Browse Source

Updated merge function to merge recursively

Kenny Woodson 10 years ago
parent
commit
6eead926a8
2 changed files with 37 additions and 24 deletions
  1. 35 22
      inventory/multi_ec2.py
  2. 2 2
      inventory/multi_ec2.yaml.example

+ 35 - 22
inventory/multi_ec2.py

@@ -18,6 +18,7 @@ class MultiEc2(object):
         self.results = {}
         self.result = {}
         self.cache_path_cache = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache')
+        self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
 
         self.parse_cli_args()
 
@@ -35,11 +36,13 @@ class MultiEc2(object):
             # get data from disk
             self.get_inventory_from_cache()
 
-    def load_yaml_config(self,conf_file=os.path.join(os.getcwd(),'multi_ec2.yaml')):
+    def load_yaml_config(self,conf_file=None):
         """Load a yaml config file with credentials to query the
         respective cloud for inventory.
         """
         config = None
+        if not conf_file:
+            conf_file = os.path.join(self.file_path,'multi_ec2.yaml')
         with open(conf_file) as conf:
           self.config = yaml.safe_load(conf)
 
@@ -51,8 +54,9 @@ class MultiEc2(object):
             env = os.environ
 
         # check to see if provider exists
-        if not os.path.isfile(os.path.join(os.getcwd(),provider)):
-            raise RuntimeError("Unkown provider: %s" % provider)
+        if not os.path.isfile(provider) or not os.access(provider, os.X_OK):
+            raise RuntimeError("Problem with the provider.  Please check path " \
+                        "and that it is executable. (%s)" % provider)
 
         cmds = [provider]
         if self.args.host:
@@ -72,11 +76,11 @@ class MultiEc2(object):
         results then fail.
 
         List query:
-        Query all of the different clouds for their tags.  Once completed
+        Query all of the different accounts for their tags.  Once completed
         store all of their results into one merged updated hash.
         """
         processes = {}
-        for account in self.config['clouds']:
+        for account in self.config['accounts']:
             env = account['env_vars']
             name = account['name']
             provider = account['provider']
@@ -100,7 +104,9 @@ class MultiEc2(object):
                     raise RuntimeError(result['err'])
                 else:
                     self.results[result['name']] = json.loads(result['out'])
-            self.merge()
+            values = self.results.values()
+            values.insert(0, self.result)
+            map(lambda x: self.merge(self.result, x), values)
         else:
             # For any 0 result, return it
             count = 0
@@ -111,23 +117,30 @@ class MultiEc2(object):
                 if count > 1:
                     raise RuntimeError("Found > 1 results for --host %s. \
                                        This is an invalid state." % self.args.host)
-
-    def merge(self):
-        """Merge the results into a single hash.  Duplicate keys are placed
-        into a list.
-        """
-        for name, cloud_result in self.results.items():
-            for k,v in cloud_result.items():
-                if self.result.has_key(k):
-                    # need to combine into a list
-                    if isinstance(self.result[k], list):
-                        self.result[k].append(v)
-                    else:
-                        self.result[k] = [self.result[k],v]
+    def merge(self, a, b):
+        "merges b into a"
+        for key in b:
+            if key in a:
+                if isinstance(a[key], dict) and isinstance(b[key], dict):
+                    self.merge(a[key], b[key])
+                elif a[key] == b[key]:
+                    pass # same leaf value
+                # both lists so add each element in b to a if it does ! exist
+                elif isinstance(a[key], list) and isinstance(b[key],list):
+                    for x in b[key]:
+                        if x not in a[key]:
+                            a[key].append(x)
+                # a is a list and not b
+                elif isinstance(a[key], list):
+                    if b[key] not in a[key]:
+                        a[key].append(b[key])
+                elif isinstance(b[key], list):
+                    a[key] = [a[key]] + [k for k in b[key] if k != a[key]]
                 else:
-                    self.result[k] = [v]
-
-        self.result = self.json_format_dict(self.result)
+                    a[key] = [a[key],b[key]]
+            else:
+                a[key] = b[key]
+        return a
 
     def is_cache_valid(self):
         ''' Determines if the cache files have expired, or if it is still valid '''

+ 2 - 2
inventory/multi_ec2.yaml.example

@@ -1,5 +1,5 @@
-# meta inventory configs
-clouds:
+# multi ec2 inventory configs
+accounts:
   - name: aws1
     provider: aws/ec2.py
     env_vars: