from __future__ import annotations import itertools from dataclasses import dataclass from ipaddress import IPv4Address from typing import Any, Generic, Iterator, Literal, TypeVar from pydantic import ( BaseModel, Field, IPvAnyAddress, ValidationError, parse_obj_as, ) T = TypeVar("T") class AutoList(list[T], Generic[T]): @classmethod def __get_validators__(cls): yield cls.__validator__ @classmethod def __validator__(cls, value): try: return parse_obj_as(list[T], value) except ValidationError: return [parse_obj_as(T, value)] class Proto(BaseModel): protos: AutoList[str] class Source(BaseModel): sources: AutoList[int] class And(BaseModel): conditions: AutoList[Condition] = Field(alias="and") class Or(BaseModel): conditions: AutoList[Condition] = Field(alias="or") class Not(BaseModel): condition: Condition = Field(alias="not") Condition = Proto | Source | And | Or | Not And.update_forward_refs() Or.update_forward_refs() Not.update_forward_refs() Accept = Literal["accept"] Reject = Literal["reject"] class PrefSrc(BaseModel): pref_src: AutoList[IPvAnyAddress] class Conditional(BaseModel): condition: Condition = Field(alias="if") actions: AutoList[Action] = Field(alias="then") Action = Accept | Reject | PrefSrc | Conditional Conditional.update_forward_refs() Rule = Condition | AutoList[Action] @dataclass class Context: ipv4: bool indent: str verb: str def flatten(iterable: Iterable[Iterable[T]]) -> Iterable[T]: return itertools.chain.from_iterable(iterable) def indent(iterable, ctx: Context) -> Iterable[str]: yield from (f"{ctx.indent}{i}" for i in iterable) def filter_addrs(addrs, ctx: Context): yield from (a for a in addrs if isinstance(a, IPv4Address) == ctx.ipv4) def quoted(string: str) -> str: escaped = string.replace("\\", "\\\\").replace('"', '\\"') return f'"{escaped}"' def bird_name(name: str, ipv4: bool) -> str: return f"{name}{'4' if ipv4 else '6'}" def str_of_condition(condition: Condition, ctx: bool) -> str: match condition: case Proto(protos=[]) | Source(sources=[]) | Or(conditions=[]): return "false" case And(conditions=[]): return "true" case Not(condition=condition): return f"!{str_of_condition(condition)}" case And(conditions=[condition]) | Or(conditions=[condition]): return str_of_condition(condition, ctx) case And(conditions=conditions): return " && ".join( f"({str_of_condition(c, ctx)})" for c in conditions ) case Or(conditions=conditions): return " || ".join( f"({str_of_condition(c, ctx)})" for c in conditions ) case Proto(protos=[proto]): return f"proto = {quoted(bird_name(proto, ctx.ipv4))}" case Proto(protos=protos): protos = [quoted(bird_name(p, ctx.ipv4)) for p in protos] return f"proto ~ [ {', '.join(protos)} ]" case Source(sources=[source]): return f"krt_source = {source}" case Source(sources=sources): sources = [str(s) for s in sources] return f"krt_source ~ [ {', '.join(sources)} ]" def lines_of_action(action: Action, ctx: Context) -> Iterable[str]: match action: case "accept" | "reject": yield f"{action};" case Conditional(condition=condition, actions=actions): yield f"if {str_of_condition(condition, ctx)} then {'{'}" yield from indent( flatten(lines_of_action(a, ctx) for a in actions), ctx ) yield "}" case PrefSrc(pref_src=sources): source = next(filter_addrs(sources, ctx)) yield f"krt_prefsrc = {source};" def lines_of_stmt(rule: Rule, ctx: Context) -> Iterable[str]: match parse_obj_as(Rule, rule): case ["accept"]: yield f"{ctx.verb} all;" case [] | ["reject"]: yield f"{ctx.verb} none;" # FIXME case (Proto() | Source() | And() | Or() | Not()) as condition: # Conditional(condition=condition, actions=["accept"]) yield f"{ctx.verb} where {str_of_condition(condition, ctx)};" case _ as actions: yield f"{ctx.verb} filter {'{'}" yield from indent( flatten(lines_of_action(a, ctx) for a in actions), ctx ) yield "};" def bird_import(rule: Rule, ipv4: bool, indent: str = " ") -> str: ctx = Context(verb="import", ipv4=ipv4, indent=indent) return "\n".join(lines_of_stmt(rule, ctx)) def bird_export(rule: Rule, ipv4: bool, indent: str = " ") -> str: ctx = Context(verb="export", ipv4=ipv4, indent=indent) return "\n".join(lines_of_stmt(rule, ctx)) class FilterModule: def filters(self): return { "bird_import": bird_import, "bird_export": bird_export, "bird_name": bird_name, }