multi_ec2.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. #!/usr/bin/env python2
  2. '''
  3. Fetch and combine multiple ec2 account settings into a single
  4. json hash.
  5. '''
  6. # vim: expandtab:tabstop=4:shiftwidth=4
  7. from time import time
  8. import argparse
  9. import yaml
  10. import os
  11. import subprocess
  12. import json
  13. import errno
  14. import fcntl
  15. import tempfile
  16. CONFIG_FILE_NAME = 'multi_ec2.yaml'
  17. DEFAULT_CACHE_PATH = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache')
  18. class MultiEc2(object):
  19. '''
  20. MultiEc2 class:
  21. Opens a yaml config file and reads aws credentials.
  22. Stores a json hash of resources in result.
  23. '''
  24. def __init__(self):
  25. self.args = None
  26. self.config = None
  27. self.all_ec2_results = {}
  28. self.result = {}
  29. self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
  30. same_dir_config_file = os.path.join(self.file_path, CONFIG_FILE_NAME)
  31. etc_dir_config_file = os.path.join(os.path.sep, 'etc', 'ansible', CONFIG_FILE_NAME)
  32. # Prefer a file in the same directory, fall back to a file in etc
  33. if os.path.isfile(same_dir_config_file):
  34. self.config_file = same_dir_config_file
  35. elif os.path.isfile(etc_dir_config_file):
  36. self.config_file = etc_dir_config_file
  37. else:
  38. self.config_file = None # expect env vars
  39. self.parse_cli_args()
  40. # load yaml
  41. if self.config_file and os.path.isfile(self.config_file):
  42. self.config = self.load_yaml_config()
  43. elif os.environ.has_key("AWS_ACCESS_KEY_ID") and \
  44. os.environ.has_key("AWS_SECRET_ACCESS_KEY"):
  45. # Build a default config
  46. self.config = {}
  47. self.config['accounts'] = [
  48. {
  49. 'name': 'default',
  50. 'cache_location': DEFAULT_CACHE_PATH,
  51. 'provider': 'aws/hosts/ec2.py',
  52. 'env_vars': {
  53. 'AWS_ACCESS_KEY_ID': os.environ["AWS_ACCESS_KEY_ID"],
  54. 'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"],
  55. }
  56. },
  57. ]
  58. self.config['cache_max_age'] = 0
  59. else:
  60. raise RuntimeError("Could not find valid ec2 credentials in the environment.")
  61. # Set the default cache path but if its defined we'll assign it.
  62. self.cache_path = DEFAULT_CACHE_PATH
  63. if self.config.has_key('cache_location'):
  64. self.cache_path = self.config['cache_location']
  65. if self.args.refresh_cache:
  66. self.get_inventory()
  67. self.write_to_cache()
  68. # if its a host query, fetch and do not cache
  69. elif self.args.host:
  70. self.get_inventory()
  71. elif not self.is_cache_valid():
  72. # go fetch the inventories and cache them if cache is expired
  73. self.get_inventory()
  74. self.write_to_cache()
  75. else:
  76. # get data from disk
  77. self.get_inventory_from_cache()
  78. def load_yaml_config(self, conf_file=None):
  79. """Load a yaml config file with credentials to query the
  80. respective cloud for inventory.
  81. """
  82. config = None
  83. if not conf_file:
  84. conf_file = self.config_file
  85. with open(conf_file) as conf:
  86. config = yaml.safe_load(conf)
  87. return config
  88. def get_provider_tags(self, provider, env=None):
  89. """Call <provider> and query all of the tags that are usuable
  90. by ansible. If environment is empty use the default env.
  91. """
  92. if not env:
  93. env = os.environ
  94. # Allow relatively path'd providers in config file
  95. if os.path.isfile(os.path.join(self.file_path, provider)):
  96. provider = os.path.join(self.file_path, provider)
  97. # check to see if provider exists
  98. if not os.path.isfile(provider) or not os.access(provider, os.X_OK):
  99. raise RuntimeError("Problem with the provider. Please check path " \
  100. "and that it is executable. (%s)" % provider)
  101. cmds = [provider]
  102. if self.args.host:
  103. cmds.append("--host")
  104. cmds.append(self.args.host)
  105. else:
  106. cmds.append('--list')
  107. cmds.append('--refresh-cache')
  108. return subprocess.Popen(cmds, stderr=subprocess.PIPE, \
  109. stdout=subprocess.PIPE, env=env)
  110. @staticmethod
  111. def generate_config(config_data):
  112. """Generate the ec2.ini file in as a secure temp file.
  113. Once generated, pass it to the ec2.py as an environment variable.
  114. """
  115. fildes, tmp_file_path = tempfile.mkstemp(prefix='multi_ec2.ini.')
  116. for section, values in config_data.items():
  117. os.write(fildes, "[%s]\n" % section)
  118. for option, value in values.items():
  119. os.write(fildes, "%s = %s\n" % (option, value))
  120. os.close(fildes)
  121. return tmp_file_path
  122. def run_provider(self):
  123. '''Setup the provider call with proper variables
  124. and call self.get_provider_tags.
  125. '''
  126. try:
  127. all_results = []
  128. tmp_file_path = None
  129. processes = {}
  130. for account in self.config['accounts']:
  131. env = account['env_vars']
  132. if account.has_key('provider_config'):
  133. tmp_file_path = MultiEc2.generate_config(account['provider_config'])
  134. env['EC2_INI_PATH'] = tmp_file_path
  135. name = account['name']
  136. provider = account['provider']
  137. processes[name] = self.get_provider_tags(provider, env)
  138. # for each process collect stdout when its available
  139. for name, process in processes.items():
  140. out, err = process.communicate()
  141. all_results.append({
  142. "name": name,
  143. "out": out.strip(),
  144. "err": err.strip(),
  145. "code": process.returncode
  146. })
  147. finally:
  148. # Clean up the mkstemp file
  149. if tmp_file_path:
  150. os.unlink(tmp_file_path)
  151. return all_results
  152. def get_inventory(self):
  153. """Create the subprocess to fetch tags from a provider.
  154. Host query:
  155. Query to return a specific host. If > 1 queries have
  156. results then fail.
  157. List query:
  158. Query all of the different accounts for their tags. Once completed
  159. store all of their results into one merged updated hash.
  160. """
  161. provider_results = self.run_provider()
  162. # process --host results
  163. if not self.args.host:
  164. # For any non-zero, raise an error on it
  165. for result in provider_results:
  166. if result['code'] != 0:
  167. raise RuntimeError(result['err'])
  168. else:
  169. self.all_ec2_results[result['name']] = json.loads(result['out'])
  170. values = self.all_ec2_results.values()
  171. values.insert(0, self.result)
  172. for result in values:
  173. MultiEc2.merge_destructively(self.result, result)
  174. else:
  175. # For any 0 result, return it
  176. count = 0
  177. for results in provider_results:
  178. if results['code'] == 0 and results['err'] == '' and results['out'] != '{}':
  179. self.result = json.loads(results['out'])
  180. count += 1
  181. if count > 1:
  182. raise RuntimeError("Found > 1 results for --host %s. \
  183. This is an invalid state." % self.args.host)
  184. @staticmethod
  185. def merge_destructively(input_a, input_b):
  186. "merges b into input_a"
  187. for key in input_b:
  188. if key in input_a:
  189. if isinstance(input_a[key], dict) and isinstance(input_b[key], dict):
  190. MultiEc2.merge_destructively(input_a[key], input_b[key])
  191. elif input_a[key] == input_b[key]:
  192. pass # same leaf value
  193. # both lists so add each element in b to a if it does ! exist
  194. elif isinstance(input_a[key], list) and isinstance(input_b[key], list):
  195. for result in input_b[key]:
  196. if result not in input_a[key]:
  197. input_a[key].append(result)
  198. # a is a list and not b
  199. elif isinstance(input_a[key], list):
  200. if input_b[key] not in input_a[key]:
  201. input_a[key].append(input_b[key])
  202. elif isinstance(input_b[key], list):
  203. input_a[key] = [input_a[key]] + [k for k in input_b[key] if k != input_a[key]]
  204. else:
  205. input_a[key] = [input_a[key], input_b[key]]
  206. else:
  207. input_a[key] = input_b[key]
  208. return input_a
  209. def is_cache_valid(self):
  210. ''' Determines if the cache files have expired, or if it is still valid '''
  211. if os.path.isfile(self.cache_path):
  212. mod_time = os.path.getmtime(self.cache_path)
  213. current_time = time()
  214. if (mod_time + self.config['cache_max_age']) > current_time:
  215. return True
  216. return False
  217. def parse_cli_args(self):
  218. ''' Command line argument processing '''
  219. parser = argparse.ArgumentParser(
  220. description='Produce an Ansible Inventory file based on a provider')
  221. parser.add_argument('--refresh-cache', action='store_true', default=False,
  222. help='Fetch cached only instances (default: False)')
  223. parser.add_argument('--list', action='store_true', default=True,
  224. help='List instances (default: True)')
  225. parser.add_argument('--host', action='store', default=False,
  226. help='Get all the variables about a specific instance')
  227. self.args = parser.parse_args()
  228. def write_to_cache(self):
  229. ''' Writes data in JSON format to a file '''
  230. # if it does not exist, try and create it.
  231. if not os.path.isfile(self.cache_path):
  232. path = os.path.dirname(self.cache_path)
  233. try:
  234. os.makedirs(path)
  235. except OSError as exc:
  236. if exc.errno != errno.EEXIST or not os.path.isdir(path):
  237. raise
  238. json_data = MultiEc2.json_format_dict(self.result, True)
  239. with open(self.cache_path, 'w') as cache:
  240. try:
  241. fcntl.flock(cache, fcntl.LOCK_EX)
  242. cache.write(json_data)
  243. finally:
  244. fcntl.flock(cache, fcntl.LOCK_UN)
  245. def get_inventory_from_cache(self):
  246. ''' Reads the inventory from the cache file and returns it as a JSON
  247. object '''
  248. if not os.path.isfile(self.cache_path):
  249. return None
  250. with open(self.cache_path, 'r') as cache:
  251. self.result = json.loads(cache.read())
  252. return True
  253. @classmethod
  254. def json_format_dict(cls, data, pretty=False):
  255. ''' Converts a dict to a JSON object and dumps it as a formatted
  256. string '''
  257. if pretty:
  258. return json.dumps(data, sort_keys=True, indent=2)
  259. else:
  260. return json.dumps(data)
  261. def result_str(self):
  262. '''Return cache string stored in self.result'''
  263. return self.json_format_dict(self.result, True)
  264. if __name__ == "__main__":
  265. print MultiEc2().result_str()