From 038bcb163ea60c094ba786a6a54f694a8a9ff371 Mon Sep 17 00:00:00 2001 From: Andrew J Roth Date: Fri, 2 Aug 2024 06:39:34 -0400 Subject: [PATCH] fixed ipcut to support ipv4 (#358) * fixed ipcut to support ipv4 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add changelog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ruchi Pakhle <72685035+Ruchip16@users.noreply.github.com> Co-authored-by: Ruchi Pakhle --- changelogs/fragments/358_ipcut.yaml | 3 ++ plugins/filter/ipcut.py | 40 ++++++++++++++++--------- tests/unit/plugins/filter/test_ipcut.py | 18 +++++++++-- 3 files changed, 45 insertions(+), 16 deletions(-) create mode 100644 changelogs/fragments/358_ipcut.yaml diff --git a/changelogs/fragments/358_ipcut.yaml b/changelogs/fragments/358_ipcut.yaml new file mode 100644 index 0000000..684c295 --- /dev/null +++ b/changelogs/fragments/358_ipcut.yaml @@ -0,0 +1,3 @@ +--- +minor_changes: + - Previously, the ansible.utils.ipcut filter only supported IPv6 addresses, leading to confusing error messages when used with IPv4 addresses. This fix ensures that the filter now appropriately handles both IPv4 and IPv6 addresses. diff --git a/plugins/filter/ipcut.py b/plugins/filter/ipcut.py index 4e61a45..b8d47a7 100644 --- a/plugins/filter/ipcut.py +++ b/plugins/filter/ipcut.py @@ -105,10 +105,15 @@ def _ipcut(*args, **kwargs): def ipcut(value, amount): - ipv6_oct = [] try: ip = netaddr.IPAddress(value) - ipv6address = ip.bits().replace(":", "") + if ip.version == 6: + ip_bits = ip.bits().replace(":", "") + elif ip.version == 4: + ip_bits = ip.bits().replace(".", "") + else: + msg = "Unknown IP Address Version: {0}".format(ip.version) + raise AnsibleFilterError(msg) except (netaddr.AddrFormatError, ValueError): msg = "You must pass a valid IP address; {0} is invalid".format(value) raise AnsibleFilterError(msg) @@ -120,20 +125,27 @@ def ipcut(value, amount): raise AnsibleFilterError(msg) else: if amount < 0: - ipsub = ipv6address[amount:] + ipsub = ip_bits[amount:] else: - ipsub = ipv6address[0:amount] + ipsub = ip_bits[0:amount] - ipsubfinal = [] - - for i in range(0, len(ipsub), 16): - oct_sub = i + 16 - ipsubfinal.append(ipsub[i:oct_sub]) - - for i in ipsubfinal: - x = hex(int(i, 2)) - ipv6_oct.append(x.replace("0x", "")) - return str(":".join(ipv6_oct)) + if ip.version == 6: + ipv4_oct = [] + for i in range(0, len(ipsub), 16): + oct_sub = i + 16 + ipv4_oct.append( + hex(int(ipsub[i:oct_sub], 2)).replace("0x", ""), + ) + result = str(":".join(ipv4_oct)) + else: # ip.version == 4: + ipv4_oct = [] + for i in range(0, len(ipsub), 8): + oct_sub = i + 8 + ipv4_oct.append( + str(int(ipsub[i:oct_sub], 2)), + ) + result = str(".".join(ipv4_oct)) + return result class FilterModule(object): diff --git a/tests/unit/plugins/filter/test_ipcut.py b/tests/unit/plugins/filter/test_ipcut.py index c086833..4a66878 100644 --- a/tests/unit/plugins/filter/test_ipcut.py +++ b/tests/unit/plugins/filter/test_ipcut.py @@ -21,16 +21,30 @@ class TestIpCut(TestCase): def setUp(self): pass - def test_get_last_X_bits(self): + def test_get_last_X_bits_ipv6(self): """Get last X bits of Ipv6 address""" args = ["", "1234:4321:abcd:dcba::17", -80] result = _ipcut(*args) self.assertEqual(result, "dcba:0:0:0:17") - def test_get_first_X_bits(self): + def test_get_first_X_bits_ipv6(self): """Get first X bits of Ipv6 address""" args = ["", "1234:4321:abcd:dcba::17", 64] result = _ipcut(*args) self.assertEqual(result, "1234:4321:abcd:dcba") + + def test_get_last_X_bits_ipv4(self): + """Get last X bits of Ipv4 address""" + + args = ["", "10.2.3.0", -16] + result = _ipcut(*args) + self.assertEqual(result, "3.0") + + def test_get_first_X_bits_ipv4(self): + """Get first X bits of Ipv4 address""" + + args = ["", "10.2.3.0", 24] + result = _ipcut(*args) + self.assertEqual(result, "10.2.3")