From 4dbe0e562de0175ab6d47bd0d28f05ca72e757ab Mon Sep 17 00:00:00 2001 From: Jeltz Date: Wed, 17 Aug 2022 18:23:47 +0200 Subject: [PATCH] dns_zone: cleanup + hosts + product --- library/dns_zone.py | 158 +++++++++++++++++++++++++++++++++----------- 1 file changed, 119 insertions(+), 39 deletions(-) diff --git a/library/dns_zone.py b/library/dns_zone.py index 033eff2..825a0d7 100755 --- a/library/dns_zone.py +++ b/library/dns_zone.py @@ -1,26 +1,30 @@ #!/usr/bin/env python3 -import itertools import dataclasses - -from typing import Any +import ipaddress +import itertools +import sys +import typing +from typing import Annotated, Any import dns -import dns.serial -import dns.zone import dns.rdata import dns.rdataclass import dns.rdatatype -import dns.rdtypes.IN.A -import dns.rdtypes.IN.AAAA +import dns.rdtypes.ANY.CNAME import dns.rdtypes.ANY.MX -import dns.rdtypes.ANY.SOA import dns.rdtypes.ANY.NS +import dns.rdtypes.ANY.SOA import dns.rdtypes.ANY.TXT - +import dns.rdtypes.IN.A +import dns.rdtypes.IN.AAAA +import dns.serial +import dns.zone from ansible.module_utils.basic import AnsibleModule class RName(dns.name.Name): + """Domain name used to represent an e-mail address (see RFC 1035).""" + def __init__(self, address): try: local, domain = address.split("@") @@ -31,10 +35,18 @@ class RName(dns.name.Name): super().__init__((local,) + dns.name.from_text(domain).labels) +class MultiRecords: + """Annotation used to indicate that a field can be filled in more than + once via a list, and that this will create as many records as values. + """ + + ... + + @dataclasses.dataclass class A: address: str - name: dns.name.Name = dns.name.empty + name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.IN.A.A( @@ -45,7 +57,7 @@ class A: @dataclasses.dataclass class AAAA: address: str - name: dns.name.Name = dns.name.empty + name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.IN.AAAA.AAAA( @@ -56,7 +68,7 @@ class AAAA: @dataclasses.dataclass class CNAME: address: dns.name.Name - name: dns.name.Name = dns.name.empty + name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.CNAME.CNAME( @@ -66,8 +78,8 @@ class CNAME: @dataclasses.dataclass class MX: - exchange: dns.name.Name - name: dns.name.Name = dns.name.empty + exchange: Annotated[dns.name.Name, MultiRecords] + name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty priority: int = 10 def rdata(self) -> dns.rdata.Rdata: @@ -81,8 +93,8 @@ class MX: @dataclasses.dataclass class NS: - address: dns.name.Name - name: dns.name.Name = dns.name.empty + address: Annotated[dns.name.Name, MultiRecords] + name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.NS.NS( @@ -93,7 +105,7 @@ class NS: @dataclasses.dataclass class TXT: data: str - name: dns.name.Name = dns.name.empty + name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.TXT.TXT( @@ -110,7 +122,7 @@ class SOA: expire: int ttl: int serial: int = 1 - name: dns.name.Name = dns.name.empty + name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.SOA.SOA( @@ -126,6 +138,25 @@ class SOA: ) +def has_annotation(ty, annotation): + """Is the type `ty` annotated with a given `annotation`.""" + return ( + typing.get_origin(ty) == typing.Annotated + and annotation in typing.get_args(ty)[1:] + ) + + +def annotated_origin(ty): + """Returns the origin of an annotated type `ty`.""" + assert typing.get_origin(ty) == typing.Annotated + return typing.get_args(ty)[0] + + +def is_multi_records(ty): + """Is the type `ty` annotated with `MultiRecords`.""" + return has_annotation(ty, MultiRecords) + + def spec_option_of_field(field): types = { str: "str", @@ -133,13 +164,20 @@ def spec_option_of_field(field): RName: "str", int: "int", } - return { - "type": types[field.type], - "required": field.default is dataclasses.MISSING, - } + if is_multi_records(field.type): + option = { + "type": "list", + "elements": types[annotated_origin(field.type)], + } + else: + option = {"type": types[field.type]} + option["required"] = field.default is dataclasses.MISSING + return option def spec_options_of_type(ty): + """Convert a `dataclass` type to Ansible `argument_spec` `options`' + format.""" return { field.name: spec_option_of_field(field) for field in dataclasses.fields(ty) @@ -147,12 +185,32 @@ def spec_options_of_type(ty): def coerce_dns_name(value: Any) -> dns.name.Name: + """Try to convert a `value` to `dns.name.Name`.""" if not isinstance(value, dns.name.Name): return dns.name.from_text(value, origin=dns.name.empty) return value -def make_record(args, ty): +def product_dict(dct, keys=None): + """Compute the "cartesian product" of a dictionnary `dct` + w.r.t some `keys` (if `keys` is None, then the product is computed + on all the keys).""" + if keys is None: + keys = dct.keys() + wrapped = {k: v if k in keys else [v] for k, v in dct.items()} + for values in itertools.product(*wrapped.values()): + yield dict(zip(wrapped.keys(), values)) + + +def make_hosts_records(hosts): + for host, addrs in hosts.items(): + for addr in addrs: + name = dns.name.from_text(host, origin=dns.name.empty) + decoded = ipaddress.ip_address(addr) + yield AAAA(addr, name) if decoded.version == 6 else A(addr, name) + + +def make_records(args, ty): # TODO: Ça n'est pas du tout élégant, mais : # 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers # 2. Ansible positionne à `None` les entrées non passées à la tâche et @@ -163,10 +221,16 @@ def make_record(args, ty): RName: RName, } + def coerce_single(value, ty): + if ty in coercers: + return coercers[ty](value) + return value + def coerce(name, value): - if types[name] not in coercers: - return value - return coercers[types[name]](value) + if is_multi_records(types[name]): + origin = annotated_origin(types[name]) + return [coerce_single(v, origin) for v in value] + return coerce_single(value, types[name]) clean_args = { name: coerce(name, value) @@ -174,11 +238,16 @@ def make_record(args, ty): if value is not None } - return ty(**clean_args) + multi_keys = (k for k, v in types.items() if is_multi_records(v)) + for single_args in product_dict(clean_args, multi_keys): + yield ty(**single_args) -def zones_eq(a: dns.zone.Zone, b: dns.zone.Zone) -> bool: - return a.to_text(relativize=False) == b.to_text(relativize=False) + +def zones_eq(lhs: dns.zone.Zone, rhs: dns.zone.Zone) -> bool: + """Returns a `bool` indicating whether two `dns.zone.Zone`s are equal + w.r.t. their text representation.""" + return lhs.to_text(relativize=False) == rhs.to_text(relativize=False) def main() -> int: @@ -188,17 +257,19 @@ def main() -> int: "txt": TXT, "a": A, "aaaa": AAAA, + "cname": CNAME, "mx": MX, } module_args = { - "path": {"type": "path", "required": True}, + "path": {"type": "str", "required": True}, "origin": {"type": "str", "required": True}, "soa": { "type": "dict", "required": True, "options": spec_options_of_type(SOA), }, + "hosts": {"type": "dict", "default": {}}, } for name, ty in record_types.items(): @@ -209,7 +280,10 @@ def main() -> int: "options": spec_options_of_type(ty), } - module = AnsibleModule(argument_spec=module_args) + module = AnsibleModule( + argument_spec=module_args, + add_file_common_args=True, + ) origin = dns.name.from_text(module.params["origin"]) path = module.params["path"] @@ -218,16 +292,18 @@ def main() -> int: try: current = dns.zone.from_file(path, origin=origin) - except: + except Exception: current = None - records = [make_record(module.params["soa"], SOA)] - - records.extend( + records = itertools.chain( + make_records(module.params["soa"], SOA), + make_hosts_records(module.params["hosts"]), itertools.chain.from_iterable( - (make_record(args, ty) for args in module.params[name]) + itertools.chain.from_iterable( + make_records(args, ty) for args in module.params[name] + ) for name, ty in record_types.items() - ) + ), ) for record in records: @@ -236,9 +312,13 @@ def main() -> int: dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True) dataset.add(rdata) + file_args = module.load_file_common_arguments(module.params) + changed = current is None or not zones_eq(zone, current) if changed: - zone.to_file(path, relativize=False) + zone.to_file(module.params["path"], relativize=False) + + changed = module.set_fs_attributes_if_different(file_args, changed) module.exit_json(changed=changed) @@ -246,4 +326,4 @@ def main() -> int: if __name__ == "__main__": - exit(main()) + sys.exit(main())