multi_ec2.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. #!/usr/bin/env python
  2. # vim: expandtab:tabstop=4:shiftwidth=4
  3. from time import time
  4. import argparse
  5. import yaml
  6. import os
  7. import sys
  8. import pdb
  9. import subprocess
  10. import json
  11. import pprint
  12. class MultiEc2(object):
  13. def __init__(self):
  14. self.config = None
  15. self.all_ec2_results = {}
  16. self.result = {}
  17. self.cache_path = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache')
  18. self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
  19. self.config_file = os.path.join(self.file_path,"multi_ec2.yaml")
  20. self.parse_cli_args()
  21. # load yaml
  22. if os.path.isfile(self.config_file):
  23. self.config = self.load_yaml_config()
  24. elif os.environ.has_key("AWS_ACCESS_KEY_ID") and os.environ.has_key("AWS_SECRET_ACCESS_KEY"):
  25. self.config = {}
  26. self.config['accounts'] = [
  27. {
  28. 'name': 'default',
  29. 'provider': 'aws/ec2.py',
  30. 'env_vars': {
  31. 'AWS_ACCESS_KEY_ID': os.environ["AWS_ACCESS_KEY_ID"],
  32. 'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"],
  33. }
  34. },
  35. ]
  36. self.config['cache_max_age'] = 0
  37. else:
  38. raise RuntimeError("Could not find valid ec2 credentials in the environment.")
  39. # if its a host query, fetch and do not cache
  40. if self.args.host:
  41. self.get_inventory()
  42. elif not self.is_cache_valid():
  43. # go fetch the inventories and cache them if cache is expired
  44. self.get_inventory()
  45. self.write_to_cache()
  46. else:
  47. # get data from disk
  48. self.get_inventory_from_cache()
  49. def load_yaml_config(self,conf_file=None):
  50. """Load a yaml config file with credentials to query the
  51. respective cloud for inventory.
  52. """
  53. config = None
  54. if not conf_file:
  55. conf_file = self.config_file
  56. with open(conf_file) as conf:
  57. config = yaml.safe_load(conf)
  58. return config
  59. def get_provider_tags(self,provider, env={}):
  60. """Call <provider> and query all of the tags that are usuable
  61. by ansible. If environment is empty use the default env.
  62. """
  63. if not env:
  64. env = os.environ
  65. # Allow relatively path'd providers in config file
  66. if os.path.isfile(os.path.join(self.file_path, provider)):
  67. provider = os.path.join(self.file_path, provider)
  68. # check to see if provider exists
  69. if not os.path.isfile(provider) or not os.access(provider, os.X_OK):
  70. raise RuntimeError("Problem with the provider. Please check path " \
  71. "and that it is executable. (%s)" % provider)
  72. cmds = [provider]
  73. if self.args.host:
  74. cmds.append("--host")
  75. cmds.append(self.args.host)
  76. else:
  77. cmds.append('--list')
  78. cmds.append('--refresh-cache')
  79. return subprocess.Popen(cmds, stderr=subprocess.PIPE, \
  80. stdout=subprocess.PIPE, env=env)
  81. def get_inventory(self):
  82. """Create the subprocess to fetch tags from a provider.
  83. Host query:
  84. Query to return a specific host. If > 1 queries have
  85. results then fail.
  86. List query:
  87. Query all of the different accounts for their tags. Once completed
  88. store all of their results into one merged updated hash.
  89. """
  90. processes = {}
  91. for account in self.config['accounts']:
  92. env = account['env_vars']
  93. name = account['name']
  94. provider = account['provider']
  95. processes[name] = self.get_provider_tags(provider, env)
  96. # for each process collect stdout when its available
  97. all_results = []
  98. for name, process in processes.items():
  99. out, err = process.communicate()
  100. all_results.append({
  101. "name": name,
  102. "out": out.strip(),
  103. "err": err.strip(),
  104. "code": process.returncode
  105. })
  106. # process --host results
  107. if not self.args.host:
  108. # For any non-zero, raise an error on it
  109. for result in all_results:
  110. if result['code'] != 0:
  111. raise RuntimeError(result['err'])
  112. else:
  113. self.all_ec2_results[result['name']] = json.loads(result['out'])
  114. values = self.all_ec2_results.values()
  115. values.insert(0, self.result)
  116. [MultiEc2.merge_destructively(self.result, x) for x in values]
  117. else:
  118. # For any 0 result, return it
  119. count = 0
  120. for results in all_results:
  121. if results['code'] == 0 and results['err'] == '' and results['out'] != '{}':
  122. self.result = json.loads(out)
  123. count += 1
  124. if count > 1:
  125. raise RuntimeError("Found > 1 results for --host %s. \
  126. This is an invalid state." % self.args.host)
  127. @staticmethod
  128. def merge_destructively(a, b):
  129. "merges b into a"
  130. for key in b:
  131. if key in a:
  132. if isinstance(a[key], dict) and isinstance(b[key], dict):
  133. MultiEc2.merge_destructively(a[key], b[key])
  134. elif a[key] == b[key]:
  135. pass # same leaf value
  136. # both lists so add each element in b to a if it does ! exist
  137. elif isinstance(a[key], list) and isinstance(b[key],list):
  138. for x in b[key]:
  139. if x not in a[key]:
  140. a[key].append(x)
  141. # a is a list and not b
  142. elif isinstance(a[key], list):
  143. if b[key] not in a[key]:
  144. a[key].append(b[key])
  145. elif isinstance(b[key], list):
  146. a[key] = [a[key]] + [k for k in b[key] if k != a[key]]
  147. else:
  148. a[key] = [a[key],b[key]]
  149. else:
  150. a[key] = b[key]
  151. return a
  152. def is_cache_valid(self):
  153. ''' Determines if the cache files have expired, or if it is still valid '''
  154. if os.path.isfile(self.cache_path):
  155. mod_time = os.path.getmtime(self.cache_path)
  156. current_time = time()
  157. if (mod_time + self.config['cache_max_age']) > current_time:
  158. return True
  159. return False
  160. def parse_cli_args(self):
  161. ''' Command line argument processing '''
  162. parser = argparse.ArgumentParser(description='Produce an Ansible Inventory file based on a provider')
  163. parser.add_argument('--list', action='store_true', default=True,
  164. help='List instances (default: True)')
  165. parser.add_argument('--host', action='store', default=False,
  166. help='Get all the variables about a specific instance')
  167. self.args = parser.parse_args()
  168. def write_to_cache(self):
  169. ''' Writes data in JSON format to a file '''
  170. json_data = self.json_format_dict(self.result, True)
  171. with open(self.cache_path, 'w') as cache:
  172. cache.write(json_data)
  173. def get_inventory_from_cache(self):
  174. ''' Reads the inventory from the cache file and returns it as a JSON
  175. object '''
  176. with open(self.cache_path, 'r') as cache:
  177. self.result = json.loads(cache.read())
  178. def json_format_dict(self, data, pretty=False):
  179. ''' Converts a dict to a JSON object and dumps it as a formatted
  180. string '''
  181. if pretty:
  182. return json.dumps(data, sort_keys=True, indent=2)
  183. else:
  184. return json.dumps(data)
  185. def result_str(self):
  186. return self.json_format_dict(self.result, True)
  187. if __name__ == "__main__":
  188. mi = MultiEc2()
  189. print mi.result_str()