multi_inventory.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. #!/usr/bin/env python2
  2. '''
  3. Fetch and combine multiple inventory account settings into a single
  4. json hash.
  5. '''
  6. # vim: expandtab:tabstop=4:shiftwidth=4
  7. from time import time
  8. import argparse
  9. import yaml
  10. import os
  11. import subprocess
  12. import json
  13. import errno
  14. import fcntl
  15. import tempfile
  16. import copy
  17. from string import Template
  18. import shutil
  19. CONFIG_FILE_NAME = 'multi_inventory.yaml'
  20. DEFAULT_CACHE_PATH = os.path.expanduser('~/.ansible/tmp/multi_inventory.cache')
  21. class MultiInventoryException(Exception):
  22. '''Exceptions for MultiInventory class'''
  23. pass
  24. class MultiInventory(object):
  25. '''
  26. MultiInventory class:
  27. Opens a yaml config file and reads aws credentials.
  28. Stores a json hash of resources in result.
  29. '''
  30. def __init__(self, args=None):
  31. # Allow args to be passed when called as a library
  32. if not args:
  33. self.args = {}
  34. else:
  35. self.args = args
  36. self.cache_path = DEFAULT_CACHE_PATH
  37. self.config = None
  38. self.all_inventory_results = {}
  39. self.result = {}
  40. self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
  41. same_dir_config_file = os.path.join(self.file_path, CONFIG_FILE_NAME)
  42. etc_dir_config_file = os.path.join(os.path.sep, 'etc', 'ansible', CONFIG_FILE_NAME)
  43. # Prefer a file in the same directory, fall back to a file in etc
  44. if os.path.isfile(same_dir_config_file):
  45. self.config_file = same_dir_config_file
  46. elif os.path.isfile(etc_dir_config_file):
  47. self.config_file = etc_dir_config_file
  48. else:
  49. self.config_file = None # expect env vars
  50. def run(self):
  51. '''This method checks to see if the local
  52. cache is valid for the inventory.
  53. if the cache is valid; return cache
  54. else the credentials are loaded from multi_inventory.yaml or from the env
  55. and we attempt to get the inventory from the provider specified.
  56. '''
  57. # load yaml
  58. if self.config_file and os.path.isfile(self.config_file):
  59. self.config = self.load_yaml_config()
  60. elif os.environ.has_key("AWS_ACCESS_KEY_ID") and \
  61. os.environ.has_key("AWS_SECRET_ACCESS_KEY"):
  62. # Build a default config
  63. self.config = {}
  64. self.config['accounts'] = [
  65. {
  66. 'name': 'default',
  67. 'cache_location': DEFAULT_CACHE_PATH,
  68. 'provider': 'aws/hosts/ec2.py',
  69. 'env_vars': {
  70. 'AWS_ACCESS_KEY_ID': os.environ["AWS_ACCESS_KEY_ID"],
  71. 'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"],
  72. }
  73. },
  74. ]
  75. self.config['cache_max_age'] = 300
  76. else:
  77. raise RuntimeError("Could not find valid ec2 credentials in the environment.")
  78. if self.config.has_key('cache_location'):
  79. self.cache_path = self.config['cache_location']
  80. if self.args.get('refresh_cache', None):
  81. self.get_inventory()
  82. self.write_to_cache()
  83. # if its a host query, fetch and do not cache
  84. elif self.args.get('host', None):
  85. self.get_inventory()
  86. elif not self.is_cache_valid():
  87. # go fetch the inventories and cache them if cache is expired
  88. self.get_inventory()
  89. self.write_to_cache()
  90. else:
  91. # get data from disk
  92. self.get_inventory_from_cache()
  93. def load_yaml_config(self, conf_file=None):
  94. """Load a yaml config file with credentials to query the
  95. respective cloud for inventory.
  96. """
  97. config = None
  98. if not conf_file:
  99. conf_file = self.config_file
  100. with open(conf_file) as conf:
  101. config = yaml.safe_load(conf)
  102. # Provide a check for unique account names
  103. if len(set([acc['name'] for acc in config['accounts']])) != len(config['accounts']):
  104. raise MultiInventoryException('Duplicate account names in config file')
  105. return config
  106. def get_provider_tags(self, provider, env=None):
  107. """Call <provider> and query all of the tags that are usuable
  108. by ansible. If environment is empty use the default env.
  109. """
  110. if not env:
  111. env = os.environ
  112. # Allow relatively path'd providers in config file
  113. if os.path.isfile(os.path.join(self.file_path, provider)):
  114. provider = os.path.join(self.file_path, provider)
  115. # check to see if provider exists
  116. if not os.path.isfile(provider) or not os.access(provider, os.X_OK):
  117. raise RuntimeError("Problem with the provider. Please check path " \
  118. "and that it is executable. (%s)" % provider)
  119. cmds = [provider]
  120. if self.args.get('host', None):
  121. cmds.append("--host")
  122. cmds.append(self.args.get('host', None))
  123. else:
  124. cmds.append('--list')
  125. if 'aws' in provider.lower():
  126. cmds.append('--refresh-cache')
  127. return subprocess.Popen(cmds, stderr=subprocess.PIPE, \
  128. stdout=subprocess.PIPE, env=env)
  129. @staticmethod
  130. def generate_config(provider_files):
  131. """Generate the provider_files in a temporary directory.
  132. """
  133. prefix = 'multi_inventory.'
  134. tmp_dir_path = tempfile.mkdtemp(prefix=prefix)
  135. for provider_file in provider_files:
  136. filedes = open(os.path.join(tmp_dir_path, provider_file['name']), 'w+')
  137. content = Template(provider_file['contents']).substitute(tmpdir=tmp_dir_path)
  138. filedes.write(content)
  139. filedes.close()
  140. return tmp_dir_path
  141. def run_provider(self):
  142. '''Setup the provider call with proper variables
  143. and call self.get_provider_tags.
  144. '''
  145. try:
  146. all_results = []
  147. tmp_dir_paths = []
  148. processes = {}
  149. for account in self.config['accounts']:
  150. tmp_dir = None
  151. if account.has_key('provider_files'):
  152. tmp_dir = MultiInventory.generate_config(account['provider_files'])
  153. tmp_dir_paths.append(tmp_dir)
  154. # Update env vars after creating provider_config_files
  155. # so that we can grab the tmp_dir if it exists
  156. env = account.get('env_vars', {})
  157. if env and tmp_dir:
  158. for key, value in env.items():
  159. env[key] = Template(value).substitute(tmpdir=tmp_dir)
  160. name = account['name']
  161. provider = account['provider']
  162. processes[name] = self.get_provider_tags(provider, env)
  163. # for each process collect stdout when its available
  164. for name, process in processes.items():
  165. out, err = process.communicate()
  166. all_results.append({
  167. "name": name,
  168. "out": out.strip(),
  169. "err": err.strip(),
  170. "code": process.returncode
  171. })
  172. finally:
  173. # Clean up the mkdtemp dirs
  174. for tmp_dir in tmp_dir_paths:
  175. shutil.rmtree(tmp_dir)
  176. return all_results
  177. def get_inventory(self):
  178. """Create the subprocess to fetch tags from a provider.
  179. Host query:
  180. Query to return a specific host. If > 1 queries have
  181. results then fail.
  182. List query:
  183. Query all of the different accounts for their tags. Once completed
  184. store all of their results into one merged updated hash.
  185. """
  186. provider_results = self.run_provider()
  187. # process --host results
  188. # For any 0 result, return it
  189. if self.args.get('host', None):
  190. count = 0
  191. for results in provider_results:
  192. if results['code'] == 0 and results['err'] == '' and results['out'] != '{}':
  193. self.result = json.loads(results['out'])
  194. count += 1
  195. if count > 1:
  196. raise RuntimeError("Found > 1 results for --host %s. \
  197. This is an invalid state." % self.args.get('host', None))
  198. # process --list results
  199. else:
  200. # For any non-zero, raise an error on it
  201. for result in provider_results:
  202. if result['code'] != 0:
  203. err_msg = ['\nProblem fetching account: {name}',
  204. 'Error Code: {code}',
  205. 'StdErr: {err}',
  206. 'Stdout: {out}',
  207. ]
  208. raise RuntimeError('\n'.join(err_msg).format(**result))
  209. else:
  210. self.all_inventory_results[result['name']] = json.loads(result['out'])
  211. # Check if user wants extra vars in yaml by
  212. # having hostvars and all_group defined
  213. for acc_config in self.config['accounts']:
  214. self.apply_account_config(acc_config)
  215. # Build results by merging all dictionaries
  216. values = self.all_inventory_results.values()
  217. values.insert(0, self.result)
  218. for result in values:
  219. MultiInventory.merge_destructively(self.result, result)
  220. def add_entry(self, data, keys, item):
  221. ''' Add an item to a dictionary with key notation a.b.c
  222. d = {'a': {'b': 'c'}}}
  223. keys = a.b
  224. item = c
  225. '''
  226. if "." in keys:
  227. key, rest = keys.split(".", 1)
  228. if key not in data:
  229. data[key] = {}
  230. self.add_entry(data[key], rest, item)
  231. else:
  232. data[keys] = item
  233. def get_entry(self, data, keys):
  234. ''' Get an item from a dictionary with key notation a.b.c
  235. d = {'a': {'b': 'c'}}}
  236. keys = a.b
  237. return c
  238. '''
  239. if keys and "." in keys:
  240. key, rest = keys.split(".", 1)
  241. return self.get_entry(data[key], rest)
  242. else:
  243. return data.get(keys, None)
  244. def apply_account_config(self, acc_config):
  245. ''' Apply account config settings
  246. '''
  247. results = self.all_inventory_results[acc_config['name']]
  248. results['all_hosts'] = results['_meta']['hostvars'].keys()
  249. # Extra vars go here
  250. for new_var, value in acc_config.get('extra_vars', {}).items():
  251. for data in results['_meta']['hostvars'].values():
  252. self.add_entry(data, new_var, value)
  253. # Clone vars go here
  254. for to_name, from_name in acc_config.get('clone_vars', {}).items():
  255. for data in results['_meta']['hostvars'].values():
  256. self.add_entry(data, to_name, self.get_entry(data, from_name))
  257. # Extra groups go here
  258. for new_var, value in acc_config.get('extra_groups', {}).items():
  259. for data in results['_meta']['hostvars'].values():
  260. results["%s_%s" % (new_var, value)] = copy.copy(results['all_hosts'])
  261. # Clone groups go here
  262. # Build a group based on the desired key name
  263. for to_name, from_name in acc_config.get('clone_groups', {}).items():
  264. for name, data in results['_meta']['hostvars'].items():
  265. key = '%s_%s' % (to_name, self.get_entry(data, from_name))
  266. if not results.has_key(key):
  267. results[key] = []
  268. results[key].append(name)
  269. # store the results back into all_inventory_results
  270. self.all_inventory_results[acc_config['name']] = results
  271. @staticmethod
  272. def merge_destructively(input_a, input_b):
  273. "merges b into input_a"
  274. for key in input_b:
  275. if key in input_a:
  276. if isinstance(input_a[key], dict) and isinstance(input_b[key], dict):
  277. MultiInventory.merge_destructively(input_a[key], input_b[key])
  278. elif input_a[key] == input_b[key]:
  279. pass # same leaf value
  280. # both lists so add each element in b to a if it does ! exist
  281. elif isinstance(input_a[key], list) and isinstance(input_b[key], list):
  282. for result in input_b[key]:
  283. if result not in input_a[key]:
  284. input_a[key].append(result)
  285. # a is a list and not b
  286. elif isinstance(input_a[key], list):
  287. if input_b[key] not in input_a[key]:
  288. input_a[key].append(input_b[key])
  289. elif isinstance(input_b[key], list):
  290. input_a[key] = [input_a[key]] + [k for k in input_b[key] if k != input_a[key]]
  291. else:
  292. input_a[key] = [input_a[key], input_b[key]]
  293. else:
  294. input_a[key] = input_b[key]
  295. return input_a
  296. def is_cache_valid(self):
  297. ''' Determines if the cache files have expired, or if it is still valid '''
  298. if os.path.isfile(self.cache_path):
  299. mod_time = os.path.getmtime(self.cache_path)
  300. current_time = time()
  301. if (mod_time + self.config['cache_max_age']) > current_time:
  302. return True
  303. return False
  304. def parse_cli_args(self):
  305. ''' Command line argument processing '''
  306. parser = argparse.ArgumentParser(
  307. description='Produce an Ansible Inventory file based on a provider')
  308. parser.add_argument('--refresh-cache', action='store_true', default=False,
  309. help='Fetch cached only instances (default: False)')
  310. parser.add_argument('--list', action='store_true', default=True,
  311. help='List instances (default: True)')
  312. parser.add_argument('--host', action='store', default=False,
  313. help='Get all the variables about a specific instance')
  314. self.args = parser.parse_args().__dict__
  315. def write_to_cache(self):
  316. ''' Writes data in JSON format to a file '''
  317. # if it does not exist, try and create it.
  318. if not os.path.isfile(self.cache_path):
  319. path = os.path.dirname(self.cache_path)
  320. try:
  321. os.makedirs(path)
  322. except OSError as exc:
  323. if exc.errno != errno.EEXIST or not os.path.isdir(path):
  324. raise
  325. json_data = MultiInventory.json_format_dict(self.result, True)
  326. with open(self.cache_path, 'w') as cache:
  327. try:
  328. fcntl.flock(cache, fcntl.LOCK_EX)
  329. cache.write(json_data)
  330. finally:
  331. fcntl.flock(cache, fcntl.LOCK_UN)
  332. def get_inventory_from_cache(self):
  333. ''' Reads the inventory from the cache file and returns it as a JSON
  334. object '''
  335. if not os.path.isfile(self.cache_path):
  336. return None
  337. with open(self.cache_path, 'r') as cache:
  338. self.result = json.loads(cache.read())
  339. return True
  340. @classmethod
  341. def json_format_dict(cls, data, pretty=False):
  342. ''' Converts a dict to a JSON object and dumps it as a formatted
  343. string '''
  344. if pretty:
  345. return json.dumps(data, sort_keys=True, indent=2)
  346. else:
  347. return json.dumps(data)
  348. def result_str(self):
  349. '''Return cache string stored in self.result'''
  350. return self.json_format_dict(self.result, True)
  351. if __name__ == "__main__":
  352. MI2 = MultiInventory()
  353. MI2.parse_cli_args()
  354. MI2.run()
  355. print MI2.result_str()