Browse Source

Make resolve_checks more strict

So that all names it return are directly usable, without checking for
existence.
Rodolfo Carvalho 8 years ago
parent
commit
fe0dafda92

+ 33 - 21
roles/openshift_health_checker/action_plugins/openshift_health_check.py

@@ -4,6 +4,7 @@ Ansible action plugin to execute health checks in OpenShift clusters.
 # pylint: disable=wrong-import-position,missing-docstring,invalid-name
 import sys
 import os
+from collections import defaultdict
 
 try:
     from __main__ import display
@@ -41,20 +42,11 @@ class ActionModule(ActionBase):
             return result
 
         args = self._task.args
-        requested_checks = resolve_checks(args.get("checks", []), known_checks.values())
-
-        unknown_checks = requested_checks - set(known_checks)
-        if unknown_checks:
-            result["failed"] = True
-            result["msg"] = (
-                "One or more checks are unknown: {}. "
-                "Make sure there is no typo in the playbook and no files are missing."
-            ).format(", ".join(unknown_checks))
-            return result
+        resolved_checks = resolve_checks(args.get("checks", []), known_checks.values())
 
         result["checks"] = check_results = {}
 
-        for check_name in requested_checks & set(known_checks):
+        for check_name in resolved_checks:
             display.banner("CHECK [{} : {}]".format(check_name, task_vars["ansible_host"]))
             check = known_checks[check_name]
 
@@ -101,19 +93,39 @@ class ActionModule(ActionBase):
 def resolve_checks(names, all_checks):
     """Returns a set of resolved check names.
 
-    Resolving a check name involves expanding tag references (e.g., '@tag') with
-    all the checks that contain the given tag.
+    Resolving a check name expands tag references (e.g., "@tag") to all the
+    checks that contain the given tag. OpenShiftCheckException is raised if
+    names contains an unknown check or tag name.
 
     names should be a sequence of strings.
 
     all_checks should be a sequence of check classes/instances.
     """
-    resolved = set()
-    for name in names:
-        if name.startswith("@"):
-            for check in all_checks:
-                if name[1:] in check.tags:
-                    resolved.add(check.name)
-        else:
-            resolved.add(name)
+    known_check_names = set(check.name for check in all_checks)
+    known_tag_names = set(name for check in all_checks for name in check.tags)
+
+    check_names = set(name for name in names if not name.startswith('@'))
+    tag_names = set(name[1:] for name in names if name.startswith('@'))
+
+    unknown_check_names = check_names - known_check_names
+    unknown_tag_names = tag_names - known_tag_names
+
+    if unknown_check_names or unknown_tag_names:
+        msg = []
+        if unknown_check_names:
+            msg.append('Unknown check names: {}.'.format(', '.join(sorted(unknown_check_names))))
+        if unknown_tag_names:
+            msg.append('Unknown tag names: {}.'.format(', '.join(sorted(unknown_tag_names))))
+        msg.append('Make sure there is no typo in the playbook and no files are missing.')
+        raise OpenShiftCheckException('\n'.join(msg))
+
+    tag_to_checks = defaultdict(set)
+    for check in all_checks:
+        for tag in check.tags:
+            tag_to_checks[tag].add(check.name)
+
+    resolved = check_names.copy()
+    for tag in tag_names:
+        resolved.update(tag_to_checks[tag])
+
     return resolved

+ 74 - 0
roles/openshift_health_checker/test/action_plugin_test.py

@@ -0,0 +1,74 @@
+import pytest
+
+from openshift_health_check import resolve_checks
+
+
+class FakeCheck(object):
+    def __init__(self, name, tags=None):
+        self.name = name
+        self.tags = tags or []
+
+
+@pytest.mark.parametrize('names,all_checks,expected', [
+    ([], [], set()),
+    (
+        ['a', 'b'],
+        [
+            FakeCheck('a'),
+            FakeCheck('b'),
+        ],
+        set(['a', 'b']),
+    ),
+    (
+        ['a', 'b', '@group'],
+        [
+            FakeCheck('from_group_1', ['group', 'another_group']),
+            FakeCheck('not_in_group', ['another_group']),
+            FakeCheck('from_group_2', ['preflight', 'group']),
+            FakeCheck('a'),
+            FakeCheck('b'),
+        ],
+        set(['a', 'b', 'from_group_1', 'from_group_2']),
+    ),
+])
+def test_resolve_checks_ok(names, all_checks, expected):
+    assert resolve_checks(names, all_checks) == expected
+
+
+@pytest.mark.parametrize('names,all_checks,words_in_exception,words_not_in_exception', [
+    (
+        ['testA', 'testB'],
+        [],
+        ['check', 'name', 'testA', 'testB'],
+        ['tag', 'group', '@'],
+    ),
+    (
+        ['@group'],
+        [],
+        ['tag', 'name', 'group'],
+        ['check', '@'],
+    ),
+    (
+        ['testA', 'testB', '@group'],
+        [],
+        ['check', 'name', 'testA', 'testB', 'tag', 'group'],
+        ['@'],
+    ),
+    (
+        ['testA', 'testB', '@group'],
+        [
+            FakeCheck('from_group_1', ['group', 'another_group']),
+            FakeCheck('not_in_group', ['another_group']),
+            FakeCheck('from_group_2', ['preflight', 'group']),
+        ],
+        ['check', 'name', 'testA', 'testB'],
+        ['tag', 'group', '@'],
+    ),
+])
+def test_resolve_checks_failure(names, all_checks, words_in_exception, words_not_in_exception):
+    with pytest.raises(Exception) as excinfo:
+        resolve_checks(names, all_checks)
+    for word in words_in_exception:
+        assert word in str(excinfo.value)
+    for word in words_not_in_exception:
+        assert word not in str(excinfo.value)

+ 7 - 2
roles/openshift_health_checker/test/conftest.py

@@ -1,5 +1,10 @@
 import os
 import sys
 
-# extend sys.path so that tests can import openshift_checks
-sys.path.insert(1, os.path.dirname(os.path.dirname(__file__)))
+# extend sys.path so that tests can import openshift_checks and action plugins
+# from this role.
+openshift_health_checker_path = os.path.dirname(os.path.dirname(__file__))
+sys.path[1:1] = [
+    openshift_health_checker_path,
+    os.path.join(openshift_health_checker_path, 'action_plugins')
+]