multi_inventory.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. #!/usr/bin/env python2
  2. '''
  3. Fetch and combine multiple inventory account settings into a single
  4. json hash.
  5. '''
  6. # vim: expandtab:tabstop=4:shiftwidth=4
  7. from time import time
  8. import argparse
  9. import yaml
  10. import os
  11. import subprocess
  12. import json
  13. import errno
  14. import fcntl
  15. import tempfile
  16. import copy
  17. from string import Template
  18. import shutil
  19. CONFIG_FILE_NAME = 'multi_inventory.yaml'
  20. DEFAULT_CACHE_PATH = os.path.expanduser('~/.ansible/tmp/multi_inventory.cache')
  21. class MultiInventoryException(Exception):
  22. '''Exceptions for MultiInventory class'''
  23. pass
  24. # pylint: disable=too-many-public-methods
  25. # After a refactor of too-many-branches and placing those branches into
  26. # their own corresponding function, we have passed the allowed amount of functions(20).
  27. class MultiInventory(object):
  28. '''
  29. MultiInventory class:
  30. Opens a yaml config file and reads aws credentials.
  31. Stores a json hash of resources in result.
  32. '''
  33. def __init__(self, args=None):
  34. # Allow args to be passed when called as a library
  35. if not args:
  36. self.args = {}
  37. else:
  38. self.args = args
  39. self.cache_path = DEFAULT_CACHE_PATH
  40. self.config = None
  41. self.all_inventory_results = {}
  42. self.result = {}
  43. self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
  44. same_dir_config_file = os.path.join(self.file_path, CONFIG_FILE_NAME)
  45. etc_dir_config_file = os.path.join(os.path.sep, 'etc', 'ansible', CONFIG_FILE_NAME)
  46. # Prefer a file in the same directory, fall back to a file in etc
  47. if os.path.isfile(same_dir_config_file):
  48. self.config_file = same_dir_config_file
  49. elif os.path.isfile(etc_dir_config_file):
  50. self.config_file = etc_dir_config_file
  51. else:
  52. self.config_file = None # expect env vars
  53. # load yaml
  54. if self.config_file and os.path.isfile(self.config_file):
  55. self.config = self.load_yaml_config()
  56. elif os.environ.has_key("AWS_ACCESS_KEY_ID") and \
  57. os.environ.has_key("AWS_SECRET_ACCESS_KEY"):
  58. # Build a default config
  59. self.config = {}
  60. self.config['accounts'] = [
  61. {
  62. 'name': 'default',
  63. 'cache_location': DEFAULT_CACHE_PATH,
  64. 'provider': 'aws/hosts/ec2.py',
  65. 'env_vars': {
  66. 'AWS_ACCESS_KEY_ID': os.environ["AWS_ACCESS_KEY_ID"],
  67. 'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"],
  68. }
  69. },
  70. ]
  71. self.config['cache_max_age'] = 300
  72. else:
  73. raise RuntimeError("Could not find valid ec2 credentials in the environment.")
  74. if self.config.has_key('cache_location'):
  75. self.cache_path = self.config['cache_location']
  76. def run(self):
  77. '''This method checks to see if the local
  78. cache is valid for the inventory.
  79. if the cache is valid; return cache
  80. else the credentials are loaded from multi_inventory.yaml or from the env
  81. and we attempt to get the inventory from the provider specified.
  82. '''
  83. if self.args.get('refresh_cache', None):
  84. self.get_inventory()
  85. self.write_to_cache()
  86. # if its a host query, fetch and do not cache
  87. elif self.args.get('host', None):
  88. self.get_inventory()
  89. elif not self.is_cache_valid():
  90. # go fetch the inventories and cache them if cache is expired
  91. self.get_inventory()
  92. self.write_to_cache()
  93. else:
  94. # get data from disk
  95. self.get_inventory_from_cache()
  96. def load_yaml_config(self, conf_file=None):
  97. """Load a yaml config file with credentials to query the
  98. respective cloud for inventory.
  99. """
  100. config = None
  101. if not conf_file:
  102. conf_file = self.config_file
  103. with open(conf_file) as conf:
  104. config = yaml.safe_load(conf)
  105. # Provide a check for unique account names
  106. if len(set([acc['name'] for acc in config['accounts']])) != len(config['accounts']):
  107. raise MultiInventoryException('Duplicate account names in config file')
  108. return config
  109. def get_provider_tags(self, provider, env=None):
  110. """Call <provider> and query all of the tags that are usuable
  111. by ansible. If environment is empty use the default env.
  112. """
  113. if not env:
  114. env = os.environ
  115. # Allow relatively path'd providers in config file
  116. if os.path.isfile(os.path.join(self.file_path, provider)):
  117. provider = os.path.join(self.file_path, provider)
  118. # check to see if provider exists
  119. if not os.path.isfile(provider) or not os.access(provider, os.X_OK):
  120. raise RuntimeError("Problem with the provider. Please check path " \
  121. "and that it is executable. (%s)" % provider)
  122. cmds = [provider]
  123. if self.args.get('host', None):
  124. cmds.append("--host")
  125. cmds.append(self.args.get('host', None))
  126. else:
  127. cmds.append('--list')
  128. if 'aws' in provider.lower():
  129. cmds.append('--refresh-cache')
  130. return subprocess.Popen(cmds, stderr=subprocess.PIPE, \
  131. stdout=subprocess.PIPE, env=env)
  132. @staticmethod
  133. def generate_config(provider_files):
  134. """Generate the provider_files in a temporary directory.
  135. """
  136. prefix = 'multi_inventory.'
  137. tmp_dir_path = tempfile.mkdtemp(prefix=prefix)
  138. for provider_file in provider_files:
  139. filedes = open(os.path.join(tmp_dir_path, provider_file['name']), 'w+')
  140. content = Template(provider_file['contents']).substitute(tmpdir=tmp_dir_path)
  141. filedes.write(content)
  142. filedes.close()
  143. return tmp_dir_path
  144. def run_provider(self):
  145. '''Setup the provider call with proper variables
  146. and call self.get_provider_tags.
  147. '''
  148. try:
  149. all_results = []
  150. tmp_dir_paths = []
  151. processes = {}
  152. for account in self.config['accounts']:
  153. tmp_dir = None
  154. if account.has_key('provider_files'):
  155. tmp_dir = MultiInventory.generate_config(account['provider_files'])
  156. tmp_dir_paths.append(tmp_dir)
  157. # Update env vars after creating provider_config_files
  158. # so that we can grab the tmp_dir if it exists
  159. env = account.get('env_vars', {})
  160. if env and tmp_dir:
  161. for key, value in env.items():
  162. env[key] = Template(value).substitute(tmpdir=tmp_dir)
  163. name = account['name']
  164. provider = account['provider']
  165. processes[name] = self.get_provider_tags(provider, env)
  166. # for each process collect stdout when its available
  167. for name, process in processes.items():
  168. out, err = process.communicate()
  169. all_results.append({
  170. "name": name,
  171. "out": out.strip(),
  172. "err": err.strip(),
  173. "code": process.returncode
  174. })
  175. finally:
  176. # Clean up the mkdtemp dirs
  177. for tmp_dir in tmp_dir_paths:
  178. shutil.rmtree(tmp_dir)
  179. return all_results
  180. def get_inventory(self):
  181. """Create the subprocess to fetch tags from a provider.
  182. Host query:
  183. Query to return a specific host. If > 1 queries have
  184. results then fail.
  185. List query:
  186. Query all of the different accounts for their tags. Once completed
  187. store all of their results into one merged updated hash.
  188. """
  189. provider_results = self.run_provider()
  190. # process --host results
  191. # For any 0 result, return it
  192. if self.args.get('host', None):
  193. count = 0
  194. for results in provider_results:
  195. if results['code'] == 0 and results['err'] == '' and results['out'] != '{}':
  196. self.result = json.loads(results['out'])
  197. count += 1
  198. if count > 1:
  199. raise RuntimeError("Found > 1 results for --host %s. \
  200. This is an invalid state." % self.args.get('host', None))
  201. # process --list results
  202. else:
  203. # For any non-zero, raise an error on it
  204. for result in provider_results:
  205. if result['code'] != 0:
  206. err_msg = ['\nProblem fetching account: {name}',
  207. 'Error Code: {code}',
  208. 'StdErr: {err}',
  209. 'Stdout: {out}',
  210. ]
  211. raise RuntimeError('\n'.join(err_msg).format(**result))
  212. else:
  213. self.all_inventory_results[result['name']] = json.loads(result['out'])
  214. # Check if user wants extra vars in yaml by
  215. # having hostvars and all_group defined
  216. for acc_config in self.config['accounts']:
  217. self.apply_account_config(acc_config)
  218. # Build results by merging all dictionaries
  219. values = self.all_inventory_results.values()
  220. values.insert(0, self.result)
  221. for result in values:
  222. MultiInventory.merge_destructively(self.result, result)
  223. def add_entry(self, data, keys, item):
  224. ''' Add an item to a dictionary with key notation a.b.c
  225. d = {'a': {'b': 'c'}}}
  226. keys = a.b
  227. item = c
  228. '''
  229. if "." in keys:
  230. key, rest = keys.split(".", 1)
  231. if key not in data:
  232. data[key] = {}
  233. self.add_entry(data[key], rest, item)
  234. else:
  235. data[keys] = item
  236. def get_entry(self, data, keys):
  237. ''' Get an item from a dictionary with key notation a.b.c
  238. d = {'a': {'b': 'c'}}}
  239. keys = a.b
  240. return c
  241. '''
  242. if keys and "." in keys:
  243. key, rest = keys.split(".", 1)
  244. return self.get_entry(data[key], rest)
  245. else:
  246. return data.get(keys, None)
  247. def apply_extra_vars(self, inventory, extra_vars):
  248. ''' Apply the account config extra vars '''
  249. # Extra vars go here
  250. for new_var, value in extra_vars.items():
  251. for data in inventory.values():
  252. self.add_entry(data, new_var, value)
  253. def apply_clone_vars(self, inventory, clone_vars):
  254. ''' Apply the account config clone vars '''
  255. # Clone vars go here
  256. for to_name, from_name in clone_vars.items():
  257. for data in inventory.values():
  258. self.add_entry(data, to_name, self.get_entry(data, from_name))
  259. def apply_extra_groups(self, inventory, extra_groups):
  260. ''' Apply the account config for extra groups '''
  261. _ = self # Here for pylint. wanted an instance method instead of static
  262. for new_var, value in extra_groups.items():
  263. for _ in inventory['_meta']['hostvars'].values():
  264. inventory["%s_%s" % (new_var, value)] = copy.copy(inventory['all_hosts'])
  265. def apply_clone_groups(self, inventory, clone_groups):
  266. ''' Apply the account config for clone groups '''
  267. for to_name, from_name in clone_groups.items():
  268. for name, data in inventory['_meta']['hostvars'].items():
  269. key = '%s_%s' % (to_name, self.get_entry(data, from_name))
  270. if not inventory.has_key(key):
  271. inventory[key] = []
  272. inventory[key].append(name)
  273. def apply_group_selectors(self, inventory, group_selectors):
  274. ''' Apply the account config for group selectors '''
  275. _ = self # Here for pylint. wanted an instance method instead of static
  276. # There could be multiple clusters per account. We need to process these selectors
  277. # based upon the oo_clusterid_ variable.
  278. clusterids = [group for group in inventory if "oo_clusterid_" in group]
  279. for clusterid in clusterids:
  280. for selector in group_selectors:
  281. if inventory.has_key(selector['from_group']):
  282. hosts = list(set(inventory[clusterid]) & set(inventory[selector['from_group']]))
  283. hosts.sort()
  284. # Multiple clusters in an account
  285. if inventory.has_key(selector['name']):
  286. inventory[selector['name']].extend(hosts[0:selector['count']])
  287. else:
  288. inventory[selector['name']] = hosts[0:selector['count']]
  289. for host in hosts:
  290. if host in inventory[selector['name']]:
  291. inventory['_meta']['hostvars'][host][selector['name']] = True
  292. else:
  293. inventory['_meta']['hostvars'][host][selector['name']] = False
  294. def apply_account_config(self, acc_config):
  295. ''' Apply account config settings '''
  296. results = self.all_inventory_results[acc_config['name']]
  297. results['all_hosts'] = results['_meta']['hostvars'].keys()
  298. self.apply_extra_vars(results['_meta']['hostvars'], acc_config.get('extra_vars', {}))
  299. self.apply_clone_vars(results['_meta']['hostvars'], acc_config.get('clone_vars', {}))
  300. self.apply_extra_groups(results, acc_config.get('extra_groups', {}))
  301. self.apply_clone_groups(results, acc_config.get('clone_groups', {}))
  302. self.apply_group_selectors(results, acc_config.get('group_selectors', {}))
  303. # store the results back into all_inventory_results
  304. self.all_inventory_results[acc_config['name']] = results
  305. @staticmethod
  306. def merge_destructively(input_a, input_b):
  307. "merges b into input_a"
  308. for key in input_b:
  309. if key in input_a:
  310. if isinstance(input_a[key], dict) and isinstance(input_b[key], dict):
  311. MultiInventory.merge_destructively(input_a[key], input_b[key])
  312. elif input_a[key] == input_b[key]:
  313. pass # same leaf value
  314. # both lists so add each element in b to a if it does ! exist
  315. elif isinstance(input_a[key], list) and isinstance(input_b[key], list):
  316. for result in input_b[key]:
  317. if result not in input_a[key]:
  318. input_a[key].append(result)
  319. # a is a list and not b
  320. elif isinstance(input_a[key], list):
  321. if input_b[key] not in input_a[key]:
  322. input_a[key].append(input_b[key])
  323. elif isinstance(input_b[key], list):
  324. input_a[key] = [input_a[key]] + [k for k in input_b[key] if k != input_a[key]]
  325. else:
  326. input_a[key] = [input_a[key], input_b[key]]
  327. else:
  328. input_a[key] = input_b[key]
  329. return input_a
  330. def is_cache_valid(self):
  331. ''' Determines if the cache files have expired, or if it is still valid '''
  332. if os.path.isfile(self.cache_path):
  333. mod_time = os.path.getmtime(self.cache_path)
  334. current_time = time()
  335. if (mod_time + self.config['cache_max_age']) > current_time:
  336. return True
  337. return False
  338. def parse_cli_args(self):
  339. ''' Command line argument processing '''
  340. parser = argparse.ArgumentParser(
  341. description='Produce an Ansible Inventory file based on a provider')
  342. parser.add_argument('--refresh-cache', action='store_true', default=False,
  343. help='Fetch cached only instances (default: False)')
  344. parser.add_argument('--list', action='store_true', default=True,
  345. help='List instances (default: True)')
  346. parser.add_argument('--host', action='store', default=False,
  347. help='Get all the variables about a specific instance')
  348. self.args = parser.parse_args().__dict__
  349. def write_to_cache(self):
  350. ''' Writes data in JSON format to a file '''
  351. # if it does not exist, try and create it.
  352. if not os.path.isfile(self.cache_path):
  353. path = os.path.dirname(self.cache_path)
  354. try:
  355. os.makedirs(path)
  356. except OSError as exc:
  357. if exc.errno != errno.EEXIST or not os.path.isdir(path):
  358. raise
  359. json_data = MultiInventory.json_format_dict(self.result, True)
  360. with open(self.cache_path, 'w') as cache:
  361. try:
  362. fcntl.flock(cache, fcntl.LOCK_EX)
  363. cache.write(json_data)
  364. finally:
  365. fcntl.flock(cache, fcntl.LOCK_UN)
  366. def get_inventory_from_cache(self):
  367. ''' Reads the inventory from the cache file and returns it as a JSON
  368. object '''
  369. if not os.path.isfile(self.cache_path):
  370. return None
  371. with open(self.cache_path, 'r') as cache:
  372. self.result = json.loads(cache.read())
  373. return True
  374. @classmethod
  375. def json_format_dict(cls, data, pretty=False):
  376. ''' Converts a dict to a JSON object and dumps it as a formatted
  377. string '''
  378. if pretty:
  379. return json.dumps(data, sort_keys=True, indent=2)
  380. else:
  381. return json.dumps(data)
  382. def result_str(self):
  383. '''Return cache string stored in self.result'''
  384. return self.json_format_dict(self.result, True)
  385. if __name__ == "__main__":
  386. MI2 = MultiInventory()
  387. MI2.parse_cli_args()
  388. MI2.run()
  389. print MI2.result_str()