ansible/roles/knotd/library/dns_zone.py

419 lines
11 KiB
Python
Raw Normal View History

2022-08-16 20:13:25 +02:00
import dataclasses
2022-08-17 18:23:47 +02:00
import ipaddress
import itertools
import sys
import typing
from typing import Annotated, Any
2022-08-16 20:13:25 +02:00
import dns
import dns.rdata
import dns.rdataclass
import dns.rdatatype
2022-08-17 18:23:47 +02:00
import dns.rdtypes.ANY.CNAME
2022-08-16 20:13:25 +02:00
import dns.rdtypes.ANY.MX
import dns.rdtypes.ANY.NS
import dns.rdtypes.ANY.PTR
2022-08-17 18:23:47 +02:00
import dns.rdtypes.ANY.SOA
import dns.rdtypes.ANY.SPF
2022-08-16 20:13:25 +02:00
import dns.rdtypes.ANY.TXT
2022-08-17 18:23:47 +02:00
import dns.rdtypes.IN.A
import dns.rdtypes.IN.AAAA
import dns.rdtypes.IN.SRV
import dns.reversename
2022-08-17 18:23:47 +02:00
import dns.serial
import dns.zone
2022-08-16 20:13:25 +02:00
from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.common.validation import check_type_list
2022-08-16 20:13:25 +02:00
class RName(dns.name.Name):
2022-08-17 18:23:47 +02:00
"""Domain name used to represent an e-mail address (see RFC 1035)."""
2022-08-16 20:13:25 +02:00
def __init__(self, address):
try:
local, domain = address.split("@")
except ValueError:
raise ValueError(
"Invalid e-mail address format: {}".format(address)
)
super().__init__(
(local,) + dns.name.from_text(domain, origin=dns.name.empty).labels
)
2022-08-16 20:13:25 +02:00
2022-08-17 18:23:47 +02:00
class MultiRecords:
"""Annotation used to indicate that a field can be filled in more than
once via a list, and that this will create as many records as values.
"""
...
2022-08-16 20:13:25 +02:00
@dataclasses.dataclass
class A:
address: str
2022-08-17 18:23:47 +02:00
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
2022-08-16 20:13:25 +02:00
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.IN.A.A(
dns.rdataclass.IN.IN, dns.rdatatype.A, self.address
)
@dataclasses.dataclass
class AAAA:
address: str
2022-08-17 18:23:47 +02:00
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
2022-08-16 20:13:25 +02:00
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.IN.AAAA.AAAA(
dns.rdataclass.IN.IN, dns.rdatatype.AAAA, self.address
)
@dataclasses.dataclass
class PTR:
target: dns.name.Name
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.PTR.PTR(
dns.rdataclass.IN.IN, dns.rdatatype.PTR, self.target
)
2022-08-16 20:13:25 +02:00
@dataclasses.dataclass
class CNAME:
target: dns.name.Name
2022-08-17 18:23:47 +02:00
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
2022-08-16 20:13:25 +02:00
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.CNAME.CNAME(
dns.rdataclass.IN.IN, dns.rdatatype.CNAME, self.target
2022-08-16 20:13:25 +02:00
)
@dataclasses.dataclass
class MX:
2022-08-17 18:23:47 +02:00
exchange: Annotated[dns.name.Name, MultiRecords]
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
preference: int = 10
2022-08-16 20:13:25 +02:00
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.MX.MX(
dns.rdataclass.IN.IN,
dns.rdatatype.MX,
self.preference,
2022-08-16 20:13:25 +02:00
self.exchange,
)
@dataclasses.dataclass
class NS:
target: Annotated[dns.name.Name, MultiRecords]
2022-08-17 18:23:47 +02:00
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
2022-08-16 20:13:25 +02:00
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.NS.NS(
dns.rdataclass.IN.IN, dns.rdatatype.NS, self.target
)
@dataclasses.dataclass
class SPF:
data: str
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.SPF.SPF(
dns.rdataclass.IN.IN, dns.rdatatype.SPF, self.data
2022-08-16 20:13:25 +02:00
)
@dataclasses.dataclass
class TXT:
data: str
2022-08-17 18:23:47 +02:00
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
2022-08-16 20:13:25 +02:00
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.TXT.TXT(
dns.rdataclass.IN.IN, dns.rdatatype.TXT, self.data
)
@dataclasses.dataclass
class SRV:
target: Annotated[dns.name.Name, MultiRecords]
weight: int
port: int
priority: int = 10
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.IN.SRV.SRV(
dns.rdataclass.IN.IN,
dns.rdatatype.SRV,
self.priority,
self.weight,
self.port,
self.target,
)
2022-08-16 20:13:25 +02:00
@dataclasses.dataclass
class SOA:
mname: dns.name.Name
rname: RName
refresh: int
retry: int
expire: int
minimum: int
2022-08-16 20:13:25 +02:00
serial: int = 1
2022-08-17 18:23:47 +02:00
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
2022-08-16 20:13:25 +02:00
def rdata(self) -> dns.rdata.Rdata:
return dns.rdtypes.ANY.SOA.SOA(
dns.rdataclass.IN.IN,
dns.rdatatype.SOA,
self.mname,
self.rname,
self.serial,
self.refresh,
self.retry,
self.expire,
self.minimum,
2022-08-16 20:13:25 +02:00
)
2022-08-17 18:23:47 +02:00
def has_annotation(ty, annotation):
"""Is the type `ty` annotated with a given `annotation`."""
return (
typing.get_origin(ty) == typing.Annotated
and annotation in typing.get_args(ty)[1:]
)
def annotated_origin(ty):
"""Returns the origin of an annotated type `ty`."""
assert typing.get_origin(ty) == typing.Annotated
return typing.get_args(ty)[0]
def is_multi_records(ty):
"""Is the type `ty` annotated with `MultiRecords`."""
return has_annotation(ty, MultiRecords)
2022-08-16 20:13:25 +02:00
def spec_option_of_field(field):
types = {
str: "str",
dns.name.Name: "str",
RName: "str",
int: "int",
}
2022-08-17 18:23:47 +02:00
if is_multi_records(field.type):
option = {
"type": "list",
"elements": types[annotated_origin(field.type)],
}
else:
option = {"type": types[field.type]}
option["required"] = field.default is dataclasses.MISSING
return option
2022-08-16 20:13:25 +02:00
def spec_options_of_type(ty):
2022-08-17 18:23:47 +02:00
"""Convert a `dataclass` type to Ansible `argument_spec` `options`'
format."""
2022-08-16 20:13:25 +02:00
return {
field.name: spec_option_of_field(field)
for field in dataclasses.fields(ty)
}
def coerce_dns_name(value: Any) -> dns.name.Name:
2022-08-17 18:23:47 +02:00
"""Try to convert a `value` to `dns.name.Name`."""
2022-08-16 20:13:25 +02:00
if not isinstance(value, dns.name.Name):
return dns.name.from_text(value, origin=dns.name.empty)
return value
2022-08-17 18:23:47 +02:00
def product_dict(dct, keys=None):
"""Compute the "cartesian product" of a dictionnary `dct`
w.r.t some `keys` (if `keys` is None, then the product is computed
on all the keys)."""
if keys is None:
keys = dct.keys()
wrapped = {k: v if k in keys else [v] for k, v in dct.items()}
for values in itertools.product(*wrapped.values()):
yield dict(zip(wrapped.keys(), values))
def make_hosts_records(hosts):
for host, addrs in hosts.items():
for addr in check_type_list(addrs):
2022-08-17 18:23:47 +02:00
name = dns.name.from_text(host, origin=dns.name.empty)
decoded = ipaddress.ip_address(addr)
yield AAAA(addr, name) if decoded.version == 6 else A(addr, name)
def make_reverse_hosts_records(hosts):
for host, addrs in hosts.items():
for addr in check_type_list(addrs):
name = dns.name.from_text(host)
reverse = dns.reversename.from_address(addr)
yield PTR(name, reverse)
2022-08-17 18:23:47 +02:00
def make_records(args, ty):
2022-08-16 20:13:25 +02:00
# TODO: Ça n'est pas du tout élégant, mais :
# 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers
# 2. Ansible positionne à `None` les entrées non passées à la tâche et
# ce comportement ne semble pas modifiable
types = {f.name: f.type for f in dataclasses.fields(ty)}
coercers = {
dns.name.Name: coerce_dns_name,
RName: RName,
}
2022-08-17 18:23:47 +02:00
def coerce_single(value, ty):
if ty in coercers:
return coercers[ty](value)
return value
2022-08-16 20:13:25 +02:00
def coerce(name, value):
2022-08-17 18:23:47 +02:00
if is_multi_records(types[name]):
origin = annotated_origin(types[name])
return [coerce_single(v, origin) for v in value]
return coerce_single(value, types[name])
2022-08-16 20:13:25 +02:00
clean_args = {
name: coerce(name, value)
for name, value in args.items()
if value is not None
}
multi_keys = {k for k, v in types.items() if is_multi_records(v)}
2022-08-16 20:13:25 +02:00
2022-08-17 18:23:47 +02:00
for single_args in product_dict(clean_args, multi_keys):
yield ty(**single_args)
2022-08-16 20:13:25 +02:00
2022-08-17 18:23:47 +02:00
def zones_eq(lhs: dns.zone.Zone, rhs: dns.zone.Zone) -> bool:
"""Returns a `bool` indicating whether two `dns.zone.Zone`s are equal
w.r.t. their text representation."""
2022-08-18 16:35:16 +02:00
return lhs.to_text(relativize=False, sorted=True) == rhs.to_text(
relativize=False, sorted=True
)
2022-08-16 20:13:25 +02:00
def write_text_file(path, text, module):
"""Naive text file write function with support for Ansible's diff and
check modes."""
diff_text = {
"before_header": f"{path} (content)",
"after_header": f"{path} (content)",
"after": text,
}
try:
with open(path) as f:
current = f.read()
changed = text != current
diff_text["before"] = current
except Exception:
changed = True
diff_text["before"] = None
if changed and not module.check_mode:
with open(path, "w") as f:
f.write(text)
file_args = module.load_file_common_arguments(module.params)
diff_attrs = {
"before_header": f"{path} (attributes)",
"after_header": f"{path} (attributes)",
}
changed = module.set_file_attributes_if_different(
file_args, changed, diff_attrs
)
return changed, [diff_text, diff_attrs]
2022-08-16 20:13:25 +02:00
def main() -> int:
record_types = {
"ns": NS,
"txt": TXT,
"a": A,
"aaaa": AAAA,
"srv": SRV,
"spf": SPF,
"ptr": PTR,
2022-08-17 18:23:47 +02:00
"cname": CNAME,
2022-08-16 20:13:25 +02:00
"mx": MX,
}
module_args = {
2022-08-17 18:23:47 +02:00
"path": {"type": "str", "required": True},
2022-08-16 20:13:25 +02:00
"origin": {"type": "str", "required": True},
"soa": {
"type": "dict",
"required": True,
"options": spec_options_of_type(SOA),
},
2022-08-17 18:23:47 +02:00
"hosts": {"type": "dict", "default": {}},
"reverse_hosts": {"type": "dict", "default": {}},
2022-08-16 20:13:25 +02:00
}
for name, ty in record_types.items():
module_args[name] = {
"type": "list",
"default": [],
"elements": "dict",
"options": spec_options_of_type(ty),
}
2022-08-17 18:23:47 +02:00
module = AnsibleModule(
argument_spec=module_args,
add_file_common_args=True,
supports_check_mode=True,
2022-08-17 18:23:47 +02:00
)
2022-08-16 20:13:25 +02:00
origin = dns.name.from_text(module.params["origin"])
path = module.params["path"]
zone = dns.zone.Zone(origin)
2022-08-17 18:23:47 +02:00
records = itertools.chain(
make_records(module.params["soa"], SOA),
make_reverse_hosts_records(module.params["reverse_hosts"]),
2022-08-17 18:23:47 +02:00
make_hosts_records(module.params["hosts"]),
2022-08-16 20:13:25 +02:00
itertools.chain.from_iterable(
2022-08-17 18:23:47 +02:00
itertools.chain.from_iterable(
make_records(args, ty) for args in module.params[name]
)
2022-08-16 20:13:25 +02:00
for name, ty in record_types.items()
2022-08-17 18:23:47 +02:00
),
2022-08-16 20:13:25 +02:00
)
for record in records:
node = zone.get_node(record.name, create=True)
rdata = record.rdata()
dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True)
dataset.add(rdata)
zone_text = zone.to_text(relativize=False, sorted=True)
2022-08-17 18:23:47 +02:00
changed, diff = write_text_file(path, zone_text, module)
2022-08-16 20:13:25 +02:00
module.exit_json(changed=changed, diff=diff)
2022-08-16 20:13:25 +02:00
return 0
if __name__ == "__main__":
2022-08-17 18:23:47 +02:00
sys.exit(main())