Ansible arg spec validation using AnsibleModule (#9)

* Add argspec validator

* Add argspec validator

* Remove usused param

* lint fix

* Remove unneeded import

* Add sort_list tests

* Unit tests for dict_merge

* lint fix

* argspec unit tests

* Remove q

* Rerun black

* doc updates

Co-authored-by: cidrblock <brad@thethorntons.net>
pull/11/head
Bradley A. Thornton 2020-10-19 08:23:50 -07:00 committed by GitHub
parent 12333c2386
commit 59eb835d96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 686 additions and 0 deletions

View File

@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Red Hat
# GNU General Public License v3.0+
# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
"""Use AnsibleModule's argspec validation
def _check_argspec(self):
aav = AnsibleArgSpecValidator(
data=self._task.args,
schema=DOCUMENTATION,
schema_format="doc",
schema_conditionals={},
other_args={},
name=self._task.action,
)
valid, errors = aav.validate()
if not valid:
raise AnsibleActionFail(errors)
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import json
import re
from ansible.module_utils.basic import AnsibleModule
from ansible_collections.ansible.utils.plugins.module_utils.common.utils import (
dict_merge,
)
from ansible.module_utils.six import iteritems, string_types
from ansible.module_utils._text import to_bytes
try:
import yaml
# use C version if possible for speedup
try:
from yaml import CSafeLoader as SafeLoader
except ImportError:
from yaml import SafeLoader
HAS_YAML = True
except ImportError:
HAS_YAML = False
# TODO: Update this to point to functionality being exposed in 2.11
# ansible-base 2.11 should expose argspec validation outside of the
# ansiblemodule class
try:
from ansible.module_utils.somefile import FutureBaseArgspecValidator
HAS_ANSIBLE_ARG_SPEC_VALIDATOR = True
except ImportError:
HAS_ANSIBLE_ARG_SPEC_VALIDATOR = False
OPTION_METADATA = (
"type",
"choices",
"default",
"required",
"aliases",
"elements",
"fallback",
"no_log",
"apply_defaults",
"deprecated_aliases",
"removed_in_version",
)
OPTION_CONDITIONALS = (
"mutually_exclusive",
"required_one_of",
"required_together",
"required_by",
"required_if",
)
VALID_ANSIBLEMODULE_ARGS = (
"argument_spec",
"bypass_checks",
"no_log",
"add_file_common_args",
"supports_check_mode",
) + OPTION_CONDITIONALS
BASE_ARG_AVAIL = 2.11
class MonkeyModule(AnsibleModule):
"""A derivative of the AnsibleModule used
to just validate the data (task.args) against
the schema(argspec)
"""
def __init__(self, data, schema, name):
self._errors = None
self._valid = True
self._schema = schema
self.name = name
self.params = data
def fail_json(self, msg):
"""Replace the AnsibleModule fail_json here
:param msg: The message for the failure
:type msg: str
"""
if self.name:
msg = re.sub(
r"\(basic\.pyc?\)",
"'{name}'".format(name=self.name),
msg,
)
self._valid = False
self._errors = msg
def _load_params(self):
"""This replaces the AnsibleModule _load_params
fn because we already set self.params in init
"""
pass
def validate(self):
"""Instantiate the super, validating the schema
against the data
:return valid: if the data passed
:rtype valid: bool
:return errors: errors reported during validation
:rtype errors: str
"""
super(MonkeyModule, self).__init__(**self._schema)
return self._valid, self._errors
class AnsibleArgSpecValidator:
def __init__(
self,
data,
schema,
schema_format,
schema_conditionals=None,
name=None,
other_args=None,
):
"""Validate some data against a schema
:param data: The data to valdiate
:type data: dict
:param schema: A schema in ansible argspec format
:type schema: dict
:param schema_format: 'doc' (ansible docstring) or 'argspec' (ansible argspec)
:type schema: str if doc, dict if argspec
:param schema_conditionals: A dict of schema conditionals, ie required_if
:type schema_conditionals: dict
:param name: the name of the plugin calling this class, used in error messages
:type name: str
:param other_args: Other valid kv pairs for the argspec, eg no_log, bypass_checks
:type other_args: dict
note:
- the schema conditionals can be root conditionals or deeply nested conditionals
these get dict_merged into the argspec from the docstring, since the docstring cannot
contain them.
"""
self._errors = ""
self._name = name
self._other_args = other_args
self._schema = schema
self._schema_format = schema_format
self._schema_conditionals = schema_conditionals
self._data = data
def _extract_schema_from_doc(self, doc_obj, temp_schema):
"""Extract the schema from a doc string
:param doc_obj: The doc as a python obj
:type doc_obj: dictionary
:params temp_schema: The dict in which we stuff the schema parts
:type temp_schema: dict
"""
options_obj = doc_obj.get("options")
for okey, ovalue in iteritems(options_obj):
temp_schema[okey] = {}
for metakey in list(ovalue):
if metakey == "suboptions":
temp_schema[okey].update({"options": {}})
suboptions_obj = {"options": ovalue["suboptions"]}
self._extract_schema_from_doc(
suboptions_obj, temp_schema[okey]["options"]
)
elif metakey in OPTION_METADATA + OPTION_CONDITIONALS:
temp_schema[okey].update({metakey: ovalue[metakey]})
# TODO: Support extends_documentation_fragment
def _convert_doc_to_schema(self):
"""Convert the doc string to an obj, was yaml
add back other valid conditionals and params
"""
doc_obj = yaml.load(self._schema, SafeLoader)
temp_schema = {}
self._extract_schema_from_doc(doc_obj, temp_schema)
self._schema = {"argument_spec": temp_schema}
def _validate(self):
"""Validate the data gainst the schema
convert doc string in argspec if necessary
"""
if self._schema_format == "doc":
self._convert_doc_to_schema()
if self._schema_conditionals is not None:
self._schema = dict_merge(self._schema, self._schema_conditionals)
if self._other_args is not None:
self._schema = dict_merge(self._schema, self._other_args)
invalid_keys = [
k for k in self._schema.keys() if k not in VALID_ANSIBLEMODULE_ARGS
]
if invalid_keys:
valid = False
errors = "Invalid schema. Invalid keys found: {ikeys}".format(
ikeys=",".join(invalid_keys)
)
else:
mm = MonkeyModule(
data=self._data, schema=self._schema, name=self._name
)
valid, errors = mm.validate()
return valid, errors
def validate(self):
"""The public validate method
check for future argspec validation
that is coming in 2.11, change the check according above
"""
if HAS_ANSIBLE_ARG_SPEC_VALIDATOR:
return self._validate()
else:
return self._validate()

View File

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Red Hat
# GNU General Public License v3.0+
# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from copy import deepcopy
from itertools import chain
from ansible.module_utils.common._collections_compat import Mapping
from ansible.module_utils.six import iteritems, string_types
def sort_list(val):
if isinstance(val, list):
if isinstance(val[0], dict):
sorted_keys = [tuple(sorted(dict_.keys())) for dict_ in val]
# All keys should be identical
if len(set(sorted_keys)) != 1:
raise ValueError("dictionaries do not match")
return sorted(
val, key=lambda d: tuple(d[k] for k in sorted_keys[0])
)
return sorted(val)
return val
def dict_merge(base, other):
"""Return a new dict object that combines base and other
This will create a new dict object that is a combination of the key/value
pairs from base and other. When both keys exist, the value will be
selected from other.
If the value in base is a list, and the value in other is a list
the base list will be extended with the values from the other list that were
not already present in the base list
If the value in base is a list, and the value in other is a list
and the two have the same entries, the value from other will be
used, preserving the order from the other list
If the value in base is a list, and the value in other is not a list
the value from other will be used
:param base: dict object to serve as base
:param other: dict object to combine with base
:returns: new combined dict object
"""
if not isinstance(base, dict):
raise AssertionError("`base` must be of type <dict>")
if not isinstance(other, dict):
raise AssertionError("`other` must be of type <dict>")
combined = dict()
for key, value in iteritems(deepcopy(base)):
if isinstance(value, dict):
if key in other:
item = other.get(key)
if item is not None:
if isinstance(other[key], Mapping):
combined[key] = dict_merge(value, other[key])
else:
combined[key] = other[key]
else:
combined[key] = item
else:
combined[key] = value
elif isinstance(value, list):
if key in other:
item = other.get(key)
if isinstance(item, list):
if sort_list(value) == sort_list(item):
combined[key] = item
else:
value.extend([i for i in item if i not in value])
combined[key] = value
else:
combined[key] = item
else:
combined[key] = value
else:
if key in other:
other_value = other.get(key)
if other_value is not None:
if sort_list(base[key]) != sort_list(other_value):
combined[key] = other_value
else:
combined[key] = value
else:
combined[key] = other_value
else:
combined[key] = value
for key in set(other.keys()).difference(base.keys()):
combined[key] = other.get(key)
return combined

View File

@ -0,0 +1,35 @@
DOCUMENTATION = """
module: test
author: Bradley Thornton (@cidrblock)
short_description: Short description here
description:
- A longer description here
version_added: 0.0.0
options:
param_str:
type: str
description:
- A string param
required: True
params_bool:
type: bool
description:
- A bool param
params_dict:
type: dict
description:
- A dict param
suboptions:
subo_str:
type: str
description:
- A string suboption
subo_list:
type: list
description:
- A list suboption
subo_dict:
type: dict
description:
- A dict suboption
"""

View File

@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Red Hat
# GNU General Public License v3.0+
# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import unittest
from ansible_collections.ansible.utils.plugins.module_utils.common.argspec_validate import (
AnsibleArgSpecValidator,
)
from .fixtures.docstring import DOCUMENTATION
class TestSortList(unittest.TestCase):
def test_simple_pass(self):
data = {"param_str": "string"}
aav = AnsibleArgSpecValidator(
data=data,
schema=DOCUMENTATION,
schema_format="doc",
schema_conditionals={},
name="test_action",
)
valid, errors = aav.validate()
self.assertTrue(valid)
self.assertEqual(errors, None)
def test_simple_fail(self):
data = {}
aav = AnsibleArgSpecValidator(
data=data,
schema=DOCUMENTATION,
schema_format="doc",
schema_conditionals={},
name="test_action",
)
valid, errors = aav.validate()
self.assertFalse(valid)
self.assertIn("missing required arguments: param_str", errors)
def test_simple_fail_no_name(self):
data = {}
aav = AnsibleArgSpecValidator(
data=data,
schema=DOCUMENTATION,
schema_format="doc",
schema_conditionals={},
)
valid, errors = aav.validate()
self.assertFalse(valid)
self.assertIn("missing required arguments: param_str", errors)
def test_not_doc(self):
data = {"param_str": "string"}
aav = AnsibleArgSpecValidator(
data=data,
schema={"argument_spec": {"param_str": {"type": "str"}}},
schema_format="argspec",
name="test_action",
)
valid, errors = aav.validate()
self.assertTrue(valid)
self.assertEqual(errors, None)
def test_schema_conditional(self):
data = {"param_str": "string"}
aav = AnsibleArgSpecValidator(
data=data,
schema=DOCUMENTATION,
schema_format="doc",
schema_conditionals={
"required_together": [["param_str", "param_bool"]]
},
name="test_action",
)
valid, errors = aav.validate()
self.assertFalse(valid)
self.assertIn(
"parameters are required together: param_str, param_bool", errors
)
def test_unsupported_param(self):
data = {"param_str": "string", "not_valid": "string"}
aav = AnsibleArgSpecValidator(
data=data,
schema=DOCUMENTATION,
schema_format="doc",
name="test_action",
# other_args={'bypass_checks': True},
)
valid, errors = aav.validate()
self.assertFalse(valid)
self.assertIn(
"Unsupported parameters for 'test_action' module: not_valid",
errors,
)
def test_other_args(self):
data = {}
aav = AnsibleArgSpecValidator(
data=data,
schema=DOCUMENTATION,
schema_format="doc",
name="test_action",
other_args={"bypass_checks": True},
)
valid, errors = aav.validate()
self.assertTrue(valid)
self.assertIsNone(errors)
def test_invalid_spec(self):
data = {}
aav = AnsibleArgSpecValidator(
data=data,
schema={"not_valid": True},
schema_format="argspec",
name="test_action",
other_args={"bypass_checks": True},
)
valid, errors = aav.validate()
self.assertFalse(valid)
self.assertIn("Invalid keys found: not_valid", errors)

View File

@ -0,0 +1,146 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Red Hat
# GNU General Public License v3.0+
# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import unittest
from ansible_collections.ansible.utils.plugins.module_utils.common.utils import (
dict_merge,
)
class TestDict_merge(unittest.TestCase):
def test_not_dict_base(self):
base = [0]
other = {"a": "b"}
expected = "must be of type <dict>"
with self.assertRaises(Exception) as exc:
dict_merge(base, other)
self.assertIn(expected, str(exc.exception))
def test_not_dict_other(self):
base = {"a": "b"}
other = [0]
expected = "must be of type <dict>"
with self.assertRaises(Exception) as exc:
dict_merge(base, other)
self.assertIn(expected, str(exc.exception))
def test_simple(self):
base = {"a": "b"}
other = {"c": "d"}
expected = {"a": "b", "c": "d"}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_simple_other_is_string(self):
base = {"a": "b"}
other = {"a": "c"}
expected = {"a": "c"}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_simple_other_is_none(self):
base = {"a": "b"}
other = {"a": None}
expected = {"a": None}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_list_value_same(self):
base = {"a": [2, 1, 0]}
other = {"a": [0, 1, 2]}
expected = {"a": [0, 1, 2]}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_list_value_combine(self):
base = {"a": [2, 1, 0]}
other = {"a": [0, 3]}
expected = {"a": [2, 1, 0, 3]}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_list_missing_from_other(self):
base = {"a": [2, 1, 0]}
other = {"b": [2, 1, 0]}
expected = {"a": [2, 1, 0], "b": [2, 1, 0]}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_list_other_none(self):
base = {"a": [2, 1, 0]}
other = {"a": None}
expected = {"a": None}
result = dict_merge(base, other)
self.assertEqual(result, expected)
# modified dict merge from netcommon so this works
def test_list_other_dict(self):
base = {"a": [2, 1, 0]}
other = {"a": {"b": "c"}}
expected = {"a": {"b": "c"}}
result = dict_merge(base, other)
self.assertEqual(result, expected)
# modified dict merge from netcommon so this works
def test_list_other_string(self):
base = {"a": [2, 1, 0]}
other = {"a": "xyz"}
expected = {"a": "xyz"}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_dict_of_dict_merged(self):
base = {"a": {"b": 0}}
other = {"a": {"c": 1}}
expected = {"a": {"b": 0, "c": 1}}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_dict_of_dict_replaced(self):
base = {"a": {"b": 0}}
other = {"a": {"b": 1}}
expected = {"a": {"b": 1}}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_dict_of_dict_replaced_other_none(self):
base = {"a": {"b": 0}}
other = {"a": None}
expected = {"a": None}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_dict_of_dict_replaced_other_string(self):
base = {"a": {"b": 0}}
other = {"a": "xyz"}
expected = {"a": "xyz"}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_dict_of_dict_replaced_other_missing(self):
base = {"a": {"b": 0}}
other = {"c": 1}
expected = {"a": {"b": 0}, "c": 1}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_not_list_or_dict_different(self):
base = {"a": 0}
other = {"a": 1}
expected = {"a": 1}
result = dict_merge(base, other)
self.assertEqual(result, expected)
def test_not_list_or_dict_same(self):
base = {"a": 0}
other = {"a": 0}
expected = {"a": 0}
result = dict_merge(base, other)
self.assertEqual(result, expected)

View File

@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Red Hat
# GNU General Public License v3.0+
# (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import unittest
from ansible_collections.ansible.utils.plugins.module_utils.common.utils import (
sort_list,
)
class TestSortList(unittest.TestCase):
def test_simple(self):
var = [3, 2, 1]
result = sort_list(var)
expected = [1, 2, 3]
self.assertEqual(result, expected)
def test_mot_list(self):
var = {"a": "b"}
result = sort_list(var)
self.assertEqual(result, var)
def test_not_same(self):
var = [{"a": "b", "c": "d"}, {"b": "a"}]
expected = "dictionaries do not match"
with self.assertRaises(Exception) as exc:
sort_list(var)
self.assertIn(expected, str(exc.exception))
def test_pass(self):
var = [{"a": 2, "b": 3}, {"a": 0, "b": 1}]
expected = [{"a": 0, "b": 1}, {"a": 2, "b": 3}]
result = sort_list(var)
self.assertEqual(result, expected)