diff --git a/plugins/module_utils/common/argspec_validate.py b/plugins/module_utils/common/argspec_validate.py new file mode 100644 index 0000000..a28c6a2 --- /dev/null +++ b/plugins/module_utils/common/argspec_validate.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +"""Use AnsibleModule's argspec validation + +def _check_argspec(self): + aav = AnsibleArgSpecValidator( + data=self._task.args, + schema=DOCUMENTATION, + schema_format="doc", + schema_conditionals={}, + other_args={}, + name=self._task.action, + ) + valid, errors = aav.validate() + if not valid: + raise AnsibleActionFail(errors) + +""" +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import json +import re +from ansible.module_utils.basic import AnsibleModule +from ansible_collections.ansible.utils.plugins.module_utils.common.utils import ( + dict_merge, +) +from ansible.module_utils.six import iteritems, string_types +from ansible.module_utils._text import to_bytes + +try: + import yaml + + # use C version if possible for speedup + try: + from yaml import CSafeLoader as SafeLoader + except ImportError: + from yaml import SafeLoader + HAS_YAML = True +except ImportError: + HAS_YAML = False + +# TODO: Update this to point to functionality being exposed in 2.11 +# ansible-base 2.11 should expose argspec validation outside of the +# ansiblemodule class +try: + from ansible.module_utils.somefile import FutureBaseArgspecValidator + + HAS_ANSIBLE_ARG_SPEC_VALIDATOR = True +except ImportError: + HAS_ANSIBLE_ARG_SPEC_VALIDATOR = False + + +OPTION_METADATA = ( + "type", + "choices", + "default", + "required", + "aliases", + "elements", + "fallback", + "no_log", + "apply_defaults", + "deprecated_aliases", + "removed_in_version", +) +OPTION_CONDITIONALS = ( + "mutually_exclusive", + "required_one_of", + "required_together", + "required_by", + "required_if", +) + +VALID_ANSIBLEMODULE_ARGS = ( + "argument_spec", + "bypass_checks", + "no_log", + "add_file_common_args", + "supports_check_mode", +) + OPTION_CONDITIONALS + +BASE_ARG_AVAIL = 2.11 + + +class MonkeyModule(AnsibleModule): + """A derivative of the AnsibleModule used + to just validate the data (task.args) against + the schema(argspec) + """ + + def __init__(self, data, schema, name): + self._errors = None + self._valid = True + self._schema = schema + self.name = name + self.params = data + + def fail_json(self, msg): + """Replace the AnsibleModule fail_json here + :param msg: The message for the failure + :type msg: str + """ + if self.name: + msg = re.sub( + r"\(basic\.pyc?\)", + "'{name}'".format(name=self.name), + msg, + ) + self._valid = False + self._errors = msg + + def _load_params(self): + """This replaces the AnsibleModule _load_params + fn because we already set self.params in init + """ + pass + + def validate(self): + """Instantiate the super, validating the schema + against the data + :return valid: if the data passed + :rtype valid: bool + :return errors: errors reported during validation + :rtype errors: str + """ + super(MonkeyModule, self).__init__(**self._schema) + return self._valid, self._errors + + +class AnsibleArgSpecValidator: + def __init__( + self, + data, + schema, + schema_format, + schema_conditionals=None, + name=None, + other_args=None, + ): + """Validate some data against a schema + :param data: The data to valdiate + :type data: dict + :param schema: A schema in ansible argspec format + :type schema: dict + :param schema_format: 'doc' (ansible docstring) or 'argspec' (ansible argspec) + :type schema: str if doc, dict if argspec + :param schema_conditionals: A dict of schema conditionals, ie required_if + :type schema_conditionals: dict + :param name: the name of the plugin calling this class, used in error messages + :type name: str + :param other_args: Other valid kv pairs for the argspec, eg no_log, bypass_checks + :type other_args: dict + + note: + - the schema conditionals can be root conditionals or deeply nested conditionals + these get dict_merged into the argspec from the docstring, since the docstring cannot + contain them. + """ + self._errors = "" + self._name = name + self._other_args = other_args + self._schema = schema + self._schema_format = schema_format + self._schema_conditionals = schema_conditionals + self._data = data + + def _extract_schema_from_doc(self, doc_obj, temp_schema): + """Extract the schema from a doc string + :param doc_obj: The doc as a python obj + :type doc_obj: dictionary + :params temp_schema: The dict in which we stuff the schema parts + :type temp_schema: dict + """ + options_obj = doc_obj.get("options") + for okey, ovalue in iteritems(options_obj): + temp_schema[okey] = {} + for metakey in list(ovalue): + if metakey == "suboptions": + temp_schema[okey].update({"options": {}}) + suboptions_obj = {"options": ovalue["suboptions"]} + self._extract_schema_from_doc( + suboptions_obj, temp_schema[okey]["options"] + ) + elif metakey in OPTION_METADATA + OPTION_CONDITIONALS: + temp_schema[okey].update({metakey: ovalue[metakey]}) + + # TODO: Support extends_documentation_fragment + def _convert_doc_to_schema(self): + """Convert the doc string to an obj, was yaml + add back other valid conditionals and params + """ + doc_obj = yaml.load(self._schema, SafeLoader) + temp_schema = {} + self._extract_schema_from_doc(doc_obj, temp_schema) + self._schema = {"argument_spec": temp_schema} + + def _validate(self): + """Validate the data gainst the schema + convert doc string in argspec if necessary + """ + if self._schema_format == "doc": + self._convert_doc_to_schema() + if self._schema_conditionals is not None: + self._schema = dict_merge(self._schema, self._schema_conditionals) + if self._other_args is not None: + self._schema = dict_merge(self._schema, self._other_args) + invalid_keys = [ + k for k in self._schema.keys() if k not in VALID_ANSIBLEMODULE_ARGS + ] + if invalid_keys: + valid = False + errors = "Invalid schema. Invalid keys found: {ikeys}".format( + ikeys=",".join(invalid_keys) + ) + else: + mm = MonkeyModule( + data=self._data, schema=self._schema, name=self._name + ) + valid, errors = mm.validate() + return valid, errors + + def validate(self): + """The public validate method + check for future argspec validation + that is coming in 2.11, change the check according above + """ + if HAS_ANSIBLE_ARG_SPEC_VALIDATOR: + return self._validate() + else: + return self._validate() diff --git a/plugins/module_utils/common/utils.py b/plugins/module_utils/common/utils.py new file mode 100644 index 0000000..a69dabe --- /dev/null +++ b/plugins/module_utils/common/utils.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +from copy import deepcopy +from itertools import chain + +from ansible.module_utils.common._collections_compat import Mapping +from ansible.module_utils.six import iteritems, string_types + + +def sort_list(val): + if isinstance(val, list): + if isinstance(val[0], dict): + sorted_keys = [tuple(sorted(dict_.keys())) for dict_ in val] + # All keys should be identical + if len(set(sorted_keys)) != 1: + raise ValueError("dictionaries do not match") + + return sorted( + val, key=lambda d: tuple(d[k] for k in sorted_keys[0]) + ) + return sorted(val) + return val + + +def dict_merge(base, other): + """Return a new dict object that combines base and other + + This will create a new dict object that is a combination of the key/value + pairs from base and other. When both keys exist, the value will be + selected from other. + + If the value in base is a list, and the value in other is a list + the base list will be extended with the values from the other list that were + not already present in the base list + + If the value in base is a list, and the value in other is a list + and the two have the same entries, the value from other will be + used, preserving the order from the other list + + If the value in base is a list, and the value in other is not a list + the value from other will be used + + :param base: dict object to serve as base + :param other: dict object to combine with base + + :returns: new combined dict object + """ + if not isinstance(base, dict): + raise AssertionError("`base` must be of type ") + if not isinstance(other, dict): + raise AssertionError("`other` must be of type ") + + combined = dict() + + for key, value in iteritems(deepcopy(base)): + if isinstance(value, dict): + if key in other: + item = other.get(key) + if item is not None: + if isinstance(other[key], Mapping): + combined[key] = dict_merge(value, other[key]) + else: + combined[key] = other[key] + else: + combined[key] = item + else: + combined[key] = value + elif isinstance(value, list): + if key in other: + item = other.get(key) + if isinstance(item, list): + if sort_list(value) == sort_list(item): + combined[key] = item + else: + value.extend([i for i in item if i not in value]) + combined[key] = value + else: + combined[key] = item + else: + combined[key] = value + else: + if key in other: + other_value = other.get(key) + if other_value is not None: + if sort_list(base[key]) != sort_list(other_value): + combined[key] = other_value + else: + combined[key] = value + else: + combined[key] = other_value + else: + combined[key] = value + + for key in set(other.keys()).difference(base.keys()): + combined[key] = other.get(key) + + return combined diff --git a/tests/unit/module_utils/fixtures/docstring.py b/tests/unit/module_utils/fixtures/docstring.py new file mode 100644 index 0000000..d94bf3b --- /dev/null +++ b/tests/unit/module_utils/fixtures/docstring.py @@ -0,0 +1,35 @@ +DOCUMENTATION = """ +module: test +author: Bradley Thornton (@cidrblock) +short_description: Short description here +description: +- A longer description here +version_added: 0.0.0 +options: + param_str: + type: str + description: + - A string param + required: True + params_bool: + type: bool + description: + - A bool param + params_dict: + type: dict + description: + - A dict param + suboptions: + subo_str: + type: str + description: + - A string suboption + subo_list: + type: list + description: + - A list suboption + subo_dict: + type: dict + description: + - A dict suboption +""" diff --git a/tests/unit/module_utils/test_argspec_validate.py b/tests/unit/module_utils/test_argspec_validate.py new file mode 100644 index 0000000..6ed2eb2 --- /dev/null +++ b/tests/unit/module_utils/test_argspec_validate.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import unittest +from ansible_collections.ansible.utils.plugins.module_utils.common.argspec_validate import ( + AnsibleArgSpecValidator, +) +from .fixtures.docstring import DOCUMENTATION + + +class TestSortList(unittest.TestCase): + def test_simple_pass(self): + data = {"param_str": "string"} + aav = AnsibleArgSpecValidator( + data=data, + schema=DOCUMENTATION, + schema_format="doc", + schema_conditionals={}, + name="test_action", + ) + valid, errors = aav.validate() + self.assertTrue(valid) + self.assertEqual(errors, None) + + def test_simple_fail(self): + data = {} + aav = AnsibleArgSpecValidator( + data=data, + schema=DOCUMENTATION, + schema_format="doc", + schema_conditionals={}, + name="test_action", + ) + valid, errors = aav.validate() + self.assertFalse(valid) + self.assertIn("missing required arguments: param_str", errors) + + def test_simple_fail_no_name(self): + data = {} + aav = AnsibleArgSpecValidator( + data=data, + schema=DOCUMENTATION, + schema_format="doc", + schema_conditionals={}, + ) + valid, errors = aav.validate() + self.assertFalse(valid) + self.assertIn("missing required arguments: param_str", errors) + + def test_not_doc(self): + data = {"param_str": "string"} + aav = AnsibleArgSpecValidator( + data=data, + schema={"argument_spec": {"param_str": {"type": "str"}}}, + schema_format="argspec", + name="test_action", + ) + valid, errors = aav.validate() + self.assertTrue(valid) + self.assertEqual(errors, None) + + def test_schema_conditional(self): + data = {"param_str": "string"} + aav = AnsibleArgSpecValidator( + data=data, + schema=DOCUMENTATION, + schema_format="doc", + schema_conditionals={ + "required_together": [["param_str", "param_bool"]] + }, + name="test_action", + ) + valid, errors = aav.validate() + self.assertFalse(valid) + self.assertIn( + "parameters are required together: param_str, param_bool", errors + ) + + def test_unsupported_param(self): + data = {"param_str": "string", "not_valid": "string"} + aav = AnsibleArgSpecValidator( + data=data, + schema=DOCUMENTATION, + schema_format="doc", + name="test_action", + # other_args={'bypass_checks': True}, + ) + valid, errors = aav.validate() + self.assertFalse(valid) + self.assertIn( + "Unsupported parameters for 'test_action' module: not_valid", + errors, + ) + + def test_other_args(self): + data = {} + aav = AnsibleArgSpecValidator( + data=data, + schema=DOCUMENTATION, + schema_format="doc", + name="test_action", + other_args={"bypass_checks": True}, + ) + valid, errors = aav.validate() + self.assertTrue(valid) + self.assertIsNone(errors) + + def test_invalid_spec(self): + data = {} + aav = AnsibleArgSpecValidator( + data=data, + schema={"not_valid": True}, + schema_format="argspec", + name="test_action", + other_args={"bypass_checks": True}, + ) + valid, errors = aav.validate() + self.assertFalse(valid) + self.assertIn("Invalid keys found: not_valid", errors) diff --git a/tests/unit/module_utils/test_dict_merge.py b/tests/unit/module_utils/test_dict_merge.py new file mode 100644 index 0000000..6a1d944 --- /dev/null +++ b/tests/unit/module_utils/test_dict_merge.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import unittest +from ansible_collections.ansible.utils.plugins.module_utils.common.utils import ( + dict_merge, +) + + +class TestDict_merge(unittest.TestCase): + def test_not_dict_base(self): + base = [0] + other = {"a": "b"} + expected = "must be of type " + with self.assertRaises(Exception) as exc: + dict_merge(base, other) + self.assertIn(expected, str(exc.exception)) + + def test_not_dict_other(self): + base = {"a": "b"} + other = [0] + expected = "must be of type " + with self.assertRaises(Exception) as exc: + dict_merge(base, other) + self.assertIn(expected, str(exc.exception)) + + def test_simple(self): + base = {"a": "b"} + other = {"c": "d"} + expected = {"a": "b", "c": "d"} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_simple_other_is_string(self): + base = {"a": "b"} + other = {"a": "c"} + expected = {"a": "c"} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_simple_other_is_none(self): + base = {"a": "b"} + other = {"a": None} + expected = {"a": None} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_list_value_same(self): + base = {"a": [2, 1, 0]} + other = {"a": [0, 1, 2]} + expected = {"a": [0, 1, 2]} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_list_value_combine(self): + base = {"a": [2, 1, 0]} + other = {"a": [0, 3]} + expected = {"a": [2, 1, 0, 3]} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_list_missing_from_other(self): + base = {"a": [2, 1, 0]} + other = {"b": [2, 1, 0]} + expected = {"a": [2, 1, 0], "b": [2, 1, 0]} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_list_other_none(self): + base = {"a": [2, 1, 0]} + other = {"a": None} + expected = {"a": None} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + # modified dict merge from netcommon so this works + def test_list_other_dict(self): + base = {"a": [2, 1, 0]} + other = {"a": {"b": "c"}} + expected = {"a": {"b": "c"}} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + # modified dict merge from netcommon so this works + def test_list_other_string(self): + base = {"a": [2, 1, 0]} + other = {"a": "xyz"} + expected = {"a": "xyz"} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_dict_of_dict_merged(self): + base = {"a": {"b": 0}} + other = {"a": {"c": 1}} + expected = {"a": {"b": 0, "c": 1}} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_dict_of_dict_replaced(self): + base = {"a": {"b": 0}} + other = {"a": {"b": 1}} + expected = {"a": {"b": 1}} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_dict_of_dict_replaced_other_none(self): + base = {"a": {"b": 0}} + other = {"a": None} + expected = {"a": None} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_dict_of_dict_replaced_other_string(self): + base = {"a": {"b": 0}} + other = {"a": "xyz"} + expected = {"a": "xyz"} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_dict_of_dict_replaced_other_missing(self): + base = {"a": {"b": 0}} + other = {"c": 1} + expected = {"a": {"b": 0}, "c": 1} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_not_list_or_dict_different(self): + base = {"a": 0} + other = {"a": 1} + expected = {"a": 1} + result = dict_merge(base, other) + self.assertEqual(result, expected) + + def test_not_list_or_dict_same(self): + base = {"a": 0} + other = {"a": 0} + expected = {"a": 0} + result = dict_merge(base, other) + self.assertEqual(result, expected) diff --git a/tests/unit/module_utils/test_sort_list.py b/tests/unit/module_utils/test_sort_list.py new file mode 100644 index 0000000..ae5ab21 --- /dev/null +++ b/tests/unit/module_utils/test_sort_list.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Red Hat +# GNU General Public License v3.0+ +# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +from __future__ import absolute_import, division, print_function + +__metaclass__ = type + +import unittest +from ansible_collections.ansible.utils.plugins.module_utils.common.utils import ( + sort_list, +) + + +class TestSortList(unittest.TestCase): + def test_simple(self): + var = [3, 2, 1] + result = sort_list(var) + expected = [1, 2, 3] + self.assertEqual(result, expected) + + def test_mot_list(self): + var = {"a": "b"} + result = sort_list(var) + self.assertEqual(result, var) + + def test_not_same(self): + var = [{"a": "b", "c": "d"}, {"b": "a"}] + expected = "dictionaries do not match" + with self.assertRaises(Exception) as exc: + sort_list(var) + self.assertIn(expected, str(exc.exception)) + + def test_pass(self): + var = [{"a": 2, "b": 3}, {"a": 0, "b": 1}] + expected = [{"a": 0, "b": 1}, {"a": 2, "b": 3}] + result = sort_list(var) + self.assertEqual(result, expected)