from __future__ import annotations import itertools import re from dataclasses import dataclass from ipaddress import IPv4Address, IPv4Network, IPv6Network, ip_network from typing import Any, Generic, Iterator, Literal, TypeVar from pydantic import ( BaseModel, Field, IPvAnyAddress, ValidationError, parse_obj_as, ) T = TypeVar("T") class AutoList(list[T], Generic[T]): @classmethod def __get_validators__(cls): yield cls.__validator__ @classmethod def __validator__(cls, value): try: return parse_obj_as(list[T], value) except ValidationError: return [parse_obj_as(T, value)] VARIABLES = { "net.len": "net.len", } def interpolate(string: str, ctx: Context) -> str: pattern = r"(? str: try: return VARIABLES[var] except KeyError: return quoted(getattr(ctx, var)) split = re.split(pattern, string) parts = [ (lookup(p[2:-1]) if re.match(pattern, p) else quoted(p)) for p in split if p ] return ", ".join(parts) class Proto(BaseModel): protos: AutoList[str] class Source(BaseModel): sources: AutoList[int] class And(BaseModel): conditions: AutoList[Condition] = Field(alias="and") class Or(BaseModel): conditions: AutoList[Condition] = Field(alias="or") class Not(BaseModel): condition: Condition = Field(alias="not") class AsPathContains(BaseModel): contains: AutoList[int] = Field(alias="as_path.contains") class AsPathLength(BaseModel): length: list[int] = Field( ge=0, min_items=2, max_items=2, alias="as_path.len" ) class IPv4orIPv6(BaseModel): ipv4: list[int] = Field(ge=0, min_items=2, max_items=2) ipv6: list[int] = Field(ge=0, min_items=2, max_items=2) class IPFlag: @classmethod def __get_validators__(cls): yield cls.validate @classmethod def validate(cls, v): pattern = r"(?P.*?)(?P[+-]|\{[0-9]+,[0-9]+\})?$" parts = re.match(pattern, v) return (ip_network(parts.group("ip")), parts.group("flag") or "") class NetMatch(BaseModel): matches: list[IPFlag] = Field(alias="net.match") class NetLength(BaseModel): length: IPv4orIPv6 = Field(alias="net.len") Condition = ( Proto | Source | And | Or | Not | AsPathContains | AsPathLength | NetLength | NetMatch ) And.update_forward_refs() Or.update_forward_refs() Not.update_forward_refs() Accept = Literal["accept"] Reject = Literal["reject"] class RejectWithMsg(BaseModel): reject: str class PrefSrc(BaseModel): pref_src: AutoList[IPvAnyAddress] class Conditional(BaseModel): condition: Condition = Field(alias="if") actions: AutoList[Action] = Field(alias="then") Action = Accept | Reject | RejectWithMsg | PrefSrc | Conditional Conditional.update_forward_refs() Rule = Condition | AutoList[Action] @dataclass class Context: ipv4: bool indent: str verb: str def flatten(iterable: Iterable[Iterable[T]]) -> Iterable[T]: return itertools.chain.from_iterable(iterable) def indent(iterable, ctx: Context) -> Iterable[str]: yield from (f"{ctx.indent}{i}" for i in iterable) def filter_addrs(addrs, ctx: Context): yield from (a for a in addrs if isinstance(a, IPv4Address) == ctx.ipv4) def quoted(string: str) -> str: escaped = string.replace("\\", "\\\\").replace('"', '\\"') return f'"{escaped}"' def bird_name(name: str, ipv4: bool) -> str: return f"{name}{'4' if ipv4 else '6'}" def str_of_condition(condition: Condition, ctx: bool) -> str: match condition: case Proto(protos=[]) | Source(sources=[]) | Or(conditions=[]): return "false" case And(conditions=[]): return "true" case Not(condition=condition): return f"!{str_of_condition(condition)}" case And(conditions=[condition]) | Or(conditions=[condition]): return str_of_condition(condition, ctx) case And(conditions=conditions): return " && ".join( f"({str_of_condition(c, ctx)})" for c in conditions ) case Or(conditions=conditions): return " || ".join( f"({str_of_condition(c, ctx)})" for c in conditions ) case Proto(protos=[proto]): return f"proto = {quoted(bird_name(proto, ctx.ipv4))}" case Proto(protos=protos): protos = [quoted(bird_name(p, ctx.ipv4)) for p in protos] return " || ".join(f"proto = {p}" for p in protos) case Source(sources=[source]): return f"krt_source = {source}" case Source(sources=sources): sources = [str(s) for s in sources] return f"krt_source ~ [ {', '.join(sources)} ]" case AsPathContains(contains=contains): return ( f"bgp_path ~ [ {', '.join([str(asn) for asn in contains])} ]" ) case AsPathLength(length=[min_len, max_len]): return f"{min_len} <= bgp_path.len && bgp_path.len <= {max_len}" case NetLength( length=IPv4orIPv6(ipv4=[min_v4, max_v4], ipv6=[min_v6, max_v6]) ): if ctx.ipv4: return f"{min_v4} <= net.len && net.len <= {max_v4}" else: return f"{min_v6} <= net.len && net.len <= {max_v6}" case NetMatch(matches=matches): networkType = IPv4Network if ctx.ipv4 else IPv6Network networks = [m for m in matches if isinstance(m[0], networkType)] return f"net ~ [ {', '.join([f'{network}{flag}' for (network, flag) in networks])} ]" def lines_of_action(action: Action, ctx: Context) -> Iterable[str]: match action: case "accept" | "reject": yield f"{action};" case RejectWithMsg(reject=reason): yield f"reject {interpolate(reason, ctx)};" case Conditional(condition=condition, actions=actions): yield f"if {str_of_condition(condition, ctx)} then {'{'}" yield from indent( flatten(lines_of_action(a, ctx) for a in actions), ctx ) yield "}" case PrefSrc(pref_src=sources): source = next(filter_addrs(sources, ctx)) yield f"krt_prefsrc = {source};" def lines_of_stmt(rule: Rule, ctx: Context) -> Iterable[str]: match parse_obj_as(Rule, rule): case ["accept"]: yield f"{ctx.verb} all;" case [] | ["reject"]: yield f"{ctx.verb} none;" # FIXME case (Proto() | Source() | And() | Or() | Not()) as condition: # Conditional(condition=condition, actions=["accept"]) yield f"{ctx.verb} where {str_of_condition(condition, ctx)};" case _ as actions: yield f"{ctx.verb} filter {'{'}" yield from indent( flatten(lines_of_action(a, ctx) for a in actions), ctx ) yield "};" def bird_import(rule: Rule, ipv4: bool, indent: str = " ") -> str: ctx = Context(verb="import", ipv4=ipv4, indent=indent) return "\n".join(lines_of_stmt(rule, ctx)) def bird_export(rule: Rule, ipv4: bool, indent: str = " ") -> str: ctx = Context(verb="export", ipv4=ipv4, indent=indent) return "\n".join(lines_of_stmt(rule, ctx)) class FilterModule: def filters(self): return { "bird_import": bird_import, "bird_export": bird_export, "bird_name": bird_name, }