dns_zone: cleanup + hosts + product
This commit is contained in:
parent
c97dca8fa8
commit
4dbe0e562d
1 changed files with 120 additions and 40 deletions
|
@ -1,26 +1,30 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import itertools
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import ipaddress
|
||||||
from typing import Any
|
import itertools
|
||||||
|
import sys
|
||||||
|
import typing
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
import dns
|
import dns
|
||||||
import dns.serial
|
|
||||||
import dns.zone
|
|
||||||
import dns.rdata
|
import dns.rdata
|
||||||
import dns.rdataclass
|
import dns.rdataclass
|
||||||
import dns.rdatatype
|
import dns.rdatatype
|
||||||
|
import dns.rdtypes.ANY.CNAME
|
||||||
|
import dns.rdtypes.ANY.MX
|
||||||
|
import dns.rdtypes.ANY.NS
|
||||||
|
import dns.rdtypes.ANY.SOA
|
||||||
|
import dns.rdtypes.ANY.TXT
|
||||||
import dns.rdtypes.IN.A
|
import dns.rdtypes.IN.A
|
||||||
import dns.rdtypes.IN.AAAA
|
import dns.rdtypes.IN.AAAA
|
||||||
import dns.rdtypes.ANY.MX
|
import dns.serial
|
||||||
import dns.rdtypes.ANY.SOA
|
import dns.zone
|
||||||
import dns.rdtypes.ANY.NS
|
|
||||||
import dns.rdtypes.ANY.TXT
|
|
||||||
|
|
||||||
from ansible.module_utils.basic import AnsibleModule
|
from ansible.module_utils.basic import AnsibleModule
|
||||||
|
|
||||||
|
|
||||||
class RName(dns.name.Name):
|
class RName(dns.name.Name):
|
||||||
|
"""Domain name used to represent an e-mail address (see RFC 1035)."""
|
||||||
|
|
||||||
def __init__(self, address):
|
def __init__(self, address):
|
||||||
try:
|
try:
|
||||||
local, domain = address.split("@")
|
local, domain = address.split("@")
|
||||||
|
@ -31,10 +35,18 @@ class RName(dns.name.Name):
|
||||||
super().__init__((local,) + dns.name.from_text(domain).labels)
|
super().__init__((local,) + dns.name.from_text(domain).labels)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiRecords:
|
||||||
|
"""Annotation used to indicate that a field can be filled in more than
|
||||||
|
once via a list, and that this will create as many records as values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class A:
|
class A:
|
||||||
address: str
|
address: str
|
||||||
name: dns.name.Name = dns.name.empty
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||||
|
|
||||||
def rdata(self) -> dns.rdata.Rdata:
|
def rdata(self) -> dns.rdata.Rdata:
|
||||||
return dns.rdtypes.IN.A.A(
|
return dns.rdtypes.IN.A.A(
|
||||||
|
@ -45,7 +57,7 @@ class A:
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class AAAA:
|
class AAAA:
|
||||||
address: str
|
address: str
|
||||||
name: dns.name.Name = dns.name.empty
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||||
|
|
||||||
def rdata(self) -> dns.rdata.Rdata:
|
def rdata(self) -> dns.rdata.Rdata:
|
||||||
return dns.rdtypes.IN.AAAA.AAAA(
|
return dns.rdtypes.IN.AAAA.AAAA(
|
||||||
|
@ -56,7 +68,7 @@ class AAAA:
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class CNAME:
|
class CNAME:
|
||||||
address: dns.name.Name
|
address: dns.name.Name
|
||||||
name: dns.name.Name = dns.name.empty
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||||
|
|
||||||
def rdata(self) -> dns.rdata.Rdata:
|
def rdata(self) -> dns.rdata.Rdata:
|
||||||
return dns.rdtypes.ANY.CNAME.CNAME(
|
return dns.rdtypes.ANY.CNAME.CNAME(
|
||||||
|
@ -66,8 +78,8 @@ class CNAME:
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class MX:
|
class MX:
|
||||||
exchange: dns.name.Name
|
exchange: Annotated[dns.name.Name, MultiRecords]
|
||||||
name: dns.name.Name = dns.name.empty
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||||
priority: int = 10
|
priority: int = 10
|
||||||
|
|
||||||
def rdata(self) -> dns.rdata.Rdata:
|
def rdata(self) -> dns.rdata.Rdata:
|
||||||
|
@ -81,8 +93,8 @@ class MX:
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class NS:
|
class NS:
|
||||||
address: dns.name.Name
|
address: Annotated[dns.name.Name, MultiRecords]
|
||||||
name: dns.name.Name = dns.name.empty
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||||
|
|
||||||
def rdata(self) -> dns.rdata.Rdata:
|
def rdata(self) -> dns.rdata.Rdata:
|
||||||
return dns.rdtypes.ANY.NS.NS(
|
return dns.rdtypes.ANY.NS.NS(
|
||||||
|
@ -93,7 +105,7 @@ class NS:
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class TXT:
|
class TXT:
|
||||||
data: str
|
data: str
|
||||||
name: dns.name.Name = dns.name.empty
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||||
|
|
||||||
def rdata(self) -> dns.rdata.Rdata:
|
def rdata(self) -> dns.rdata.Rdata:
|
||||||
return dns.rdtypes.ANY.TXT.TXT(
|
return dns.rdtypes.ANY.TXT.TXT(
|
||||||
|
@ -110,7 +122,7 @@ class SOA:
|
||||||
expire: int
|
expire: int
|
||||||
ttl: int
|
ttl: int
|
||||||
serial: int = 1
|
serial: int = 1
|
||||||
name: dns.name.Name = dns.name.empty
|
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||||
|
|
||||||
def rdata(self) -> dns.rdata.Rdata:
|
def rdata(self) -> dns.rdata.Rdata:
|
||||||
return dns.rdtypes.ANY.SOA.SOA(
|
return dns.rdtypes.ANY.SOA.SOA(
|
||||||
|
@ -126,6 +138,25 @@ class SOA:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def has_annotation(ty, annotation):
|
||||||
|
"""Is the type `ty` annotated with a given `annotation`."""
|
||||||
|
return (
|
||||||
|
typing.get_origin(ty) == typing.Annotated
|
||||||
|
and annotation in typing.get_args(ty)[1:]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def annotated_origin(ty):
|
||||||
|
"""Returns the origin of an annotated type `ty`."""
|
||||||
|
assert typing.get_origin(ty) == typing.Annotated
|
||||||
|
return typing.get_args(ty)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def is_multi_records(ty):
|
||||||
|
"""Is the type `ty` annotated with `MultiRecords`."""
|
||||||
|
return has_annotation(ty, MultiRecords)
|
||||||
|
|
||||||
|
|
||||||
def spec_option_of_field(field):
|
def spec_option_of_field(field):
|
||||||
types = {
|
types = {
|
||||||
str: "str",
|
str: "str",
|
||||||
|
@ -133,13 +164,20 @@ def spec_option_of_field(field):
|
||||||
RName: "str",
|
RName: "str",
|
||||||
int: "int",
|
int: "int",
|
||||||
}
|
}
|
||||||
return {
|
if is_multi_records(field.type):
|
||||||
"type": types[field.type],
|
option = {
|
||||||
"required": field.default is dataclasses.MISSING,
|
"type": "list",
|
||||||
|
"elements": types[annotated_origin(field.type)],
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
option = {"type": types[field.type]}
|
||||||
|
option["required"] = field.default is dataclasses.MISSING
|
||||||
|
return option
|
||||||
|
|
||||||
|
|
||||||
def spec_options_of_type(ty):
|
def spec_options_of_type(ty):
|
||||||
|
"""Convert a `dataclass` type to Ansible `argument_spec` `options`'
|
||||||
|
format."""
|
||||||
return {
|
return {
|
||||||
field.name: spec_option_of_field(field)
|
field.name: spec_option_of_field(field)
|
||||||
for field in dataclasses.fields(ty)
|
for field in dataclasses.fields(ty)
|
||||||
|
@ -147,12 +185,32 @@ def spec_options_of_type(ty):
|
||||||
|
|
||||||
|
|
||||||
def coerce_dns_name(value: Any) -> dns.name.Name:
|
def coerce_dns_name(value: Any) -> dns.name.Name:
|
||||||
|
"""Try to convert a `value` to `dns.name.Name`."""
|
||||||
if not isinstance(value, dns.name.Name):
|
if not isinstance(value, dns.name.Name):
|
||||||
return dns.name.from_text(value, origin=dns.name.empty)
|
return dns.name.from_text(value, origin=dns.name.empty)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def make_record(args, ty):
|
def product_dict(dct, keys=None):
|
||||||
|
"""Compute the "cartesian product" of a dictionnary `dct`
|
||||||
|
w.r.t some `keys` (if `keys` is None, then the product is computed
|
||||||
|
on all the keys)."""
|
||||||
|
if keys is None:
|
||||||
|
keys = dct.keys()
|
||||||
|
wrapped = {k: v if k in keys else [v] for k, v in dct.items()}
|
||||||
|
for values in itertools.product(*wrapped.values()):
|
||||||
|
yield dict(zip(wrapped.keys(), values))
|
||||||
|
|
||||||
|
|
||||||
|
def make_hosts_records(hosts):
|
||||||
|
for host, addrs in hosts.items():
|
||||||
|
for addr in addrs:
|
||||||
|
name = dns.name.from_text(host, origin=dns.name.empty)
|
||||||
|
decoded = ipaddress.ip_address(addr)
|
||||||
|
yield AAAA(addr, name) if decoded.version == 6 else A(addr, name)
|
||||||
|
|
||||||
|
|
||||||
|
def make_records(args, ty):
|
||||||
# TODO: Ça n'est pas du tout élégant, mais :
|
# TODO: Ça n'est pas du tout élégant, mais :
|
||||||
# 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers
|
# 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers
|
||||||
# 2. Ansible positionne à `None` les entrées non passées à la tâche et
|
# 2. Ansible positionne à `None` les entrées non passées à la tâche et
|
||||||
|
@ -163,10 +221,16 @@ def make_record(args, ty):
|
||||||
RName: RName,
|
RName: RName,
|
||||||
}
|
}
|
||||||
|
|
||||||
def coerce(name, value):
|
def coerce_single(value, ty):
|
||||||
if types[name] not in coercers:
|
if ty in coercers:
|
||||||
|
return coercers[ty](value)
|
||||||
return value
|
return value
|
||||||
return coercers[types[name]](value)
|
|
||||||
|
def coerce(name, value):
|
||||||
|
if is_multi_records(types[name]):
|
||||||
|
origin = annotated_origin(types[name])
|
||||||
|
return [coerce_single(v, origin) for v in value]
|
||||||
|
return coerce_single(value, types[name])
|
||||||
|
|
||||||
clean_args = {
|
clean_args = {
|
||||||
name: coerce(name, value)
|
name: coerce(name, value)
|
||||||
|
@ -174,11 +238,16 @@ def make_record(args, ty):
|
||||||
if value is not None
|
if value is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
return ty(**clean_args)
|
multi_keys = (k for k, v in types.items() if is_multi_records(v))
|
||||||
|
|
||||||
|
for single_args in product_dict(clean_args, multi_keys):
|
||||||
|
yield ty(**single_args)
|
||||||
|
|
||||||
|
|
||||||
def zones_eq(a: dns.zone.Zone, b: dns.zone.Zone) -> bool:
|
def zones_eq(lhs: dns.zone.Zone, rhs: dns.zone.Zone) -> bool:
|
||||||
return a.to_text(relativize=False) == b.to_text(relativize=False)
|
"""Returns a `bool` indicating whether two `dns.zone.Zone`s are equal
|
||||||
|
w.r.t. their text representation."""
|
||||||
|
return lhs.to_text(relativize=False) == rhs.to_text(relativize=False)
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
|
@ -188,17 +257,19 @@ def main() -> int:
|
||||||
"txt": TXT,
|
"txt": TXT,
|
||||||
"a": A,
|
"a": A,
|
||||||
"aaaa": AAAA,
|
"aaaa": AAAA,
|
||||||
|
"cname": CNAME,
|
||||||
"mx": MX,
|
"mx": MX,
|
||||||
}
|
}
|
||||||
|
|
||||||
module_args = {
|
module_args = {
|
||||||
"path": {"type": "path", "required": True},
|
"path": {"type": "str", "required": True},
|
||||||
"origin": {"type": "str", "required": True},
|
"origin": {"type": "str", "required": True},
|
||||||
"soa": {
|
"soa": {
|
||||||
"type": "dict",
|
"type": "dict",
|
||||||
"required": True,
|
"required": True,
|
||||||
"options": spec_options_of_type(SOA),
|
"options": spec_options_of_type(SOA),
|
||||||
},
|
},
|
||||||
|
"hosts": {"type": "dict", "default": {}},
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, ty in record_types.items():
|
for name, ty in record_types.items():
|
||||||
|
@ -209,7 +280,10 @@ def main() -> int:
|
||||||
"options": spec_options_of_type(ty),
|
"options": spec_options_of_type(ty),
|
||||||
}
|
}
|
||||||
|
|
||||||
module = AnsibleModule(argument_spec=module_args)
|
module = AnsibleModule(
|
||||||
|
argument_spec=module_args,
|
||||||
|
add_file_common_args=True,
|
||||||
|
)
|
||||||
|
|
||||||
origin = dns.name.from_text(module.params["origin"])
|
origin = dns.name.from_text(module.params["origin"])
|
||||||
path = module.params["path"]
|
path = module.params["path"]
|
||||||
|
@ -218,16 +292,18 @@ def main() -> int:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current = dns.zone.from_file(path, origin=origin)
|
current = dns.zone.from_file(path, origin=origin)
|
||||||
except:
|
except Exception:
|
||||||
current = None
|
current = None
|
||||||
|
|
||||||
records = [make_record(module.params["soa"], SOA)]
|
records = itertools.chain(
|
||||||
|
make_records(module.params["soa"], SOA),
|
||||||
records.extend(
|
make_hosts_records(module.params["hosts"]),
|
||||||
itertools.chain.from_iterable(
|
itertools.chain.from_iterable(
|
||||||
(make_record(args, ty) for args in module.params[name])
|
itertools.chain.from_iterable(
|
||||||
for name, ty in record_types.items()
|
make_records(args, ty) for args in module.params[name]
|
||||||
)
|
)
|
||||||
|
for name, ty in record_types.items()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
|
@ -236,9 +312,13 @@ def main() -> int:
|
||||||
dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True)
|
dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True)
|
||||||
dataset.add(rdata)
|
dataset.add(rdata)
|
||||||
|
|
||||||
|
file_args = module.load_file_common_arguments(module.params)
|
||||||
|
|
||||||
changed = current is None or not zones_eq(zone, current)
|
changed = current is None or not zones_eq(zone, current)
|
||||||
if changed:
|
if changed:
|
||||||
zone.to_file(path, relativize=False)
|
zone.to_file(module.params["path"], relativize=False)
|
||||||
|
|
||||||
|
changed = module.set_fs_attributes_if_different(file_args, changed)
|
||||||
|
|
||||||
module.exit_json(changed=changed)
|
module.exit_json(changed=changed)
|
||||||
|
|
||||||
|
@ -246,4 +326,4 @@ def main() -> int:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
exit(main())
|
sys.exit(main())
|
||||||
|
|
Loading…
Reference in a new issue