dns_zone: cleanup + hosts + product

dns
jeltz 2 years ago
parent c97dca8fa8
commit 4dbe0e562d
Signed by: jeltz
GPG Key ID: 800882B66C0C3326

@ -1,26 +1,30 @@
#!/usr/bin/env python3
import itertools
import dataclasses
from typing import Any
import ipaddress
import itertools
import sys
import typing
from typing import Annotated, Any
import dns
import dns.serial
import dns.zone
import dns.rdata
import dns.rdataclass
import dns.rdatatype
import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA
import dns.rdtypes.ANY.CNAME
import dns.rdtypes.ANY.MX
import dns.rdtypes.ANY.SOA
import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.SOA
import dns.rdtypes.ANY.TXT
import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA
import dns.serial
import dns.zone
from ansible.module_utils.basic import AnsibleModule
class RName(dns.name.Name):
"""Domain name used to represent an e-mail address (see RFC 1035)."""
def __init__(self, address):
try:
local, domain = address.split("@")
@ -31,10 +35,18 @@ class RName(dns.name.Name):
super().__init__((local,) + dns.name.from_text(domain).labels)
class MultiRecords:
"""Annotation used to indicate that a field can be filled in more than
once via a list, and that this will create as many records as values.
"""
...
@dataclasses.dataclass
class A:
address: str
name: dns.name.Name = dns.name.empty
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.IN.A.A(
@ -45,7 +57,7 @@ class A:
@dataclasses.dataclass
class AAAA:
address: str
name: dns.name.Name = dns.name.empty
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.IN.AAAA.AAAA(
@ -56,7 +68,7 @@ class AAAA:
@dataclasses.dataclass
class CNAME:
address: dns.name.Name
name: dns.name.Name = dns.name.empty
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.CNAME.CNAME(
@ -66,8 +78,8 @@ class CNAME:
@dataclasses.dataclass
class MX:
exchange: dns.name.Name
name: dns.name.Name = dns.name.empty
exchange: Annotated[dns.name.Name, MultiRecords]
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
priority: int = 10
def rdata(self) -> dns.rdata.Rdata:
@ -81,8 +93,8 @@ class MX:
@dataclasses.dataclass
class NS:
address: dns.name.Name
name: dns.name.Name = dns.name.empty
address: Annotated[dns.name.Name, MultiRecords]
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.NS.NS(
@ -93,7 +105,7 @@ class NS:
@dataclasses.dataclass
class TXT:
data: str
name: dns.name.Name = dns.name.empty
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.TXT.TXT(
@ -110,7 +122,7 @@ class SOA:
expire: int
ttl: int
serial: int = 1
name: dns.name.Name = dns.name.empty
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.SOA.SOA(
@ -126,6 +138,25 @@ class SOA:
)
def has_annotation(ty, annotation):
"""Is the type `ty` annotated with a given `annotation`."""
return (
typing.get_origin(ty) == typing.Annotated
and annotation in typing.get_args(ty)[1:]
)
def annotated_origin(ty):
"""Returns the origin of an annotated type `ty`."""
assert typing.get_origin(ty) == typing.Annotated
return typing.get_args(ty)[0]
def is_multi_records(ty):
"""Is the type `ty` annotated with `MultiRecords`."""
return has_annotation(ty, MultiRecords)
def spec_option_of_field(field):
types = {
str: "str",
@ -133,13 +164,20 @@ def spec_option_of_field(field):
RName: "str",
int: "int",
}
return {
"type": types[field.type],
"required": field.default is dataclasses.MISSING,
}
if is_multi_records(field.type):
option = {
"type": "list",
"elements": types[annotated_origin(field.type)],
}
else:
option = {"type": types[field.type]}
option["required"] = field.default is dataclasses.MISSING
return option
def spec_options_of_type(ty):
"""Convert a `dataclass` type to Ansible `argument_spec` `options`'
format."""
return {
field.name: spec_option_of_field(field)
for field in dataclasses.fields(ty)
@ -147,12 +185,32 @@ def spec_options_of_type(ty):
def coerce_dns_name(value: Any) -> dns.name.Name:
"""Try to convert a `value` to `dns.name.Name`."""
if not isinstance(value, dns.name.Name):
return dns.name.from_text(value, origin=dns.name.empty)
return value
def make_record(args, ty):
def product_dict(dct, keys=None):
"""Compute the "cartesian product" of a dictionnary `dct`
w.r.t some `keys` (if `keys` is None, then the product is computed
on all the keys)."""
if keys is None:
keys = dct.keys()
wrapped = {k: v if k in keys else [v] for k, v in dct.items()}
for values in itertools.product(*wrapped.values()):
yield dict(zip(wrapped.keys(), values))
def make_hosts_records(hosts):
for host, addrs in hosts.items():
for addr in addrs:
name = dns.name.from_text(host, origin=dns.name.empty)
decoded = ipaddress.ip_address(addr)
yield AAAA(addr, name) if decoded.version == 6 else A(addr, name)
def make_records(args, ty):
# TODO: Ça n'est pas du tout élégant, mais :
# 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers
# 2. Ansible positionne à `None` les entrées non passées à la tâche et
@ -163,10 +221,16 @@ def make_record(args, ty):
RName: RName,
}
def coerce_single(value, ty):
if ty in coercers:
return coercers[ty](value)
return value
def coerce(name, value):
if types[name] not in coercers:
return value
return coercers[types[name]](value)
if is_multi_records(types[name]):
origin = annotated_origin(types[name])
return [coerce_single(v, origin) for v in value]
return coerce_single(value, types[name])
clean_args = {
name: coerce(name, value)
@ -174,11 +238,16 @@ def make_record(args, ty):
if value is not None
}
return ty(**clean_args)
multi_keys = (k for k, v in types.items() if is_multi_records(v))
for single_args in product_dict(clean_args, multi_keys):
yield ty(**single_args)
def zones_eq(a: dns.zone.Zone, b: dns.zone.Zone) -> bool:
return a.to_text(relativize=False) == b.to_text(relativize=False)
def zones_eq(lhs: dns.zone.Zone, rhs: dns.zone.Zone) -> bool:
"""Returns a `bool` indicating whether two `dns.zone.Zone`s are equal
w.r.t. their text representation."""
return lhs.to_text(relativize=False) == rhs.to_text(relativize=False)
def main() -> int:
@ -188,17 +257,19 @@ def main() -> int:
"txt": TXT,
"a": A,
"aaaa": AAAA,
"cname": CNAME,
"mx": MX,
}
module_args = {
"path": {"type": "path", "required": True},
"path": {"type": "str", "required": True},
"origin": {"type": "str", "required": True},
"soa": {
"type": "dict",
"required": True,
"options": spec_options_of_type(SOA),
},
"hosts": {"type": "dict", "default": {}},
}
for name, ty in record_types.items():
@ -209,7 +280,10 @@ def main() -> int:
"options": spec_options_of_type(ty),
}
module = AnsibleModule(argument_spec=module_args)
module = AnsibleModule(
argument_spec=module_args,
add_file_common_args=True,
)
origin = dns.name.from_text(module.params["origin"])
path = module.params["path"]
@ -218,16 +292,18 @@ def main() -> int:
try:
current = dns.zone.from_file(path, origin=origin)
except:
except Exception:
current = None
records = [make_record(module.params["soa"], SOA)]
records.extend(
records = itertools.chain(
make_records(module.params["soa"], SOA),
make_hosts_records(module.params["hosts"]),
itertools.chain.from_iterable(
(make_record(args, ty) for args in module.params[name])
itertools.chain.from_iterable(
make_records(args, ty) for args in module.params[name]
)
for name, ty in record_types.items()
)
),
)
for record in records:
@ -236,9 +312,13 @@ def main() -> int:
dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True)
dataset.add(rdata)
file_args = module.load_file_common_arguments(module.params)
changed = current is None or not zones_eq(zone, current)
if changed:
zone.to_file(path, relativize=False)
zone.to_file(module.params["path"], relativize=False)
changed = module.set_fs_attributes_if_different(file_args, changed)
module.exit_json(changed=changed)
@ -246,4 +326,4 @@ def main() -> int:
if __name__ == "__main__":
exit(main())
sys.exit(main())

Loading…
Cancel
Save