dns_zone: cleanup + hosts + product
This commit is contained in:
parent
c97dca8fa8
commit
4dbe0e562d
1 changed files with 120 additions and 40 deletions
|
@ -1,26 +1,30 @@
|
|||
#!/usr/bin/env python3
|
||||
import itertools
|
||||
import dataclasses
|
||||
|
||||
from typing import Any
|
||||
import ipaddress
|
||||
import itertools
|
||||
import sys
|
||||
import typing
|
||||
from typing import Annotated, Any
|
||||
|
||||
import dns
|
||||
import dns.serial
|
||||
import dns.zone
|
||||
import dns.rdata
|
||||
import dns.rdataclass
|
||||
import dns.rdatatype
|
||||
import dns.rdtypes.ANY.CNAME
|
||||
import dns.rdtypes.ANY.MX
|
||||
import dns.rdtypes.ANY.NS
|
||||
import dns.rdtypes.ANY.SOA
|
||||
import dns.rdtypes.ANY.TXT
|
||||
import dns.rdtypes.IN.A
|
||||
import dns.rdtypes.IN.AAAA
|
||||
import dns.rdtypes.ANY.MX
|
||||
import dns.rdtypes.ANY.SOA
|
||||
import dns.rdtypes.ANY.NS
|
||||
import dns.rdtypes.ANY.TXT
|
||||
|
||||
import dns.serial
|
||||
import dns.zone
|
||||
from ansible.module_utils.basic import AnsibleModule
|
||||
|
||||
|
||||
class RName(dns.name.Name):
|
||||
"""Domain name used to represent an e-mail address (see RFC 1035)."""
|
||||
|
||||
def __init__(self, address):
|
||||
try:
|
||||
local, domain = address.split("@")
|
||||
|
@ -31,10 +35,18 @@ class RName(dns.name.Name):
|
|||
super().__init__((local,) + dns.name.from_text(domain).labels)
|
||||
|
||||
|
||||
class MultiRecords:
|
||||
"""Annotation used to indicate that a field can be filled in more than
|
||||
once via a list, and that this will create as many records as values.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class A:
|
||||
address: str
|
||||
name: dns.name.Name = dns.name.empty
|
||||
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||
|
||||
def rdata(self) -> dns.rdata.Rdata:
|
||||
return dns.rdtypes.IN.A.A(
|
||||
|
@ -45,7 +57,7 @@ class A:
|
|||
@dataclasses.dataclass
|
||||
class AAAA:
|
||||
address: str
|
||||
name: dns.name.Name = dns.name.empty
|
||||
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||
|
||||
def rdata(self) -> dns.rdata.Rdata:
|
||||
return dns.rdtypes.IN.AAAA.AAAA(
|
||||
|
@ -56,7 +68,7 @@ class AAAA:
|
|||
@dataclasses.dataclass
|
||||
class CNAME:
|
||||
address: dns.name.Name
|
||||
name: dns.name.Name = dns.name.empty
|
||||
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||
|
||||
def rdata(self) -> dns.rdata.Rdata:
|
||||
return dns.rdtypes.ANY.CNAME.CNAME(
|
||||
|
@ -66,8 +78,8 @@ class CNAME:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class MX:
|
||||
exchange: dns.name.Name
|
||||
name: dns.name.Name = dns.name.empty
|
||||
exchange: Annotated[dns.name.Name, MultiRecords]
|
||||
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||
priority: int = 10
|
||||
|
||||
def rdata(self) -> dns.rdata.Rdata:
|
||||
|
@ -81,8 +93,8 @@ class MX:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class NS:
|
||||
address: dns.name.Name
|
||||
name: dns.name.Name = dns.name.empty
|
||||
address: Annotated[dns.name.Name, MultiRecords]
|
||||
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||
|
||||
def rdata(self) -> dns.rdata.Rdata:
|
||||
return dns.rdtypes.ANY.NS.NS(
|
||||
|
@ -93,7 +105,7 @@ class NS:
|
|||
@dataclasses.dataclass
|
||||
class TXT:
|
||||
data: str
|
||||
name: dns.name.Name = dns.name.empty
|
||||
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||
|
||||
def rdata(self) -> dns.rdata.Rdata:
|
||||
return dns.rdtypes.ANY.TXT.TXT(
|
||||
|
@ -110,7 +122,7 @@ class SOA:
|
|||
expire: int
|
||||
ttl: int
|
||||
serial: int = 1
|
||||
name: dns.name.Name = dns.name.empty
|
||||
name: Annotated[dns.name.Name, MultiRecords] = dns.name.empty
|
||||
|
||||
def rdata(self) -> dns.rdata.Rdata:
|
||||
return dns.rdtypes.ANY.SOA.SOA(
|
||||
|
@ -126,6 +138,25 @@ class SOA:
|
|||
)
|
||||
|
||||
|
||||
def has_annotation(ty, annotation):
|
||||
"""Is the type `ty` annotated with a given `annotation`."""
|
||||
return (
|
||||
typing.get_origin(ty) == typing.Annotated
|
||||
and annotation in typing.get_args(ty)[1:]
|
||||
)
|
||||
|
||||
|
||||
def annotated_origin(ty):
|
||||
"""Returns the origin of an annotated type `ty`."""
|
||||
assert typing.get_origin(ty) == typing.Annotated
|
||||
return typing.get_args(ty)[0]
|
||||
|
||||
|
||||
def is_multi_records(ty):
|
||||
"""Is the type `ty` annotated with `MultiRecords`."""
|
||||
return has_annotation(ty, MultiRecords)
|
||||
|
||||
|
||||
def spec_option_of_field(field):
|
||||
types = {
|
||||
str: "str",
|
||||
|
@ -133,13 +164,20 @@ def spec_option_of_field(field):
|
|||
RName: "str",
|
||||
int: "int",
|
||||
}
|
||||
return {
|
||||
"type": types[field.type],
|
||||
"required": field.default is dataclasses.MISSING,
|
||||
if is_multi_records(field.type):
|
||||
option = {
|
||||
"type": "list",
|
||||
"elements": types[annotated_origin(field.type)],
|
||||
}
|
||||
else:
|
||||
option = {"type": types[field.type]}
|
||||
option["required"] = field.default is dataclasses.MISSING
|
||||
return option
|
||||
|
||||
|
||||
def spec_options_of_type(ty):
|
||||
"""Convert a `dataclass` type to Ansible `argument_spec` `options`'
|
||||
format."""
|
||||
return {
|
||||
field.name: spec_option_of_field(field)
|
||||
for field in dataclasses.fields(ty)
|
||||
|
@ -147,12 +185,32 @@ def spec_options_of_type(ty):
|
|||
|
||||
|
||||
def coerce_dns_name(value: Any) -> dns.name.Name:
|
||||
"""Try to convert a `value` to `dns.name.Name`."""
|
||||
if not isinstance(value, dns.name.Name):
|
||||
return dns.name.from_text(value, origin=dns.name.empty)
|
||||
return value
|
||||
|
||||
|
||||
def make_record(args, ty):
|
||||
def product_dict(dct, keys=None):
|
||||
"""Compute the "cartesian product" of a dictionnary `dct`
|
||||
w.r.t some `keys` (if `keys` is None, then the product is computed
|
||||
on all the keys)."""
|
||||
if keys is None:
|
||||
keys = dct.keys()
|
||||
wrapped = {k: v if k in keys else [v] for k, v in dct.items()}
|
||||
for values in itertools.product(*wrapped.values()):
|
||||
yield dict(zip(wrapped.keys(), values))
|
||||
|
||||
|
||||
def make_hosts_records(hosts):
|
||||
for host, addrs in hosts.items():
|
||||
for addr in addrs:
|
||||
name = dns.name.from_text(host, origin=dns.name.empty)
|
||||
decoded = ipaddress.ip_address(addr)
|
||||
yield AAAA(addr, name) if decoded.version == 6 else A(addr, name)
|
||||
|
||||
|
||||
def make_records(args, ty):
|
||||
# TODO: Ça n'est pas du tout élégant, mais :
|
||||
# 1. je n'ai pas réussi à spécifier dans `argument_spec` un type tiers
|
||||
# 2. Ansible positionne à `None` les entrées non passées à la tâche et
|
||||
|
@ -163,10 +221,16 @@ def make_record(args, ty):
|
|||
RName: RName,
|
||||
}
|
||||
|
||||
def coerce(name, value):
|
||||
if types[name] not in coercers:
|
||||
def coerce_single(value, ty):
|
||||
if ty in coercers:
|
||||
return coercers[ty](value)
|
||||
return value
|
||||
return coercers[types[name]](value)
|
||||
|
||||
def coerce(name, value):
|
||||
if is_multi_records(types[name]):
|
||||
origin = annotated_origin(types[name])
|
||||
return [coerce_single(v, origin) for v in value]
|
||||
return coerce_single(value, types[name])
|
||||
|
||||
clean_args = {
|
||||
name: coerce(name, value)
|
||||
|
@ -174,11 +238,16 @@ def make_record(args, ty):
|
|||
if value is not None
|
||||
}
|
||||
|
||||
return ty(**clean_args)
|
||||
multi_keys = (k for k, v in types.items() if is_multi_records(v))
|
||||
|
||||
for single_args in product_dict(clean_args, multi_keys):
|
||||
yield ty(**single_args)
|
||||
|
||||
|
||||
def zones_eq(a: dns.zone.Zone, b: dns.zone.Zone) -> bool:
|
||||
return a.to_text(relativize=False) == b.to_text(relativize=False)
|
||||
def zones_eq(lhs: dns.zone.Zone, rhs: dns.zone.Zone) -> bool:
|
||||
"""Returns a `bool` indicating whether two `dns.zone.Zone`s are equal
|
||||
w.r.t. their text representation."""
|
||||
return lhs.to_text(relativize=False) == rhs.to_text(relativize=False)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
|
@ -188,17 +257,19 @@ def main() -> int:
|
|||
"txt": TXT,
|
||||
"a": A,
|
||||
"aaaa": AAAA,
|
||||
"cname": CNAME,
|
||||
"mx": MX,
|
||||
}
|
||||
|
||||
module_args = {
|
||||
"path": {"type": "path", "required": True},
|
||||
"path": {"type": "str", "required": True},
|
||||
"origin": {"type": "str", "required": True},
|
||||
"soa": {
|
||||
"type": "dict",
|
||||
"required": True,
|
||||
"options": spec_options_of_type(SOA),
|
||||
},
|
||||
"hosts": {"type": "dict", "default": {}},
|
||||
}
|
||||
|
||||
for name, ty in record_types.items():
|
||||
|
@ -209,7 +280,10 @@ def main() -> int:
|
|||
"options": spec_options_of_type(ty),
|
||||
}
|
||||
|
||||
module = AnsibleModule(argument_spec=module_args)
|
||||
module = AnsibleModule(
|
||||
argument_spec=module_args,
|
||||
add_file_common_args=True,
|
||||
)
|
||||
|
||||
origin = dns.name.from_text(module.params["origin"])
|
||||
path = module.params["path"]
|
||||
|
@ -218,16 +292,18 @@ def main() -> int:
|
|||
|
||||
try:
|
||||
current = dns.zone.from_file(path, origin=origin)
|
||||
except:
|
||||
except Exception:
|
||||
current = None
|
||||
|
||||
records = [make_record(module.params["soa"], SOA)]
|
||||
|
||||
records.extend(
|
||||
records = itertools.chain(
|
||||
make_records(module.params["soa"], SOA),
|
||||
make_hosts_records(module.params["hosts"]),
|
||||
itertools.chain.from_iterable(
|
||||
(make_record(args, ty) for args in module.params[name])
|
||||
for name, ty in record_types.items()
|
||||
itertools.chain.from_iterable(
|
||||
make_records(args, ty) for args in module.params[name]
|
||||
)
|
||||
for name, ty in record_types.items()
|
||||
),
|
||||
)
|
||||
|
||||
for record in records:
|
||||
|
@ -236,9 +312,13 @@ def main() -> int:
|
|||
dataset = node.get_rdataset(rdata.rdclass, rdata.rdtype, create=True)
|
||||
dataset.add(rdata)
|
||||
|
||||
file_args = module.load_file_common_arguments(module.params)
|
||||
|
||||
changed = current is None or not zones_eq(zone, current)
|
||||
if changed:
|
||||
zone.to_file(path, relativize=False)
|
||||
zone.to_file(module.params["path"], relativize=False)
|
||||
|
||||
changed = module.set_fs_attributes_if_different(file_args, changed)
|
||||
|
||||
module.exit_json(changed=changed)
|
||||
|
||||
|
@ -246,4 +326,4 @@ def main() -> int:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
sys.exit(main())
|
||||
|
|
Loading…
Reference in a new issue