ansible/roles/knotd/library/dns_zone.py
2025-01-01 14:15:25 +01:00

418 lines
11 KiB
Python
Executable file

import dataclasses
import ipaddress
import itertools
import sys
import typing
from typing import Annotated, Any
import dns
import dns.rdata
import dns.rdataclass
import dns.rdatatype
import dns.rdtypes.ANY.CNAME
import dns.rdtypes.ANY.MX
import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.PTR
import dns.rdtypes.ANY.SOA
import dns.rdtypes.ANY.SPF
import dns.rdtypes.ANY.TXT
import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA
import dns.rdtypes.IN.SRV
import dns.reversename
import dns.serial
import dns.zone
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.validation import check_type_list
class RName(dns.name.Name):
"""Domain name used to represent an e-mail address (see RFC 1035)."""
def __init__(self, address):
try:
local, domain = address.split("@")
except ValueError:
raise ValueError(
"Invalid e-mail address format: {}".format(address)
)
super().__init__(
(local,) + dns.name.from_text(domain, origin=dns.name.empty).labels
)
class MultiRecords:
"""Annotation used to indicate that a field can be filled in more than
once via a list, and that this will create as many records as values.
"""
...
@dataclasses.dataclass
class A:
address: str
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.IN.A.A(
dns.rdataclass.IN.IN, dns.rdatatype.A, self.address
)
@dataclasses.dataclass
class AAAA:
address: str
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.IN.AAAA.AAAA(
dns.rdataclass.IN.IN, dns.rdatatype.AAAA, self.address
)
@dataclasses.dataclass
class PTR:
target: dns.name.Name
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.PTR.PTR(
dns.rdataclass.IN.IN, dns.rdatatype.PTR, self.target
)
@dataclasses.dataclass
class CNAME:
target: dns.name.Name
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.CNAME.CNAME(
dns.rdataclass.IN.IN, dns.rdatatype.CNAME, self.target
)
@dataclasses.dataclass
class MX:
exchange: Annotated[dns.name.Name, MultiRecords]
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
preference: int = 10
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.MX.MX(
dns.rdataclass.IN.IN,
dns.rdatatype.MX,
self.preference,
self.exchange,
)
@dataclasses.dataclass
class NS:
target: Annotated[dns.name.Name, MultiRecords]
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.NS.NS(
dns.rdataclass.IN.IN, dns.rdatatype.NS, self.target
)
@dataclasses.dataclass
class SPF:
data: str
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.SPF.SPF(
dns.rdataclass.IN.IN, dns.rdatatype.SPF, self.data
)
@dataclasses.dataclass
class TXT:
data: str
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.TXT.TXT(
dns.rdataclass.IN.IN, dns.rdatatype.TXT, self.data
)
@dataclasses.dataclass
class SRV:
target: Annotated[dns.name.Name, MultiRecords]
weight: int
port: int
priority: int = 10
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.IN.SRV.SRV(
dns.rdataclass.IN.IN,
dns.rdatatype.SRV,
self.priority,
self.weight,
self.port,
self.target,
)
@dataclasses.dataclass
class SOA:
mname: dns.name.Name
rname: RName
refresh: int
retry: int
expire: int
minimum: int
serial: int = 1
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.SOA.SOA(
dns.rdataclass.IN.IN,
dns.rdatatype.SOA,
self.mname,
self.rname,
self.serial,
self.refresh,
self.retry,
self.expire,
self.minimum,
)
def has_annotation(ty, annotation):
"""Is the type `ty` annotated with a given `annotation`."""
return (
typing.get_origin(ty) == typing.Annotated
and annotation in typing.get_args(ty)[1:]
)
def annotated_origin(ty):
"""Returns the origin of an annotated type `ty`."""
assert typing.get_origin(ty) == typing.Annotated
return typing.get_args(ty)[0]
def is_multi_records(ty):
"""Is the type `ty` annotated with `MultiRecords`."""
return has_annotation(ty, MultiRecords)
def spec_option_of_field(field):
types = {
str: "str",
dns.name.Name: "str",
RName: "str",
int: "int",
}
if is_multi_records(field.type):
option = {
"type": "list",
"elements": types[annotated_origin(field.type)],
}
else:
option = {"type": types[field.type]}
option["required"] = field.default is dataclasses.MISSING
return option
def spec_options_of_type(ty):
"""Convert a `dataclass` type to Ansible `argument_spec` `options`'
format."""
return {
field.name: spec_option_of_field(field)
for field in dataclasses.fields(ty)
}
def coerce_dns_name(value: Any) -> dns.name.Name:
"""Try to convert a `value` to `dns.name.Name`."""
if not isinstance(value, dns.name.Name):
return dns.name.from_text(value, origin=dns.name.empty)
return value
def product_dict(dct, keys=None):
"""Compute the "cartesian product" of a dictionnary `dct`
w.r.t some `keys` (if `keys` is None, then the product is computed
on all the keys)."""
if keys is None:
keys = dct.keys()
wrapped = {k: v if k in keys else [v] for k, v in dct.items()}
for values in itertools.product(*wrapped.values()):
yield dict(zip(wrapped.keys(), values))
def make_hosts_records(hosts):
for host, addrs in hosts.items():
for addr in check_type_list(addrs):
name = dns.name.from_text(host, origin=dns.name.empty)
decoded = ipaddress.ip_address(addr)
yield AAAA(addr, name) if decoded.version == 6 else A(addr, name)
def make_reverse_hosts_records(hosts):
for host, addrs in hosts.items():
for addr in check_type_list(addrs):
name = dns.name.from_text(host)
reverse = dns.reversename.from_address(addr)
yield PTR(name, reverse)
def make_records(args, ty):
# TODO: Ça n'est pas du tout élégant, mais :
# 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers
# 2. Ansible positionne à `None` les entrées non passées à la tâche et
# ce comportement ne semble pas modifiable
types = {f.name: f.type for f in dataclasses.fields(ty)}
coercers = {
dns.name.Name: coerce_dns_name,
RName: RName,
}
def coerce_single(value, ty):
if ty in coercers:
return coercers[ty](value)
return value
def coerce(name, value):
if is_multi_records(types[name]):
origin = annotated_origin(types[name])
return [coerce_single(v, origin) for v in value]
return coerce_single(value, types[name])
clean_args = {
name: coerce(name, value)
for name, value in args.items()
if value is not None
}
multi_keys = {k for k, v in types.items() if is_multi_records(v)}
for single_args in product_dict(clean_args, multi_keys):
yield ty(**single_args)
def zones_eq(lhs: dns.zone.Zone, rhs: dns.zone.Zone) -> bool:
"""Returns a `bool` indicating whether two `dns.zone.Zone`s are equal
w.r.t. their text representation."""
return lhs.to_text(relativize=False, sorted=True) == rhs.to_text(
relativize=False, sorted=True
)
def write_text_file(path, text, module):
"""Naive text file write function with support for Ansible's diff and
check modes."""
diff_text = {
"before_header": f"{path} (content)",
"after_header": f"{path} (content)",
"after": text,
}
try:
with open(path) as f:
current = f.read()
changed = text != current
diff_text["before"] = current
except Exception:
changed = True
diff_text["before"] = None
if changed and not module.check_mode:
with open(path, "w") as f:
f.write(text)
file_args = module.load_file_common_arguments(module.params)
diff_attrs = {
"before_header": f"{path} (attributes)",
"after_header": f"{path} (attributes)",
}
changed = module.set_file_attributes_if_different(
file_args, changed, diff_attrs
)
return changed, [diff_text, diff_attrs]
def main() -> int:
record_types = {
"ns": NS,
"txt": TXT,
"a": A,
"aaaa": AAAA,
"srv": SRV,
"spf": SPF,
"ptr": PTR,
"cname": CNAME,
"mx": MX,
}
module_args = {
"path": {"type": "str", "required": True},
"origin": {"type": "str", "required": True},
"soa": {
"type": "dict",
"required": True,
"options": spec_options_of_type(SOA),
},
"hosts": {"type": "dict", "default": {}},
"reverse_hosts": {"type": "dict", "default": {}},
}
for name, ty in record_types.items():
module_args[name] = {
"type": "list",
"default": [],
"elements": "dict",
"options": spec_options_of_type(ty),
}
module = AnsibleModule(
argument_spec=module_args,
add_file_common_args=True,
supports_check_mode=True,
)
origin = dns.name.from_text(module.params["origin"])
path = module.params["path"]
zone = dns.zone.Zone(origin)
records = itertools.chain(
make_records(module.params["soa"], SOA),
make_reverse_hosts_records(module.params["reverse_hosts"]),
make_hosts_records(module.params["hosts"]),
itertools.chain.from_iterable(
itertools.chain.from_iterable(
make_records(args, ty) for args in module.params[name]
)
for name, ty in record_types.items()
),
)
for record in records:
node = zone.get_node(record.name, create=True)
rdata = record.rdata()
dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True)
dataset.add(rdata)
zone_text = zone.to_text(relativize=False, sorted=True)
changed, diff = write_text_file(path, zone_text, module)
module.exit_json(changed=changed, diff=diff)
return 0
if __name__ == "__main__":
sys.exit(main())