multi_ec2.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. #!/usr/bin/env python
  2. from time import time
  3. import argparse
  4. import yaml
  5. import os
  6. import sys
  7. import pdb
  8. import subprocess
  9. import json
  10. import pprint
  11. class MultiEc2(object):
  12. def __init__(self):
  13. self.config = None
  14. self.all_ec2_results = {}
  15. self.result = {}
  16. self.cache_path = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache')
  17. self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
  18. self.config_file = os.path.join(self.file_path,"multi_ec2.yaml")
  19. self.parse_cli_args()
  20. # load yaml
  21. if os.path.isfile(self.config_file):
  22. self.config = self.load_yaml_config()
  23. elif os.environ.has_key("AWS_ACCESS_KEY_ID") and os.environ.has_key("AWS_SECRET_ACCESS_KEY"):
  24. self.config = {}
  25. self.config['accounts'] = [
  26. {
  27. 'name': 'default',
  28. 'provider': 'aws/ec2.py',
  29. 'env_vars': {
  30. 'AWS_ACCESS_KEY_ID': os.environ["AWS_ACCESS_KEY_ID"],
  31. 'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"],
  32. }
  33. },
  34. ]
  35. self.config['cache_max_age'] = 0
  36. else:
  37. raise RuntimeError("Could not find valid ec2 credentials in the environment.")
  38. # if its a host query, fetch and do not cache
  39. if self.args.host:
  40. self.get_inventory()
  41. elif not self.is_cache_valid():
  42. # go fetch the inventories and cache them if cache is expired
  43. self.get_inventory()
  44. self.write_to_cache()
  45. else:
  46. # get data from disk
  47. self.get_inventory_from_cache()
  48. def load_yaml_config(self,conf_file=None):
  49. """Load a yaml config file with credentials to query the
  50. respective cloud for inventory.
  51. """
  52. config = None
  53. if not conf_file:
  54. conf_file = self.config_file
  55. with open(conf_file) as conf:
  56. config = yaml.safe_load(conf)
  57. return config
  58. def get_provider_tags(self,provider, env={}):
  59. """Call <provider> and query all of the tags that are usuable
  60. by ansible. If environment is empty use the default env.
  61. """
  62. if not env:
  63. env = os.environ
  64. # check to see if provider exists
  65. if not os.path.isfile(provider) or not os.access(provider, os.X_OK):
  66. raise RuntimeError("Problem with the provider. Please check path " \
  67. "and that it is executable. (%s)" % provider)
  68. cmds = [provider]
  69. if self.args.host:
  70. cmds.append("--host")
  71. cmds.append(self.args.host)
  72. else:
  73. cmds.append('--list')
  74. cmds.append('--refresh-cache')
  75. return subprocess.Popen(cmds, stderr=subprocess.PIPE, \
  76. stdout=subprocess.PIPE, env=env)
  77. def get_inventory(self):
  78. """Create the subprocess to fetch tags from a provider.
  79. Host query:
  80. Query to return a specific host. If > 1 queries have
  81. results then fail.
  82. List query:
  83. Query all of the different accounts for their tags. Once completed
  84. store all of their results into one merged updated hash.
  85. """
  86. processes = {}
  87. for account in self.config['accounts']:
  88. env = account['env_vars']
  89. name = account['name']
  90. provider = account['provider']
  91. processes[name] = self.get_provider_tags(provider, env)
  92. # for each process collect stdout when its available
  93. all_results = []
  94. for name, process in processes.items():
  95. out, err = process.communicate()
  96. all_results.append({
  97. "name": name,
  98. "out": out.strip(),
  99. "err": err.strip(),
  100. "code": process.returncode
  101. })
  102. if not self.args.host:
  103. # For any non-zero, raise an error on it
  104. for result in all_results:
  105. if result['code'] != 0:
  106. raise RuntimeError(result['err'])
  107. else:
  108. self.all_ec2_results[result['name']] = json.loads(result['out'])
  109. values = self.all_ec2_results.values()
  110. values.insert(0, self.result)
  111. [self.merge_destructively(self.result, x) for x in values]
  112. else:
  113. # For any 0 result, return it
  114. count = 0
  115. for results in all_results:
  116. if results['code'] == 0 and results['err'] == '' and results['out'] != '{}':
  117. self.result = json.loads(out)
  118. count += 1
  119. if count > 1:
  120. raise RuntimeError("Found > 1 results for --host %s. \
  121. This is an invalid state." % self.args.host)
  122. def merge_destructively(self, a, b):
  123. "merges b into a"
  124. for key in b:
  125. if key in a:
  126. if isinstance(a[key], dict) and isinstance(b[key], dict):
  127. self.merge_destructively(a[key], b[key])
  128. elif a[key] == b[key]:
  129. pass # same leaf value
  130. # both lists so add each element in b to a if it does ! exist
  131. elif isinstance(a[key], list) and isinstance(b[key],list):
  132. for x in b[key]:
  133. if x not in a[key]:
  134. a[key].append(x)
  135. # a is a list and not b
  136. elif isinstance(a[key], list):
  137. if b[key] not in a[key]:
  138. a[key].append(b[key])
  139. elif isinstance(b[key], list):
  140. a[key] = [a[key]] + [k for k in b[key] if k != a[key]]
  141. else:
  142. a[key] = [a[key],b[key]]
  143. else:
  144. a[key] = b[key]
  145. return a
  146. def is_cache_valid(self):
  147. ''' Determines if the cache files have expired, or if it is still valid '''
  148. if os.path.isfile(self.cache_path):
  149. mod_time = os.path.getmtime(self.cache_path)
  150. current_time = time()
  151. if (mod_time + self.config['cache_max_age']) > current_time:
  152. return True
  153. return False
  154. def parse_cli_args(self):
  155. ''' Command line argument processing '''
  156. parser = argparse.ArgumentParser(description='Produce an Ansible Inventory file based on a provider')
  157. parser.add_argument('--list', action='store_true', default=True,
  158. help='List instances (default: True)')
  159. parser.add_argument('--host', action='store',
  160. help='Get all the variables about a specific instance')
  161. self.args = parser.parse_args()
  162. def write_to_cache(self):
  163. ''' Writes data in JSON format to a file '''
  164. json_data = self.json_format_dict(self.result, True)
  165. with open(self.cache_path, 'w') as cache:
  166. cache.write(json_data)
  167. def get_inventory_from_cache(self):
  168. ''' Reads the inventory from the cache file and returns it as a JSON
  169. object '''
  170. with open(self.cache_path, 'r') as cache:
  171. self.result = json.loads(cache.read())
  172. def json_format_dict(self, data, pretty=False):
  173. ''' Converts a dict to a JSON object and dumps it as a formatted
  174. string '''
  175. if pretty:
  176. return json.dumps(data, sort_keys=True, indent=2)
  177. else:
  178. return json.dumps(data)
  179. def result_str(self):
  180. return self.json_format_dict(self.result, True)
  181. if __name__ == "__main__":
  182. mi = MultiEc2()
  183. print mi.result_str()