392 lines
10 KiB
Python
Executable file
392 lines
10 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
import dataclasses
|
|
import ipaddress
|
|
import itertools
|
|
import sys
|
|
import typing
|
|
from typing import Annotated, Any
|
|
|
|
import dns
|
|
import dns.rdata
|
|
import dns.rdataclass
|
|
import dns.rdatatype
|
|
import dns.rdtypes.ANY.CNAME
|
|
import dns.rdtypes.ANY.MX
|
|
import dns.rdtypes.ANY.NS
|
|
import dns.rdtypes.ANY.PTR
|
|
import dns.rdtypes.ANY.SOA
|
|
import dns.rdtypes.ANY.SPF
|
|
import dns.rdtypes.ANY.TXT
|
|
import dns.rdtypes.IN.A
|
|
import dns.rdtypes.IN.AAAA
|
|
import dns.rdtypes.IN.SRV
|
|
import dns.reversename
|
|
import dns.serial
|
|
import dns.zone
|
|
from ansible.module_utils.basic import AnsibleModule
|
|
from ansible.module_utils.common.validation import check_type_list
|
|
|
|
|
|
class RName(dns.name.Name):
|
|
"""Domain name used to represent an e-mail address (see RFC 1035)."""
|
|
|
|
def __init__(self, address):
|
|
try:
|
|
local, domain = address.split("@")
|
|
except ValueError:
|
|
raise ValueError(
|
|
"Invalid e-mail address format: {}".format(address)
|
|
)
|
|
super().__init__(
|
|
(local,) + dns.name.from_text(domain, origin=dns.name.empty).labels
|
|
)
|
|
|
|
|
|
class MultiRecords:
|
|
"""Annotation used to indicate that a field can be filled in more than
|
|
once via a list, and that this will create as many records as values.
|
|
"""
|
|
|
|
...
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class A:
|
|
address: str
|
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
|
|
|
def rdata(self) -> dns.rdata.Rdata:
|
|
return dns.rdtypes.IN.A.A(
|
|
dns.rdataclass.IN.IN, dns.rdatatype.A, self.address
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class AAAA:
|
|
address: str
|
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
|
|
|
def rdata(self) -> dns.rdata.Rdata:
|
|
return dns.rdtypes.IN.AAAA.AAAA(
|
|
dns.rdataclass.IN.IN, dns.rdatatype.AAAA, self.address
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PTR:
|
|
target: dns.name.Name
|
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
|
|
|
def rdata(self) -> dns.rdata.Rdata:
|
|
return dns.rdtypes.ANY.PTR.PTR(
|
|
dns.rdataclass.IN.IN, dns.rdatatype.PTR, self.target
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CNAME:
|
|
target: dns.name.Name
|
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
|
|
|
def rdata(self) -> dns.rdata.Rdata:
|
|
return dns.rdtypes.ANY.CNAME.CNAME(
|
|
dns.rdataclass.IN.IN, dns.rdatatype.CNAME, self.target
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class MX:
|
|
exchange: Annotated[dns.name.Name, MultiRecords]
|
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
|
preference: int = 10
|
|
|
|
def rdata(self) -> dns.rdata.Rdata:
|
|
return dns.rdtypes.ANY.MX.MX(
|
|
dns.rdataclass.IN.IN,
|
|
dns.rdatatype.MX,
|
|
self.preference,
|
|
self.exchange,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class NS:
|
|
target: Annotated[dns.name.Name, MultiRecords]
|
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
|
|
|
def rdata(self) -> dns.rdata.Rdata:
|
|
return dns.rdtypes.ANY.NS.NS(
|
|
dns.rdataclass.IN.IN, dns.rdatatype.NS, self.target
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SPF:
|
|
data: str
|
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
|
|
|
def rdata(self) -> dns.rdata.Rdata:
|
|
return dns.rdtypes.ANY.SPF.SPF(
|
|
dns.rdataclass.IN.IN, dns.rdatatype.SPF, self.data
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TXT:
|
|
data: str
|
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
|
|
|
def rdata(self) -> dns.rdata.Rdata:
|
|
return dns.rdtypes.ANY.TXT.TXT(
|
|
dns.rdataclass.IN.IN, dns.rdatatype.TXT, self.data
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SRV:
|
|
target: Annotated[dns.name.Name, MultiRecords]
|
|
weight: int
|
|
port: int
|
|
priority: int = 10
|
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
|
|
|
def rdata(self) -> dns.rdata.Rdata:
|
|
return dns.rdtypes.IN.SRV.SRV(
|
|
dns.rdataclass.IN.IN,
|
|
dns.rdatatype.SRV,
|
|
self.priority,
|
|
self.weight,
|
|
self.port,
|
|
self.target,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SOA:
|
|
mname: dns.name.Name
|
|
rname: RName
|
|
refresh: int
|
|
retry: int
|
|
expire: int
|
|
minimum: int
|
|
serial: int = 1
|
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
|
|
|
def rdata(self) -> dns.rdata.Rdata:
|
|
return dns.rdtypes.ANY.SOA.SOA(
|
|
dns.rdataclass.IN.IN,
|
|
dns.rdatatype.SOA,
|
|
self.mname,
|
|
self.rname,
|
|
self.serial,
|
|
self.refresh,
|
|
self.retry,
|
|
self.expire,
|
|
self.minimum,
|
|
)
|
|
|
|
|
|
def has_annotation(ty, annotation):
|
|
"""Is the type `ty` annotated with a given `annotation`."""
|
|
return (
|
|
typing.get_origin(ty) == typing.Annotated
|
|
and annotation in typing.get_args(ty)[1:]
|
|
)
|
|
|
|
|
|
def annotated_origin(ty):
|
|
"""Returns the origin of an annotated type `ty`."""
|
|
assert typing.get_origin(ty) == typing.Annotated
|
|
return typing.get_args(ty)[0]
|
|
|
|
|
|
def is_multi_records(ty):
|
|
"""Is the type `ty` annotated with `MultiRecords`."""
|
|
return has_annotation(ty, MultiRecords)
|
|
|
|
|
|
def spec_option_of_field(field):
|
|
types = {
|
|
str: "str",
|
|
dns.name.Name: "str",
|
|
RName: "str",
|
|
int: "int",
|
|
}
|
|
if is_multi_records(field.type):
|
|
option = {
|
|
"type": "list",
|
|
"elements": types[annotated_origin(field.type)],
|
|
}
|
|
else:
|
|
option = {"type": types[field.type]}
|
|
option["required"] = field.default is dataclasses.MISSING
|
|
return option
|
|
|
|
|
|
def spec_options_of_type(ty):
|
|
"""Convert a `dataclass` type to Ansible `argument_spec` `options`'
|
|
format."""
|
|
return {
|
|
field.name: spec_option_of_field(field)
|
|
for field in dataclasses.fields(ty)
|
|
}
|
|
|
|
|
|
def coerce_dns_name(value: Any) -> dns.name.Name:
|
|
"""Try to convert a `value` to `dns.name.Name`."""
|
|
if not isinstance(value, dns.name.Name):
|
|
return dns.name.from_text(value, origin=dns.name.empty)
|
|
return value
|
|
|
|
|
|
def product_dict(dct, keys=None):
|
|
"""Compute the "cartesian product" of a dictionnary `dct`
|
|
w.r.t some `keys` (if `keys` is None, then the product is computed
|
|
on all the keys)."""
|
|
if keys is None:
|
|
keys = dct.keys()
|
|
wrapped = {k: v if k in keys else [v] for k, v in dct.items()}
|
|
for values in itertools.product(*wrapped.values()):
|
|
yield dict(zip(wrapped.keys(), values))
|
|
|
|
|
|
def make_hosts_records(hosts):
|
|
for host, addrs in hosts.items():
|
|
for addr in check_type_list(addrs):
|
|
name = dns.name.from_text(host, origin=dns.name.empty)
|
|
decoded = ipaddress.ip_address(addr)
|
|
yield AAAA(addr, name) if decoded.version == 6 else A(addr, name)
|
|
|
|
|
|
def make_reverse_hosts_records(hosts):
|
|
for host, addrs in hosts.items():
|
|
for addr in check_type_list(addrs):
|
|
name = dns.name.from_text(host)
|
|
reverse = dns.reversename.from_address(addr)
|
|
yield PTR(name, reverse)
|
|
|
|
|
|
def make_records(args, ty):
|
|
# TODO: Ça n'est pas du tout élégant, mais :
|
|
# 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers
|
|
# 2. Ansible positionne à `None` les entrées non passées à la tâche et
|
|
# ce comportement ne semble pas modifiable
|
|
types = {f.name: f.type for f in dataclasses.fields(ty)}
|
|
coercers = {
|
|
dns.name.Name: coerce_dns_name,
|
|
RName: RName,
|
|
}
|
|
|
|
def coerce_single(value, ty):
|
|
if ty in coercers:
|
|
return coercers[ty](value)
|
|
return value
|
|
|
|
def coerce(name, value):
|
|
if is_multi_records(types[name]):
|
|
origin = annotated_origin(types[name])
|
|
return [coerce_single(v, origin) for v in value]
|
|
return coerce_single(value, types[name])
|
|
|
|
clean_args = {
|
|
name: coerce(name, value)
|
|
for name, value in args.items()
|
|
if value is not None
|
|
}
|
|
|
|
multi_keys = {k for k, v in types.items() if is_multi_records(v)}
|
|
|
|
for single_args in product_dict(clean_args, multi_keys):
|
|
yield ty(**single_args)
|
|
|
|
|
|
def zones_eq(lhs: dns.zone.Zone, rhs: dns.zone.Zone) -> bool:
|
|
"""Returns a `bool` indicating whether two `dns.zone.Zone`s are equal
|
|
w.r.t. their text representation."""
|
|
return lhs.to_text(relativize=False, sorted=True) == rhs.to_text(
|
|
relativize=False, sorted=True
|
|
)
|
|
|
|
|
|
def main() -> int:
|
|
|
|
record_types = {
|
|
"ns": NS,
|
|
"txt": TXT,
|
|
"a": A,
|
|
"aaaa": AAAA,
|
|
"srv": SRV,
|
|
"spf": SPF,
|
|
"ptr": PTR,
|
|
"cname": CNAME,
|
|
"mx": MX,
|
|
}
|
|
|
|
module_args = {
|
|
"path": {"type": "str", "required": True},
|
|
"origin": {"type": "str", "required": True},
|
|
"soa": {
|
|
"type": "dict",
|
|
"required": True,
|
|
"options": spec_options_of_type(SOA),
|
|
},
|
|
"hosts": {"type": "dict", "default": {}},
|
|
"reverse_hosts": {"type": "dict", "default": {}},
|
|
}
|
|
|
|
for name, ty in record_types.items():
|
|
module_args[name] = {
|
|
"type": "list",
|
|
"default": [],
|
|
"elements": "dict",
|
|
"options": spec_options_of_type(ty),
|
|
}
|
|
|
|
module = AnsibleModule(
|
|
argument_spec=module_args,
|
|
add_file_common_args=True,
|
|
)
|
|
|
|
origin = dns.name.from_text(module.params["origin"])
|
|
path = module.params["path"]
|
|
|
|
zone = dns.zone.Zone(origin)
|
|
|
|
try:
|
|
current = dns.zone.from_file(path, origin=origin)
|
|
except Exception:
|
|
current = None
|
|
|
|
records = itertools.chain(
|
|
make_records(module.params["soa"], SOA),
|
|
make_reverse_hosts_records(module.params["reverse_hosts"]),
|
|
make_hosts_records(module.params["hosts"]),
|
|
itertools.chain.from_iterable(
|
|
itertools.chain.from_iterable(
|
|
make_records(args, ty) for args in module.params[name]
|
|
)
|
|
for name, ty in record_types.items()
|
|
),
|
|
)
|
|
|
|
for record in records:
|
|
node = zone.get_node(record.name, create=True)
|
|
rdata = record.rdata()
|
|
dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True)
|
|
dataset.add(rdata)
|
|
|
|
file_args = module.load_file_common_arguments(module.params)
|
|
|
|
changed = current is None or not zones_eq(zone, current)
|
|
if changed:
|
|
zone.to_file(module.params["path"], relativize=False, sorted=True)
|
|
|
|
changed = module.set_fs_attributes_if_different(file_args, changed)
|
|
|
|
module.exit_json(changed=changed)
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|