Browse Source

Merge pull request #38 from kwoodson/default_env

Adding default aws credentials when calling mutli_ec2
Kenny Woodson 10 years ago
parent
commit
2c11830eb8
1 changed files with 30 additions and 8 deletions
  1. 30 8
      inventory/multi_ec2.py

+ 30 - 8
inventory/multi_ec2.py

@@ -15,15 +15,33 @@ class MultiEc2(object):
 
     def __init__(self):
         self.config = None
-        self.results = {}
+        self.all_ec2_results = {}
         self.result = {}
         self.cache_path_cache = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache')
         self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
-
+        self.config_file = os.path.join(self.file_path,"multi_ec2.yaml")
         self.parse_cli_args()
 
         # load yaml
-        self.load_yaml_config()
+        if os.path.isfile(self.config_file):
+            self.config = self.load_yaml_config()
+        elif os.environ.has_key("AWS_ACCESS_KEY_ID") and os.environ.has_key("AWS_SECRET_ACCESS_KEY"):
+            self.config = {}
+            self.config['accounts'] = [
+                {
+                    'name': 'default',
+                    'provider': 'aws/ec2.py',
+                    'env_vars': {
+                        'AWS_ACCESS_KEY_ID':     os.environ["AWS_ACCESS_KEY_ID"],
+                        'AWS_SECRET_ACCESS_KEY': os.environ["AWS_SECRET_ACCESS_KEY"],
+                    }
+                },
+            ]
+
+            self.config['cache_max_age'] = 0
+        else:
+            raise RuntimeError("Could not find valid ec2 credentials in the environment.")
+
 
         # if its a host query, fetch and do not cache
         if self.args.host:
@@ -41,10 +59,14 @@ class MultiEc2(object):
         respective cloud for inventory.
         """
         config = None
+
         if not conf_file:
-            conf_file = os.path.join(self.file_path,'multi_ec2.yaml')
+            conf_file = self.config_file
+
         with open(conf_file) as conf:
-          self.config = yaml.safe_load(conf)
+            config = yaml.safe_load(conf)
+
+        return config
 
     def get_provider_tags(self,provider, env={}):
         """Call <provider> and query all of the tags that are usuable
@@ -103,10 +125,10 @@ class MultiEc2(object):
                 if result['code'] != 0:
                     raise RuntimeError(result['err'])
                 else:
-                    self.results[result['name']] = json.loads(result['out'])
-            values = self.results.values()
+                    self.all_ec2_results[result['name']] = json.loads(result['out'])
+            values = self.all_ec2_results.values()
             values.insert(0, self.result)
-            map(lambda x: self.merge_destructively(self.result, x), values)
+            [self.merge_destructively(self.result, x) for x in  values]
         else:
             # For any 0 result, return it
             count = 0