multi_ec2.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. #!/usr/bin/env python
  2. # vim: expandtab:tabstop=4:shiftwidth=4
  3. from time import time
  4. import argparse
  5. import yaml
  6. import os
  7. import sys
  8. import pdb
  9. import subprocess
  10. import json
  11. import pprint
  12. class MultiEc2(object):
  13. def __init__(self):
  14. self.config = None
  15. self.all_ec2_results = {}
  16. self.result = {}
  17. self.cache_path = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache')
  18. self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
  19. self.config_file = os.path.join(self.file_path,"multi_ec2.yaml")
  20. self.parse_cli_args()
  21. # load yaml
  22. if os.path.isfile(self.config_file):
  23. self.config = self.load_yaml_config()
  24. elif os.environ.has_key("AWS_ACCESS_KEY_ID") and os.environ.has_key("AWS_SECRET_ACCESS_KEY"):
  25. self.config = {}
  26. self.config['accounts'] = [
  27. {
  28. 'name': 'default',
  29. 'provider': 'aws/ec2.py',
  30. 'env_vars': {
  31. 'AWS_ACCESS_KEY_ID': os.environ["AWS_ACCESS_KEY_ID"],
  32. 'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"],
  33. }
  34. },
  35. ]
  36. self.config['cache_max_age'] = 0
  37. else:
  38. raise RuntimeError("Could not find valid ec2 credentials in the environment.")
  39. if self.args.cache_only:
  40. # get data from disk
  41. result = self.get_inventory_from_cache()
  42. if not result:
  43. self.get_inventory()
  44. self.write_to_cache()
  45. # if its a host query, fetch and do not cache
  46. elif self.args.host:
  47. self.get_inventory()
  48. elif not self.is_cache_valid():
  49. # go fetch the inventories and cache them if cache is expired
  50. self.get_inventory()
  51. self.write_to_cache()
  52. else:
  53. # get data from disk
  54. self.get_inventory_from_cache()
  55. def load_yaml_config(self,conf_file=None):
  56. """Load a yaml config file with credentials to query the
  57. respective cloud for inventory.
  58. """
  59. config = None
  60. if not conf_file:
  61. conf_file = self.config_file
  62. with open(conf_file) as conf:
  63. config = yaml.safe_load(conf)
  64. return config
  65. def get_provider_tags(self,provider, env={}):
  66. """Call <provider> and query all of the tags that are usuable
  67. by ansible. If environment is empty use the default env.
  68. """
  69. if not env:
  70. env = os.environ
  71. # Allow relatively path'd providers in config file
  72. if os.path.isfile(os.path.join(self.file_path, provider)):
  73. provider = os.path.join(self.file_path, provider)
  74. # check to see if provider exists
  75. if not os.path.isfile(provider) or not os.access(provider, os.X_OK):
  76. raise RuntimeError("Problem with the provider. Please check path " \
  77. "and that it is executable. (%s)" % provider)
  78. cmds = [provider]
  79. if self.args.host:
  80. cmds.append("--host")
  81. cmds.append(self.args.host)
  82. else:
  83. cmds.append('--list')
  84. cmds.append('--refresh-cache')
  85. return subprocess.Popen(cmds, stderr=subprocess.PIPE, \
  86. stdout=subprocess.PIPE, env=env)
  87. def get_inventory(self):
  88. """Create the subprocess to fetch tags from a provider.
  89. Host query:
  90. Query to return a specific host. If > 1 queries have
  91. results then fail.
  92. List query:
  93. Query all of the different accounts for their tags. Once completed
  94. store all of their results into one merged updated hash.
  95. """
  96. processes = {}
  97. for account in self.config['accounts']:
  98. env = account['env_vars']
  99. name = account['name']
  100. provider = account['provider']
  101. processes[name] = self.get_provider_tags(provider, env)
  102. # for each process collect stdout when its available
  103. all_results = []
  104. for name, process in processes.items():
  105. out, err = process.communicate()
  106. all_results.append({
  107. "name": name,
  108. "out": out.strip(),
  109. "err": err.strip(),
  110. "code": process.returncode
  111. })
  112. # process --host results
  113. if not self.args.host:
  114. # For any non-zero, raise an error on it
  115. for result in all_results:
  116. if result['code'] != 0:
  117. raise RuntimeError(result['err'])
  118. else:
  119. self.all_ec2_results[result['name']] = json.loads(result['out'])
  120. values = self.all_ec2_results.values()
  121. values.insert(0, self.result)
  122. [MultiEc2.merge_destructively(self.result, x) for x in values]
  123. else:
  124. # For any 0 result, return it
  125. count = 0
  126. for results in all_results:
  127. if results['code'] == 0 and results['err'] == '' and results['out'] != '{}':
  128. self.result = json.loads(out)
  129. count += 1
  130. if count > 1:
  131. raise RuntimeError("Found > 1 results for --host %s. \
  132. This is an invalid state." % self.args.host)
  133. @staticmethod
  134. def merge_destructively(a, b):
  135. "merges b into a"
  136. for key in b:
  137. if key in a:
  138. if isinstance(a[key], dict) and isinstance(b[key], dict):
  139. MultiEc2.merge_destructively(a[key], b[key])
  140. elif a[key] == b[key]:
  141. pass # same leaf value
  142. # both lists so add each element in b to a if it does ! exist
  143. elif isinstance(a[key], list) and isinstance(b[key],list):
  144. for x in b[key]:
  145. if x not in a[key]:
  146. a[key].append(x)
  147. # a is a list and not b
  148. elif isinstance(a[key], list):
  149. if b[key] not in a[key]:
  150. a[key].append(b[key])
  151. elif isinstance(b[key], list):
  152. a[key] = [a[key]] + [k for k in b[key] if k != a[key]]
  153. else:
  154. a[key] = [a[key],b[key]]
  155. else:
  156. a[key] = b[key]
  157. return a
  158. def is_cache_valid(self):
  159. ''' Determines if the cache files have expired, or if it is still valid '''
  160. if os.path.isfile(self.cache_path):
  161. mod_time = os.path.getmtime(self.cache_path)
  162. current_time = time()
  163. if (mod_time + self.config['cache_max_age']) > current_time:
  164. return True
  165. return False
  166. def parse_cli_args(self):
  167. ''' Command line argument processing '''
  168. parser = argparse.ArgumentParser(description='Produce an Ansible Inventory file based on a provider')
  169. parser.add_argument('--cache-only', action='store_true', default=False,
  170. help='Fetch cached only instances (default: False)')
  171. parser.add_argument('--list', action='store_true', default=True,
  172. help='List instances (default: True)')
  173. parser.add_argument('--host', action='store', default=False,
  174. help='Get all the variables about a specific instance')
  175. self.args = parser.parse_args()
  176. def write_to_cache(self):
  177. ''' Writes data in JSON format to a file '''
  178. json_data = self.json_format_dict(self.result, True)
  179. with open(self.cache_path, 'w') as cache:
  180. cache.write(json_data)
  181. def get_inventory_from_cache(self):
  182. ''' Reads the inventory from the cache file and returns it as a JSON
  183. object '''
  184. if not os.path.isfile(self.cache_path):
  185. return None
  186. with open(self.cache_path, 'r') as cache:
  187. self.result = json.loads(cache.read())
  188. return True
  189. def json_format_dict(self, data, pretty=False):
  190. ''' Converts a dict to a JSON object and dumps it as a formatted
  191. string '''
  192. if pretty:
  193. return json.dumps(data, sort_keys=True, indent=2)
  194. else:
  195. return json.dumps(data)
  196. def result_str(self):
  197. return self.json_format_dict(self.result, True)
  198. if __name__ == "__main__":
  199. mi = MultiEc2()
  200. print mi.result_str()