multi_ec2.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. #!/usr/bin/env python2
  2. # vim: expandtab:tabstop=4:shiftwidth=4
  3. from time import time
  4. import argparse
  5. import yaml
  6. import os
  7. import sys
  8. import pdb
  9. import subprocess
  10. import json
  11. import pprint
  12. CONFIG_FILE_NAME = 'multi_ec2.yaml'
  13. class MultiEc2(object):
  14. def __init__(self):
  15. self.config = None
  16. self.all_ec2_results = {}
  17. self.result = {}
  18. self.cache_path = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache')
  19. self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
  20. same_dir_config_file = os.path.join(self.file_path, CONFIG_FILE_NAME)
  21. etc_dir_config_file = os.path.join(os.path.sep, 'etc','ansible', CONFIG_FILE_NAME)
  22. # Prefer a file in the same directory, fall back to a file in etc
  23. if os.path.isfile(same_dir_config_file):
  24. self.config_file = same_dir_config_file
  25. elif os.path.isfile(etc_dir_config_file):
  26. self.config_file = etc_dir_config_file
  27. else:
  28. self.config_file = None # expect env vars
  29. self.parse_cli_args()
  30. # load yaml
  31. if self.config_file and os.path.isfile(self.config_file):
  32. self.config = self.load_yaml_config()
  33. elif os.environ.has_key("AWS_ACCESS_KEY_ID") and os.environ.has_key("AWS_SECRET_ACCESS_KEY"):
  34. self.config = {}
  35. self.config['accounts'] = [
  36. {
  37. 'name': 'default',
  38. 'provider': 'aws/ec2.py',
  39. 'env_vars': {
  40. 'AWS_ACCESS_KEY_ID': os.environ["AWS_ACCESS_KEY_ID"],
  41. 'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"],
  42. }
  43. },
  44. ]
  45. self.config['cache_max_age'] = 0
  46. else:
  47. raise RuntimeError("Could not find valid ec2 credentials in the environment.")
  48. if self.args.cache_only:
  49. # get data from disk
  50. result = self.get_inventory_from_cache()
  51. if not result:
  52. self.get_inventory()
  53. self.write_to_cache()
  54. # if its a host query, fetch and do not cache
  55. elif self.args.host:
  56. self.get_inventory()
  57. elif not self.is_cache_valid():
  58. # go fetch the inventories and cache them if cache is expired
  59. self.get_inventory()
  60. self.write_to_cache()
  61. else:
  62. # get data from disk
  63. self.get_inventory_from_cache()
  64. def load_yaml_config(self,conf_file=None):
  65. """Load a yaml config file with credentials to query the
  66. respective cloud for inventory.
  67. """
  68. config = None
  69. if not conf_file:
  70. conf_file = self.config_file
  71. with open(conf_file) as conf:
  72. config = yaml.safe_load(conf)
  73. return config
  74. def get_provider_tags(self,provider, env={}):
  75. """Call <provider> and query all of the tags that are usuable
  76. by ansible. If environment is empty use the default env.
  77. """
  78. if not env:
  79. env = os.environ
  80. # Allow relatively path'd providers in config file
  81. if os.path.isfile(os.path.join(self.file_path, provider)):
  82. provider = os.path.join(self.file_path, provider)
  83. # check to see if provider exists
  84. if not os.path.isfile(provider) or not os.access(provider, os.X_OK):
  85. raise RuntimeError("Problem with the provider. Please check path " \
  86. "and that it is executable. (%s)" % provider)
  87. cmds = [provider]
  88. if self.args.host:
  89. cmds.append("--host")
  90. cmds.append(self.args.host)
  91. else:
  92. cmds.append('--list')
  93. cmds.append('--refresh-cache')
  94. return subprocess.Popen(cmds, stderr=subprocess.PIPE, \
  95. stdout=subprocess.PIPE, env=env)
  96. def get_inventory(self):
  97. """Create the subprocess to fetch tags from a provider.
  98. Host query:
  99. Query to return a specific host. If > 1 queries have
  100. results then fail.
  101. List query:
  102. Query all of the different accounts for their tags. Once completed
  103. store all of their results into one merged updated hash.
  104. """
  105. processes = {}
  106. for account in self.config['accounts']:
  107. env = account['env_vars']
  108. name = account['name']
  109. provider = account['provider']
  110. processes[name] = self.get_provider_tags(provider, env)
  111. # for each process collect stdout when its available
  112. all_results = []
  113. for name, process in processes.items():
  114. out, err = process.communicate()
  115. all_results.append({
  116. "name": name,
  117. "out": out.strip(),
  118. "err": err.strip(),
  119. "code": process.returncode
  120. })
  121. # process --host results
  122. if not self.args.host:
  123. # For any non-zero, raise an error on it
  124. for result in all_results:
  125. if result['code'] != 0:
  126. raise RuntimeError(result['err'])
  127. else:
  128. self.all_ec2_results[result['name']] = json.loads(result['out'])
  129. values = self.all_ec2_results.values()
  130. values.insert(0, self.result)
  131. [MultiEc2.merge_destructively(self.result, x) for x in values]
  132. else:
  133. # For any 0 result, return it
  134. count = 0
  135. for results in all_results:
  136. if results['code'] == 0 and results['err'] == '' and results['out'] != '{}':
  137. self.result = json.loads(out)
  138. count += 1
  139. if count > 1:
  140. raise RuntimeError("Found > 1 results for --host %s. \
  141. This is an invalid state." % self.args.host)
  142. @staticmethod
  143. def merge_destructively(a, b):
  144. "merges b into a"
  145. for key in b:
  146. if key in a:
  147. if isinstance(a[key], dict) and isinstance(b[key], dict):
  148. MultiEc2.merge_destructively(a[key], b[key])
  149. elif a[key] == b[key]:
  150. pass # same leaf value
  151. # both lists so add each element in b to a if it does ! exist
  152. elif isinstance(a[key], list) and isinstance(b[key],list):
  153. for x in b[key]:
  154. if x not in a[key]:
  155. a[key].append(x)
  156. # a is a list and not b
  157. elif isinstance(a[key], list):
  158. if b[key] not in a[key]:
  159. a[key].append(b[key])
  160. elif isinstance(b[key], list):
  161. a[key] = [a[key]] + [k for k in b[key] if k != a[key]]
  162. else:
  163. a[key] = [a[key],b[key]]
  164. else:
  165. a[key] = b[key]
  166. return a
  167. def is_cache_valid(self):
  168. ''' Determines if the cache files have expired, or if it is still valid '''
  169. if os.path.isfile(self.cache_path):
  170. mod_time = os.path.getmtime(self.cache_path)
  171. current_time = time()
  172. if (mod_time + self.config['cache_max_age']) > current_time:
  173. return True
  174. return False
  175. def parse_cli_args(self):
  176. ''' Command line argument processing '''
  177. parser = argparse.ArgumentParser(description='Produce an Ansible Inventory file based on a provider')
  178. parser.add_argument('--cache-only', action='store_true', default=False,
  179. help='Fetch cached only instances (default: False)')
  180. parser.add_argument('--list', action='store_true', default=True,
  181. help='List instances (default: True)')
  182. parser.add_argument('--host', action='store', default=False,
  183. help='Get all the variables about a specific instance')
  184. self.args = parser.parse_args()
  185. def write_to_cache(self):
  186. ''' Writes data in JSON format to a file '''
  187. json_data = self.json_format_dict(self.result, True)
  188. with open(self.cache_path, 'w') as cache:
  189. cache.write(json_data)
  190. def get_inventory_from_cache(self):
  191. ''' Reads the inventory from the cache file and returns it as a JSON
  192. object '''
  193. if not os.path.isfile(self.cache_path):
  194. return None
  195. with open(self.cache_path, 'r') as cache:
  196. self.result = json.loads(cache.read())
  197. return True
  198. def json_format_dict(self, data, pretty=False):
  199. ''' Converts a dict to a JSON object and dumps it as a formatted
  200. string '''
  201. if pretty:
  202. return json.dumps(data, sort_keys=True, indent=2)
  203. else:
  204. return json.dumps(data)
  205. def result_str(self):
  206. return self.json_format_dict(self.result, True)
  207. if __name__ == "__main__":
  208. mi = MultiEc2()
  209. print mi.result_str()