ansible/roles/bird/filter_plugins/bird.py

299 lines
7.5 KiB
Python

from __future__ import annotations
import itertools
import re
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv4Network, IPv6Network, ip_network
from typing import Any, Generic, Iterator, Literal, TypeVar
from pydantic import (
BaseModel,
Field,
IPvAnyAddress,
ValidationError,
parse_obj_as,
)
T = TypeVar("T")
class AutoList(list[T], Generic[T]):
@classmethod
def __get_validators__(cls):
yield cls.__validator__
@classmethod
def __validator__(cls, value):
try:
return parse_obj_as(list[T], value)
except ValidationError:
return [parse_obj_as(T, value)]
VARIABLES = {
"net.len": "net.len",
}
def interpolate(string: str, ctx: Context) -> str:
pattern = r"(?<!\\)(\$\{[_a-z][_a-z0-9.]*\})"
def lookup(var: str) -> str:
try:
return VARIABLES[var]
except KeyError:
return quoted(getattr(ctx, var))
split = re.split(pattern, string)
parts = [
(lookup(p[2:-1]) if re.match(pattern, p) else quoted(p))
for p in split
if p
]
return ", ".join(parts)
class Proto(BaseModel):
protos: AutoList[str]
class Source(BaseModel):
sources: AutoList[int]
class And(BaseModel):
conditions: AutoList[Condition] = Field(alias="and")
class Or(BaseModel):
conditions: AutoList[Condition] = Field(alias="or")
class Not(BaseModel):
condition: Condition = Field(alias="not")
class AsPathContains(BaseModel):
contains: AutoList[int] = Field(alias="as_path.contains")
class AsPathLength(BaseModel):
length: list[int] = Field(
ge=0, min_items=2, max_items=2, alias="as_path.len"
)
class IPv4orIPv6(BaseModel):
ipv4: list[int] = Field(ge=0, min_items=2, max_items=2)
ipv6: list[int] = Field(ge=0, min_items=2, max_items=2)
class IPFlag:
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
pattern = r"(?P<ip>.*?)(?P<flag>[+-]|\{[0-9]+,[0-9]+\})?$"
parts = re.match(pattern, v)
return (ip_network(parts.group("ip")), parts.group("flag") or "")
class NetMatch(BaseModel):
matches: list[IPFlag] = Field(alias="net.match")
class NetLength(BaseModel):
length: IPv4orIPv6 = Field(alias="net.len")
Condition = (
Proto
| Source
| And
| Or
| Not
| AsPathContains
| AsPathLength
| NetLength
| NetMatch
)
And.update_forward_refs()
Or.update_forward_refs()
Not.update_forward_refs()
Accept = Literal["accept"]
Reject = Literal["reject"]
class RejectWithMsg(BaseModel):
reject: str
class PrefSrc(BaseModel):
pref_src: AutoList[IPvAnyAddress]
class Conditional(BaseModel):
condition: Condition = Field(alias="if")
actions: AutoList[Action] = Field(alias="then")
Action = Accept | Reject | RejectWithMsg | PrefSrc | Conditional
Conditional.update_forward_refs()
Rule = Condition | AutoList[Action]
@dataclass
class Context:
ipv4: bool
indent: str
verb: str
def flatten(iterable: Iterable[Iterable[T]]) -> Iterable[T]:
return itertools.chain.from_iterable(iterable)
def indent(iterable, ctx: Context) -> Iterable[str]:
yield from (f"{ctx.indent}{i}" for i in iterable)
def filter_addrs(addrs, ctx: Context):
yield from (a for a in addrs if isinstance(a, IPv4Address) == ctx.ipv4)
def quoted(string: str) -> str:
escaped = string.replace("\\", "\\\\").replace('"', '\\"')
return f'"{escaped}"'
def bird_name(name: str, ipv4: bool) -> str:
return f"{name}{'4' if ipv4 else '6'}"
def str_of_condition(condition: Condition, ctx: bool) -> str:
match condition:
case Proto(protos=[]) | Source(sources=[]) | Or(conditions=[]):
return "false"
case And(conditions=[]):
return "true"
case Not(condition=condition):
return f"!{str_of_condition(condition)}"
case And(conditions=[condition]) | Or(conditions=[condition]):
return str_of_condition(condition, ctx)
case And(conditions=conditions):
return " && ".join(
f"({str_of_condition(c, ctx)})" for c in conditions
)
case Or(conditions=conditions):
return " || ".join(
f"({str_of_condition(c, ctx)})" for c in conditions
)
case Proto(protos=[proto]):
return f"proto = {quoted(bird_name(proto, ctx.ipv4))}"
case Proto(protos=protos):
protos = [quoted(bird_name(p, ctx.ipv4)) for p in protos]
return " || ".join(f"proto = {p}" for p in protos)
case Source(sources=[source]):
return f"krt_source = {source}"
case Source(sources=sources):
sources = [str(s) for s in sources]
return f"krt_source ~ [ {', '.join(sources)} ]"
case AsPathContains(contains=contains):
return (
f"bgp_path ~ [ {', '.join([str(asn) for asn in contains])} ]"
)
case AsPathLength(length=[min_len, max_len]):
return f"{min_len} <= bgp_path.len && bgp_path.len <= {max_len}"
case NetLength(
length=IPv4orIPv6(ipv4=[min_v4, max_v4], ipv6=[min_v6, max_v6])
):
if ctx.ipv4:
return f"{min_v4} <= net.len && net.len <= {max_v4}"
else:
return f"{min_v6} <= net.len && net.len <= {max_v6}"
case NetMatch(matches=matches):
networkType = IPv4Network if ctx.ipv4 else IPv6Network
networks = [m for m in matches if isinstance(m[0], networkType)]
return f"net ~ [ {', '.join([f'{network}{flag}' for (network, flag) in networks])} ]"
def lines_of_action(action: Action, ctx: Context) -> Iterable[str]:
match action:
case "accept" | "reject":
yield f"{action};"
case RejectWithMsg(reject=reason):
yield f"reject {interpolate(reason, ctx)};"
case Conditional(condition=condition, actions=actions):
yield f"if {str_of_condition(condition, ctx)} then {'{'}"
yield from indent(
flatten(lines_of_action(a, ctx) for a in actions), ctx
)
yield "}"
case PrefSrc(pref_src=sources):
source = next(filter_addrs(sources, ctx))
yield f"krt_prefsrc = {source};"
def lines_of_stmt(rule: Rule, ctx: Context) -> Iterable[str]:
match parse_obj_as(Rule, rule):
case ["accept"]:
yield f"{ctx.verb} all;"
case [] | ["reject"]:
yield f"{ctx.verb} none;"
# FIXME
case (Proto() | Source() | And() | Or() | Not()) as condition:
# Conditional(condition=condition, actions=["accept"])
yield f"{ctx.verb} where {str_of_condition(condition, ctx)};"
case _ as actions:
yield f"{ctx.verb} filter {'{'}"
yield from indent(
flatten(lines_of_action(a, ctx) for a in actions), ctx
)
yield "};"
def bird_import(rule: Rule, ipv4: bool, indent: str = " ") -> str:
ctx = Context(verb="import", ipv4=ipv4, indent=indent)
return "\n".join(lines_of_stmt(rule, ctx))
def bird_export(rule: Rule, ipv4: bool, indent: str = " ") -> str:
ctx = Context(verb="export", ipv4=ipv4, indent=indent)
return "\n".join(lines_of_stmt(rule, ctx))
class FilterModule:
def filters(self):
return {
"bird_import": bird_import,
"bird_export": bird_export,
"bird_name": bird_name,
}