fixed ipcut to support ipv4 (#358)

* fixed ipcut to support ipv4

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add changelog

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ruchi Pakhle <72685035+Ruchip16@users.noreply.github.com>
Co-authored-by: Ruchi Pakhle <ruchipakhle@gmail.com>
pull/376/head
Andrew J Roth 2024-08-02 06:39:34 -04:00 committed by GitHub
parent 1df7af5ab2
commit 038bcb163e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 16 deletions

View File

@ -0,0 +1,3 @@
---
minor_changes:
- Previously, the ansible.utils.ipcut filter only supported IPv6 addresses, leading to confusing error messages when used with IPv4 addresses. This fix ensures that the filter now appropriately handles both IPv4 and IPv6 addresses.

View File

@ -105,10 +105,15 @@ def _ipcut(*args, **kwargs):
def ipcut(value, amount):
ipv6_oct = []
try:
ip = netaddr.IPAddress(value)
ipv6address = ip.bits().replace(":", "")
if ip.version == 6:
ip_bits = ip.bits().replace(":", "")
elif ip.version == 4:
ip_bits = ip.bits().replace(".", "")
else:
msg = "Unknown IP Address Version: {0}".format(ip.version)
raise AnsibleFilterError(msg)
except (netaddr.AddrFormatError, ValueError):
msg = "You must pass a valid IP address; {0} is invalid".format(value)
raise AnsibleFilterError(msg)
@ -120,20 +125,27 @@ def ipcut(value, amount):
raise AnsibleFilterError(msg)
else:
if amount < 0:
ipsub = ipv6address[amount:]
ipsub = ip_bits[amount:]
else:
ipsub = ipv6address[0:amount]
ipsub = ip_bits[0:amount]
ipsubfinal = []
for i in range(0, len(ipsub), 16):
oct_sub = i + 16
ipsubfinal.append(ipsub[i:oct_sub])
for i in ipsubfinal:
x = hex(int(i, 2))
ipv6_oct.append(x.replace("0x", ""))
return str(":".join(ipv6_oct))
if ip.version == 6:
ipv4_oct = []
for i in range(0, len(ipsub), 16):
oct_sub = i + 16
ipv4_oct.append(
hex(int(ipsub[i:oct_sub], 2)).replace("0x", ""),
)
result = str(":".join(ipv4_oct))
else: # ip.version == 4:
ipv4_oct = []
for i in range(0, len(ipsub), 8):
oct_sub = i + 8
ipv4_oct.append(
str(int(ipsub[i:oct_sub], 2)),
)
result = str(".".join(ipv4_oct))
return result
class FilterModule(object):

View File

@ -21,16 +21,30 @@ class TestIpCut(TestCase):
def setUp(self):
pass
def test_get_last_X_bits(self):
def test_get_last_X_bits_ipv6(self):
"""Get last X bits of Ipv6 address"""
args = ["", "1234:4321:abcd:dcba::17", -80]
result = _ipcut(*args)
self.assertEqual(result, "dcba:0:0:0:17")
def test_get_first_X_bits(self):
def test_get_first_X_bits_ipv6(self):
"""Get first X bits of Ipv6 address"""
args = ["", "1234:4321:abcd:dcba::17", 64]
result = _ipcut(*args)
self.assertEqual(result, "1234:4321:abcd:dcba")
def test_get_last_X_bits_ipv4(self):
"""Get last X bits of Ipv4 address"""
args = ["", "10.2.3.0", -16]
result = _ipcut(*args)
self.assertEqual(result, "3.0")
def test_get_first_X_bits_ipv4(self):
"""Get first X bits of Ipv4 address"""
args = ["", "10.2.3.0", 24]
result = _ipcut(*args)
self.assertEqual(result, "10.2.3")