import dataclasses import ipaddress import itertools import sys import typing from typing import Annotated, Any import dns import dns.rdata import dns.rdataclass import dns.rdatatype import dns.rdtypes.ANY.CNAME import dns.rdtypes.ANY.MX import dns.rdtypes.ANY.NS import dns.rdtypes.ANY.PTR import dns.rdtypes.ANY.SOA import dns.rdtypes.ANY.SPF import dns.rdtypes.ANY.TXT import dns.rdtypes.IN.A import dns.rdtypes.IN.AAAA import dns.rdtypes.IN.SRV import dns.reversename import dns.serial import dns.zone from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.common.validation import check_type_list class RName(dns.name.Name): """Domain name used to represent an e-mail address (see RFC 1035).""" def __init__(self, address): try: local, domain = address.split("@") except ValueError: raise ValueError( "Invalid e-mail address format: {}".format(address) ) super().__init__( (local,) + dns.name.from_text(domain, origin=dns.name.empty).labels ) class MultiRecords: """Annotation used to indicate that a field can be filled in more than once via a list, and that this will create as many records as values. """ ... @dataclasses.dataclass class A: address: str name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.IN.A.A( dns.rdataclass.IN.IN, dns.rdatatype.A, self.address ) @dataclasses.dataclass class AAAA: address: str name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.IN.AAAA.AAAA( dns.rdataclass.IN.IN, dns.rdatatype.AAAA, self.address ) @dataclasses.dataclass class PTR: target: dns.name.Name name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.PTR.PTR( dns.rdataclass.IN.IN, dns.rdatatype.PTR, self.target ) @dataclasses.dataclass class CNAME: target: dns.name.Name name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.CNAME.CNAME( dns.rdataclass.IN.IN, dns.rdatatype.CNAME, self.target ) @dataclasses.dataclass class MX: exchange: Annotated[dns.name.Name, MultiRecords] name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty preference: int = 10 def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.MX.MX( dns.rdataclass.IN.IN, dns.rdatatype.MX, self.preference, self.exchange, ) @dataclasses.dataclass class NS: target: Annotated[dns.name.Name, MultiRecords] name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.NS.NS( dns.rdataclass.IN.IN, dns.rdatatype.NS, self.target ) @dataclasses.dataclass class SPF: data: str name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.SPF.SPF( dns.rdataclass.IN.IN, dns.rdatatype.SPF, self.data ) @dataclasses.dataclass class TXT: data: str name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.TXT.TXT( dns.rdataclass.IN.IN, dns.rdatatype.TXT, self.data ) @dataclasses.dataclass class SRV: target: Annotated[dns.name.Name, MultiRecords] weight: int port: int priority: int = 10 name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.IN.SRV.SRV( dns.rdataclass.IN.IN, dns.rdatatype.SRV, self.priority, self.weight, self.port, self.target, ) @dataclasses.dataclass class SOA: mname: dns.name.Name rname: RName refresh: int retry: int expire: int minimum: int serial: int = 1 name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty def rdata(self) -> dns.rdata.Rdata: return dns.rdtypes.ANY.SOA.SOA( dns.rdataclass.IN.IN, dns.rdatatype.SOA, self.mname, self.rname, self.serial, self.refresh, self.retry, self.expire, self.minimum, ) def has_annotation(ty, annotation): """Is the type `ty` annotated with a given `annotation`.""" return ( typing.get_origin(ty) == typing.Annotated and annotation in typing.get_args(ty)[1:] ) def annotated_origin(ty): """Returns the origin of an annotated type `ty`.""" assert typing.get_origin(ty) == typing.Annotated return typing.get_args(ty)[0] def is_multi_records(ty): """Is the type `ty` annotated with `MultiRecords`.""" return has_annotation(ty, MultiRecords) def spec_option_of_field(field): types = { str: "str", dns.name.Name: "str", RName: "str", int: "int", } if is_multi_records(field.type): option = { "type": "list", "elements": types[annotated_origin(field.type)], } else: option = {"type": types[field.type]} option["required"] = field.default is dataclasses.MISSING return option def spec_options_of_type(ty): """Convert a `dataclass` type to Ansible `argument_spec` `options`' format.""" return { field.name: spec_option_of_field(field) for field in dataclasses.fields(ty) } def coerce_dns_name(value: Any) -> dns.name.Name: """Try to convert a `value` to `dns.name.Name`.""" if not isinstance(value, dns.name.Name): return dns.name.from_text(value, origin=dns.name.empty) return value def product_dict(dct, keys=None): """Compute the "cartesian product" of a dictionnary `dct` w.r.t some `keys` (if `keys` is None, then the product is computed on all the keys).""" if keys is None: keys = dct.keys() wrapped = {k: v if k in keys else [v] for k, v in dct.items()} for values in itertools.product(*wrapped.values()): yield dict(zip(wrapped.keys(), values)) def make_hosts_records(hosts): for host, addrs in hosts.items(): for addr in check_type_list(addrs): name = dns.name.from_text(host, origin=dns.name.empty) decoded = ipaddress.ip_address(addr) yield AAAA(addr, name) if decoded.version == 6 else A(addr, name) def make_reverse_hosts_records(hosts): for host, addrs in hosts.items(): for addr in check_type_list(addrs): name = dns.name.from_text(host) reverse = dns.reversename.from_address(addr) yield PTR(name, reverse) def make_records(args, ty): # TODO: Ça n'est pas du tout élégant, mais : # 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers # 2. Ansible positionne à `None` les entrées non passées à la tâche et # ce comportement ne semble pas modifiable types = {f.name: f.type for f in dataclasses.fields(ty)} coercers = { dns.name.Name: coerce_dns_name, RName: RName, } def coerce_single(value, ty): if ty in coercers: return coercers[ty](value) return value def coerce(name, value): if is_multi_records(types[name]): origin = annotated_origin(types[name]) return [coerce_single(v, origin) for v in value] return coerce_single(value, types[name]) clean_args = { name: coerce(name, value) for name, value in args.items() if value is not None } multi_keys = {k for k, v in types.items() if is_multi_records(v)} for single_args in product_dict(clean_args, multi_keys): yield ty(**single_args) def zones_eq(lhs: dns.zone.Zone, rhs: dns.zone.Zone) -> bool: """Returns a `bool` indicating whether two `dns.zone.Zone`s are equal w.r.t. their text representation.""" return lhs.to_text(relativize=False, sorted=True) == rhs.to_text( relativize=False, sorted=True ) def write_text_file(path, text, module): """Naive text file write function with support for Ansible's diff and check modes.""" diff_text = { "before_header": f"{path} (content)", "after_header": f"{path} (content)", "after": text, } try: with open(path) as f: current = f.read() changed = text != current diff_text["before"] = current except Exception: changed = True diff_text["before"] = None if changed and not module.check_mode: with open(path, "w") as f: f.write(text) file_args = module.load_file_common_arguments(module.params) diff_attrs = { "before_header": f"{path} (attributes)", "after_header": f"{path} (attributes)", } changed = module.set_file_attributes_if_different( file_args, changed, diff_attrs ) return changed, [diff_text, diff_attrs] def main() -> int: record_types = { "ns": NS, "txt": TXT, "a": A, "aaaa": AAAA, "srv": SRV, "spf": SPF, "ptr": PTR, "cname": CNAME, "mx": MX, } module_args = { "path": {"type": "str", "required": True}, "origin": {"type": "str", "required": True}, "soa": { "type": "dict", "required": True, "options": spec_options_of_type(SOA), }, "hosts": {"type": "dict", "default": {}}, "reverse_hosts": {"type": "dict", "default": {}}, } for name, ty in record_types.items(): module_args[name] = { "type": "list", "default": [], "elements": "dict", "options": spec_options_of_type(ty), } module = AnsibleModule( argument_spec=module_args, add_file_common_args=True, supports_check_mode=True, ) origin = dns.name.from_text(module.params["origin"]) path = module.params["path"] zone = dns.zone.Zone(origin) records = itertools.chain( make_records(module.params["soa"], SOA), make_reverse_hosts_records(module.params["reverse_hosts"]), make_hosts_records(module.params["hosts"]), itertools.chain.from_iterable( itertools.chain.from_iterable( make_records(args, ty) for args in module.params[name] ) for name, ty in record_types.items() ), ) for record in records: node = zone.get_node(record.name, create=True) rdata = record.rdata() dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True) dataset.add(rdata) zone_text = zone.to_text(relativize=False, sorted=True) changed, diff = write_text_file(path, zone_text, module) module.exit_json(changed=changed, diff=diff) return 0 if __name__ == "__main__": sys.exit(main())