dns_zone: cleanup + hosts + product

This commit is contained in:
jeltz 2022-08-17 18:23:47 +02:00
parent c97dca8fa8
commit 4dbe0e562d
Signed by: jeltz
GPG key ID: 800882B66C0C3326

View file

@ -1,26 +1,30 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import itertools
import dataclasses import dataclasses
import ipaddress
from typing import Any import itertools
import sys
import typing
from typing import Annotated, Any
import dns import dns
import dns.serial
import dns.zone
import dns.rdata import dns.rdata
import dns.rdataclass import dns.rdataclass
import dns.rdatatype import dns.rdatatype
import dns.rdtypes.ANY.CNAME
import dns.rdtypes.ANY.MX
import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.SOA
import dns.rdtypes.ANY.TXT
import dns.rdtypes.IN.A import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA import dns.rdtypes.IN.AAAA
import dns.rdtypes.ANY.MX import dns.serial
import dns.rdtypes.ANY.SOA import dns.zone
import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.TXT
from ansible.module_utils.basic import AnsibleModule from ansible.module_utils.basic import AnsibleModule
class RName(dns.name.Name): class RName(dns.name.Name):
"""Domain name used to represent an e-mail address (see RFC 1035)."""
def __init__(self, address): def __init__(self, address):
try: try:
local, domain = address.split("@") local, domain = address.split("@")
@ -31,10 +35,18 @@ class RName(dns.name.Name):
super().__init__((local,) + dns.name.from_text(domain).labels) super().__init__((local,) + dns.name.from_text(domain).labels)
class MultiRecords:
"""Annotation used to indicate that a field can be filled in more than
once via a list, and that this will create as many records as values.
"""
...
@dataclasses.dataclass @dataclasses.dataclass
class A: class A:
address: str address: str
name: dns.name.Name = dns.name.empty name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata: def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.IN.A.A( return dns.rdtypes.IN.A.A(
@ -45,7 +57,7 @@ class A:
@dataclasses.dataclass @dataclasses.dataclass
class AAAA: class AAAA:
address: str address: str
name: dns.name.Name = dns.name.empty name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata: def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.IN.AAAA.AAAA( return dns.rdtypes.IN.AAAA.AAAA(
@ -56,7 +68,7 @@ class AAAA:
@dataclasses.dataclass @dataclasses.dataclass
class CNAME: class CNAME:
address: dns.name.Name address: dns.name.Name
name: dns.name.Name = dns.name.empty name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata: def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.CNAME.CNAME( return dns.rdtypes.ANY.CNAME.CNAME(
@ -66,8 +78,8 @@ class CNAME:
@dataclasses.dataclass @dataclasses.dataclass
class MX: class MX:
exchange: dns.name.Name exchange: Annotated[dns.name.Name, MultiRecords]
name: dns.name.Name = dns.name.empty name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
priority: int = 10 priority: int = 10
def rdata(self) -> dns.rdata.Rdata: def rdata(self) -> dns.rdata.Rdata:
@ -81,8 +93,8 @@ class MX:
@dataclasses.dataclass @dataclasses.dataclass
class NS: class NS:
address: dns.name.Name address: Annotated[dns.name.Name, MultiRecords]
name: dns.name.Name = dns.name.empty name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata: def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.NS.NS( return dns.rdtypes.ANY.NS.NS(
@ -93,7 +105,7 @@ class NS:
@dataclasses.dataclass @dataclasses.dataclass
class TXT: class TXT:
data: str data: str
name: dns.name.Name = dns.name.empty name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata: def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.TXT.TXT( return dns.rdtypes.ANY.TXT.TXT(
@ -110,7 +122,7 @@ class SOA:
expire: int expire: int
ttl: int ttl: int
serial: int = 1 serial: int = 1
name: dns.name.Name = dns.name.empty name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata: def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.SOA.SOA( return dns.rdtypes.ANY.SOA.SOA(
@ -126,6 +138,25 @@ class SOA:
) )
def has_annotation(ty, annotation):
"""Is the type `ty` annotated with a given `annotation`."""
return (
typing.get_origin(ty) == typing.Annotated
and annotation in typing.get_args(ty)[1:]
)
def annotated_origin(ty):
"""Returns the origin of an annotated type `ty`."""
assert typing.get_origin(ty) == typing.Annotated
return typing.get_args(ty)[0]
def is_multi_records(ty):
"""Is the type `ty` annotated with `MultiRecords`."""
return has_annotation(ty, MultiRecords)
def spec_option_of_field(field): def spec_option_of_field(field):
types = { types = {
str: "str", str: "str",
@ -133,13 +164,20 @@ def spec_option_of_field(field):
RName: "str", RName: "str",
int: "int", int: "int",
} }
return { if is_multi_records(field.type):
"type": types[field.type], option = {
"required": field.default is dataclasses.MISSING, "type": "list",
"elements": types[annotated_origin(field.type)],
} }
else:
option = {"type": types[field.type]}
option["required"] = field.default is dataclasses.MISSING
return option
def spec_options_of_type(ty): def spec_options_of_type(ty):
"""Convert a `dataclass` type to Ansible `argument_spec` `options`'
format."""
return { return {
field.name: spec_option_of_field(field) field.name: spec_option_of_field(field)
for field in dataclasses.fields(ty) for field in dataclasses.fields(ty)
@ -147,12 +185,32 @@ def spec_options_of_type(ty):
def coerce_dns_name(value: Any) -> dns.name.Name: def coerce_dns_name(value: Any) -> dns.name.Name:
"""Try to convert a `value` to `dns.name.Name`."""
if not isinstance(value, dns.name.Name): if not isinstance(value, dns.name.Name):
return dns.name.from_text(value, origin=dns.name.empty) return dns.name.from_text(value, origin=dns.name.empty)
return value return value
def make_record(args, ty): def product_dict(dct, keys=None):
"""Compute the "cartesian product" of a dictionnary `dct`
w.r.t some `keys` (if `keys` is None, then the product is computed
on all the keys)."""
if keys is None:
keys = dct.keys()
wrapped = {k: v if k in keys else [v] for k, v in dct.items()}
for values in itertools.product(*wrapped.values()):
yield dict(zip(wrapped.keys(), values))
def make_hosts_records(hosts):
for host, addrs in hosts.items():
for addr in addrs:
name = dns.name.from_text(host, origin=dns.name.empty)
decoded = ipaddress.ip_address(addr)
yield AAAA(addr, name) if decoded.version == 6 else A(addr, name)
def make_records(args, ty):
# TODO: Ça n'est pas du tout élégant, mais : # TODO: Ça n'est pas du tout élégant, mais :
# 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers # 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers
# 2. Ansible positionne à `None` les entrées non passées à la tâche et # 2. Ansible positionne à `None` les entrées non passées à la tâche et
@ -163,10 +221,16 @@ def make_record(args, ty):
RName: RName, RName: RName,
} }
def coerce(name, value): def coerce_single(value, ty):
if types[name] not in coercers: if ty in coercers:
return coercers[ty](value)
return value return value
return coercers[types[name]](value)
def coerce(name, value):
if is_multi_records(types[name]):
origin = annotated_origin(types[name])
return [coerce_single(v, origin) for v in value]
return coerce_single(value, types[name])
clean_args = { clean_args = {
name: coerce(name, value) name: coerce(name, value)
@ -174,11 +238,16 @@ def make_record(args, ty):
if value is not None if value is not None
} }
return ty(**clean_args) multi_keys = (k for k, v in types.items() if is_multi_records(v))
for single_args in product_dict(clean_args, multi_keys):
yield ty(**single_args)
def zones_eq(a: dns.zone.Zone, b: dns.zone.Zone) -> bool: def zones_eq(lhs: dns.zone.Zone, rhs: dns.zone.Zone) -> bool:
return a.to_text(relativize=False) == b.to_text(relativize=False) """Returns a `bool` indicating whether two `dns.zone.Zone`s are equal
w.r.t. their text representation."""
return lhs.to_text(relativize=False) == rhs.to_text(relativize=False)
def main() -> int: def main() -> int:
@ -188,17 +257,19 @@ def main() -> int:
"txt": TXT, "txt": TXT,
"a": A, "a": A,
"aaaa": AAAA, "aaaa": AAAA,
"cname": CNAME,
"mx": MX, "mx": MX,
} }
module_args = { module_args = {
"path": {"type": "path", "required": True}, "path": {"type": "str", "required": True},
"origin": {"type": "str", "required": True}, "origin": {"type": "str", "required": True},
"soa": { "soa": {
"type": "dict", "type": "dict",
"required": True, "required": True,
"options": spec_options_of_type(SOA), "options": spec_options_of_type(SOA),
}, },
"hosts": {"type": "dict", "default": {}},
} }
for name, ty in record_types.items(): for name, ty in record_types.items():
@ -209,7 +280,10 @@ def main() -> int:
"options": spec_options_of_type(ty), "options": spec_options_of_type(ty),
} }
module = AnsibleModule(argument_spec=module_args) module = AnsibleModule(
argument_spec=module_args,
add_file_common_args=True,
)
origin = dns.name.from_text(module.params["origin"]) origin = dns.name.from_text(module.params["origin"])
path = module.params["path"] path = module.params["path"]
@ -218,16 +292,18 @@ def main() -> int:
try: try:
current = dns.zone.from_file(path, origin=origin) current = dns.zone.from_file(path, origin=origin)
except: except Exception:
current = None current = None
records = [make_record(module.params["soa"], SOA)] records = itertools.chain(
make_records(module.params["soa"], SOA),
records.extend( make_hosts_records(module.params["hosts"]),
itertools.chain.from_iterable( itertools.chain.from_iterable(
(make_record(args, ty) for args in module.params[name]) itertools.chain.from_iterable(
for name, ty in record_types.items() make_records(args, ty) for args in module.params[name]
) )
for name, ty in record_types.items()
),
) )
for record in records: for record in records:
@ -236,9 +312,13 @@ def main() -> int:
dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True) dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True)
dataset.add(rdata) dataset.add(rdata)
file_args = module.load_file_common_arguments(module.params)
changed = current is None or not zones_eq(zone, current) changed = current is None or not zones_eq(zone, current)
if changed: if changed:
zone.to_file(path, relativize=False) zone.to_file(module.params["path"], relativize=False)
changed = module.set_fs_attributes_if_different(file_args, changed)
module.exit_json(changed=changed) module.exit_json(changed=changed)
@ -246,4 +326,4 @@ def main() -> int:
if __name__ == "__main__": if __name__ == "__main__":
exit(main()) sys.exit(main())