multi_ec2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. #!/usr/bin/env python2
  2. '''
  3. Fetch and combine multiple ec2 account settings into a single
  4. json hash.
  5. '''
  6. # vim: expandtab:tabstop=4:shiftwidth=4
  7. from time import time
  8. import argparse
  9. import yaml
  10. import os
  11. import subprocess
  12. import json
  13. import errno
  14. import fcntl
  15. import tempfile
  16. import copy
  17. CONFIG_FILE_NAME = 'multi_ec2.yaml'
  18. DEFAULT_CACHE_PATH = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache')
  19. class MultiEc2(object):
  20. '''
  21. MultiEc2 class:
  22. Opens a yaml config file and reads aws credentials.
  23. Stores a json hash of resources in result.
  24. '''
  25. def __init__(self):
  26. self.args = None
  27. self.config = None
  28. self.all_ec2_results = {}
  29. self.result = {}
  30. self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
  31. same_dir_config_file = os.path.join(self.file_path, CONFIG_FILE_NAME)
  32. etc_dir_config_file = os.path.join(os.path.sep, 'etc', 'ansible', CONFIG_FILE_NAME)
  33. # Prefer a file in the same directory, fall back to a file in etc
  34. if os.path.isfile(same_dir_config_file):
  35. self.config_file = same_dir_config_file
  36. elif os.path.isfile(etc_dir_config_file):
  37. self.config_file = etc_dir_config_file
  38. else:
  39. self.config_file = None # expect env vars
  40. self.parse_cli_args()
  41. # load yaml
  42. if self.config_file and os.path.isfile(self.config_file):
  43. self.config = self.load_yaml_config()
  44. elif os.environ.has_key("AWS_ACCESS_KEY_ID") and \
  45. os.environ.has_key("AWS_SECRET_ACCESS_KEY"):
  46. # Build a default config
  47. self.config = {}
  48. self.config['accounts'] = [
  49. {
  50. 'name': 'default',
  51. 'cache_location': DEFAULT_CACHE_PATH,
  52. 'provider': 'aws/hosts/ec2.py',
  53. 'env_vars': {
  54. 'AWS_ACCESS_KEY_ID': os.environ["AWS_ACCESS_KEY_ID"],
  55. 'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"],
  56. }
  57. },
  58. ]
  59. self.config['cache_max_age'] = 0
  60. else:
  61. raise RuntimeError("Could not find valid ec2 credentials in the environment.")
  62. # Set the default cache path but if its defined we'll assign it.
  63. self.cache_path = DEFAULT_CACHE_PATH
  64. if self.config.has_key('cache_location'):
  65. self.cache_path = self.config['cache_location']
  66. if self.args.refresh_cache:
  67. self.get_inventory()
  68. self.write_to_cache()
  69. # if its a host query, fetch and do not cache
  70. elif self.args.host:
  71. self.get_inventory()
  72. elif not self.is_cache_valid():
  73. # go fetch the inventories and cache them if cache is expired
  74. self.get_inventory()
  75. self.write_to_cache()
  76. else:
  77. # get data from disk
  78. self.get_inventory_from_cache()
  79. def load_yaml_config(self, conf_file=None):
  80. """Load a yaml config file with credentials to query the
  81. respective cloud for inventory.
  82. """
  83. config = None
  84. if not conf_file:
  85. conf_file = self.config_file
  86. with open(conf_file) as conf:
  87. config = yaml.safe_load(conf)
  88. return config
  89. def get_provider_tags(self, provider, env=None):
  90. """Call <provider> and query all of the tags that are usuable
  91. by ansible. If environment is empty use the default env.
  92. """
  93. if not env:
  94. env = os.environ
  95. # Allow relatively path'd providers in config file
  96. if os.path.isfile(os.path.join(self.file_path, provider)):
  97. provider = os.path.join(self.file_path, provider)
  98. # check to see if provider exists
  99. if not os.path.isfile(provider) or not os.access(provider, os.X_OK):
  100. raise RuntimeError("Problem with the provider. Please check path " \
  101. "and that it is executable. (%s)" % provider)
  102. cmds = [provider]
  103. if self.args.host:
  104. cmds.append("--host")
  105. cmds.append(self.args.host)
  106. else:
  107. cmds.append('--list')
  108. cmds.append('--refresh-cache')
  109. return subprocess.Popen(cmds, stderr=subprocess.PIPE, \
  110. stdout=subprocess.PIPE, env=env)
  111. @staticmethod
  112. def generate_config(config_data):
  113. """Generate the ec2.ini file in as a secure temp file.
  114. Once generated, pass it to the ec2.py as an environment variable.
  115. """
  116. fildes, tmp_file_path = tempfile.mkstemp(prefix='multi_ec2.ini.')
  117. for section, values in config_data.items():
  118. os.write(fildes, "[%s]\n" % section)
  119. for option, value in values.items():
  120. os.write(fildes, "%s = %s\n" % (option, value))
  121. os.close(fildes)
  122. return tmp_file_path
  123. def run_provider(self):
  124. '''Setup the provider call with proper variables
  125. and call self.get_provider_tags.
  126. '''
  127. try:
  128. all_results = []
  129. tmp_file_paths = []
  130. processes = {}
  131. for account in self.config['accounts']:
  132. env = account['env_vars']
  133. if account.has_key('provider_config'):
  134. tmp_file_paths.append(MultiEc2.generate_config(account['provider_config']))
  135. env['EC2_INI_PATH'] = tmp_file_paths[-1]
  136. name = account['name']
  137. provider = account['provider']
  138. processes[name] = self.get_provider_tags(provider, env)
  139. # for each process collect stdout when its available
  140. for name, process in processes.items():
  141. out, err = process.communicate()
  142. all_results.append({
  143. "name": name,
  144. "out": out.strip(),
  145. "err": err.strip(),
  146. "code": process.returncode
  147. })
  148. finally:
  149. # Clean up the mkstemp file
  150. for tmp_file in tmp_file_paths:
  151. os.unlink(tmp_file)
  152. return all_results
  153. def get_inventory(self):
  154. """Create the subprocess to fetch tags from a provider.
  155. Host query:
  156. Query to return a specific host. If > 1 queries have
  157. results then fail.
  158. List query:
  159. Query all of the different accounts for their tags. Once completed
  160. store all of their results into one merged updated hash.
  161. """
  162. provider_results = self.run_provider()
  163. # process --host results
  164. # For any 0 result, return it
  165. if self.args.host:
  166. count = 0
  167. for results in provider_results:
  168. if results['code'] == 0 and results['err'] == '' and results['out'] != '{}':
  169. self.result = json.loads(results['out'])
  170. count += 1
  171. if count > 1:
  172. raise RuntimeError("Found > 1 results for --host %s. \
  173. This is an invalid state." % self.args.host)
  174. # process --list results
  175. else:
  176. # For any non-zero, raise an error on it
  177. for result in provider_results:
  178. if result['code'] != 0:
  179. raise RuntimeError(result['err'])
  180. else:
  181. self.all_ec2_results[result['name']] = json.loads(result['out'])
  182. # Check if user wants extra vars in yaml by
  183. # having hostvars and all_group defined
  184. for acc_config in self.config['accounts']:
  185. self.apply_account_config(acc_config)
  186. # Build results by merging all dictionaries
  187. values = self.all_ec2_results.values()
  188. values.insert(0, self.result)
  189. for result in values:
  190. MultiEc2.merge_destructively(self.result, result)
  191. def apply_account_config(self, acc_config):
  192. ''' Apply account config settings
  193. '''
  194. if not acc_config.has_key('hostvars') and not acc_config.has_key('all_group'):
  195. return
  196. results = self.all_ec2_results[acc_config['name']]
  197. # Update each hostvar with the newly desired key: value
  198. for host_property, value in acc_config['hostvars'].items():
  199. # Verify the account results look sane
  200. # by checking for these keys ('_meta' and 'hostvars' exist)
  201. if results.has_key('_meta') and results['_meta'].has_key('hostvars'):
  202. for data in results['_meta']['hostvars'].values():
  203. data[str(host_property)] = str(value)
  204. # Add this group
  205. results["%s_%s" % (host_property, value)] = \
  206. copy.copy(results[acc_config['all_group']])
  207. # store the results back into all_ec2_results
  208. self.all_ec2_results[acc_config['name']] = results
  209. @staticmethod
  210. def merge_destructively(input_a, input_b):
  211. "merges b into input_a"
  212. for key in input_b:
  213. if key in input_a:
  214. if isinstance(input_a[key], dict) and isinstance(input_b[key], dict):
  215. MultiEc2.merge_destructively(input_a[key], input_b[key])
  216. elif input_a[key] == input_b[key]:
  217. pass # same leaf value
  218. # both lists so add each element in b to a if it does ! exist
  219. elif isinstance(input_a[key], list) and isinstance(input_b[key], list):
  220. for result in input_b[key]:
  221. if result not in input_a[key]:
  222. input_a[key].append(result)
  223. # a is a list and not b
  224. elif isinstance(input_a[key], list):
  225. if input_b[key] not in input_a[key]:
  226. input_a[key].append(input_b[key])
  227. elif isinstance(input_b[key], list):
  228. input_a[key] = [input_a[key]] + [k for k in input_b[key] if k != input_a[key]]
  229. else:
  230. input_a[key] = [input_a[key], input_b[key]]
  231. else:
  232. input_a[key] = input_b[key]
  233. return input_a
  234. def is_cache_valid(self):
  235. ''' Determines if the cache files have expired, or if it is still valid '''
  236. if os.path.isfile(self.cache_path):
  237. mod_time = os.path.getmtime(self.cache_path)
  238. current_time = time()
  239. if (mod_time + self.config['cache_max_age']) > current_time:
  240. return True
  241. return False
  242. def parse_cli_args(self):
  243. ''' Command line argument processing '''
  244. parser = argparse.ArgumentParser(
  245. description='Produce an Ansible Inventory file based on a provider')
  246. parser.add_argument('--refresh-cache', action='store_true', default=False,
  247. help='Fetch cached only instances (default: False)')
  248. parser.add_argument('--list', action='store_true', default=True,
  249. help='List instances (default: True)')
  250. parser.add_argument('--host', action='store', default=False,
  251. help='Get all the variables about a specific instance')
  252. self.args = parser.parse_args()
  253. def write_to_cache(self):
  254. ''' Writes data in JSON format to a file '''
  255. # if it does not exist, try and create it.
  256. if not os.path.isfile(self.cache_path):
  257. path = os.path.dirname(self.cache_path)
  258. try:
  259. os.makedirs(path)
  260. except OSError as exc:
  261. if exc.errno != errno.EEXIST or not os.path.isdir(path):
  262. raise
  263. json_data = MultiEc2.json_format_dict(self.result, True)
  264. with open(self.cache_path, 'w') as cache:
  265. try:
  266. fcntl.flock(cache, fcntl.LOCK_EX)
  267. cache.write(json_data)
  268. finally:
  269. fcntl.flock(cache, fcntl.LOCK_UN)
  270. def get_inventory_from_cache(self):
  271. ''' Reads the inventory from the cache file and returns it as a JSON
  272. object '''
  273. if not os.path.isfile(self.cache_path):
  274. return None
  275. with open(self.cache_path, 'r') as cache:
  276. self.result = json.loads(cache.read())
  277. return True
  278. @classmethod
  279. def json_format_dict(cls, data, pretty=False):
  280. ''' Converts a dict to a JSON object and dumps it as a formatted
  281. string '''
  282. if pretty:
  283. return json.dumps(data, sort_keys=True, indent=2)
  284. else:
  285. return json.dumps(data)
  286. def result_str(self):
  287. '''Return cache string stored in self.result'''
  288. return self.json_format_dict(self.result, True)
  289. if __name__ == "__main__":
  290. print MultiEc2().result_str()