Browse Source

Updated merge function to merge recursively

Kenny Woodson 10 years ago
parent
commit
6eead926a8
2 changed files with 37 additions and 24 deletions
  1. 35 22
      inventory/multi_ec2.py
  2. 2 2
      inventory/multi_ec2.yaml.example

+ 35 - 22
inventory/multi_ec2.py

@@ -18,6 +18,7 @@ class MultiEc2(object):
         self.results = {}
         self.results = {}
         self.result = {}
         self.result = {}
         self.cache_path_cache = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache')
         self.cache_path_cache = os.path.expanduser('~/.ansible/tmp/multi_ec2_inventory.cache')
+        self.file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)))
 
 
         self.parse_cli_args()
         self.parse_cli_args()
 
 
@@ -35,11 +36,13 @@ class MultiEc2(object):
             # get data from disk
             # get data from disk
             self.get_inventory_from_cache()
             self.get_inventory_from_cache()
 
 
-    def load_yaml_config(self,conf_file=os.path.join(os.getcwd(),'multi_ec2.yaml')):
+    def load_yaml_config(self,conf_file=None):
         """Load a yaml config file with credentials to query the
         """Load a yaml config file with credentials to query the
         respective cloud for inventory.
         respective cloud for inventory.
         """
         """
         config = None
         config = None
+        if not conf_file:
+            conf_file = os.path.join(self.file_path,'multi_ec2.yaml')
         with open(conf_file) as conf:
         with open(conf_file) as conf:
           self.config = yaml.safe_load(conf)
           self.config = yaml.safe_load(conf)
 
 
@@ -51,8 +54,9 @@ class MultiEc2(object):
             env = os.environ
             env = os.environ
 
 
         # check to see if provider exists
         # check to see if provider exists
-        if not os.path.isfile(os.path.join(os.getcwd(),provider)):
-            raise RuntimeError("Unkown provider: %s" % provider)
+        if not os.path.isfile(provider) or not os.access(provider, os.X_OK):
+            raise RuntimeError("Problem with the provider.  Please check path " \
+                        "and that it is executable. (%s)" % provider)
 
 
         cmds = [provider]
         cmds = [provider]
         if self.args.host:
         if self.args.host:
@@ -72,11 +76,11 @@ class MultiEc2(object):
         results then fail.
         results then fail.
 
 
         List query:
         List query:
-        Query all of the different clouds for their tags.  Once completed
+        Query all of the different accounts for their tags.  Once completed
         store all of their results into one merged updated hash.
         store all of their results into one merged updated hash.
         """
         """
         processes = {}
         processes = {}
-        for account in self.config['clouds']:
+        for account in self.config['accounts']:
             env = account['env_vars']
             env = account['env_vars']
             name = account['name']
             name = account['name']
             provider = account['provider']
             provider = account['provider']
@@ -100,7 +104,9 @@ class MultiEc2(object):
                     raise RuntimeError(result['err'])
                     raise RuntimeError(result['err'])
                 else:
                 else:
                     self.results[result['name']] = json.loads(result['out'])
                     self.results[result['name']] = json.loads(result['out'])
-            self.merge()
+            values = self.results.values()
+            values.insert(0, self.result)
+            map(lambda x: self.merge(self.result, x), values)
         else:
         else:
             # For any 0 result, return it
             # For any 0 result, return it
             count = 0
             count = 0
@@ -111,23 +117,30 @@ class MultiEc2(object):
                 if count > 1:
                 if count > 1:
                     raise RuntimeError("Found > 1 results for --host %s. \
                     raise RuntimeError("Found > 1 results for --host %s. \
                                        This is an invalid state." % self.args.host)
                                        This is an invalid state." % self.args.host)
-
-    def merge(self):
-        """Merge the results into a single hash.  Duplicate keys are placed
-        into a list.
-        """
-        for name, cloud_result in self.results.items():
-            for k,v in cloud_result.items():
-                if self.result.has_key(k):
-                    # need to combine into a list
-                    if isinstance(self.result[k], list):
-                        self.result[k].append(v)
-                    else:
-                        self.result[k] = [self.result[k],v]
+    def merge(self, a, b):
+        "merges b into a"
+        for key in b:
+            if key in a:
+                if isinstance(a[key], dict) and isinstance(b[key], dict):
+                    self.merge(a[key], b[key])
+                elif a[key] == b[key]:
+                    pass # same leaf value
+                # both lists so add each element in b to a if it does ! exist
+                elif isinstance(a[key], list) and isinstance(b[key],list):
+                    for x in b[key]:
+                        if x not in a[key]:
+                            a[key].append(x)
+                # a is a list and not b
+                elif isinstance(a[key], list):
+                    if b[key] not in a[key]:
+                        a[key].append(b[key])
+                elif isinstance(b[key], list):
+                    a[key] = [a[key]] + [k for k in b[key] if k != a[key]]
                 else:
                 else:
-                    self.result[k] = [v]
-
-        self.result = self.json_format_dict(self.result)
+                    a[key] = [a[key],b[key]]
+            else:
+                a[key] = b[key]
+        return a
 
 
     def is_cache_valid(self):
     def is_cache_valid(self):
         ''' Determines if the cache files have expired, or if it is still valid '''
         ''' Determines if the cache files have expired, or if it is still valid '''

+ 2 - 2
inventory/multi_ec2.yaml.example

@@ -1,5 +1,5 @@
-# meta inventory configs
-clouds:
+# multi ec2 inventory configs
+accounts:
   - name: aws1
   - name: aws1
     provider: aws/ec2.py
     provider: aws/ec2.py
     env_vars:
     env_vars: