You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ansible/roles/bird/filter_plugins/bird.py

200 lines
5.0 KiB
Python

from __future__ import annotations
import itertools
from dataclasses import dataclass
from ipaddress import IPv4Address
from typing import Any, Generic, Iterator, Literal, TypeVar
from pydantic import (
BaseModel,
Field,
IPvAnyAddress,
ValidationError,
parse_obj_as,
)
T = TypeVar("T")
class AutoList(list[T], Generic[T]):
@classmethod
def __get_validators__(cls):
yield cls.__validator__
@classmethod
def __validator__(cls, value):
try:
return parse_obj_as(list[T], value)
except ValidationError:
return [parse_obj_as(T, value)]
class Proto(BaseModel):
protos: AutoList[str]
class Source(BaseModel):
sources: AutoList[int]
class And(BaseModel):
conditions: AutoList[Condition] = Field(alias="and")
class Or(BaseModel):
conditions: AutoList[Condition] = Field(alias="or")
class Not(BaseModel):
condition: Condition = Field(alias="not")
Condition = Proto | Source | And | Or | Not
And.update_forward_refs()
Or.update_forward_refs()
Not.update_forward_refs()
Accept = Literal["accept"]
Reject = Literal["reject"]
class PrefSrc(BaseModel):
pref_src: AutoList[IPvAnyAddress]
class Conditional(BaseModel):
condition: Condition = Field(alias="if")
actions: AutoList[Action] = Field(alias="then")
Action = Accept | Reject | PrefSrc | Conditional
Conditional.update_forward_refs()
Rule = Condition | AutoList[Action]
@dataclass
class Context:
ipv4: bool
indent: str
verb: str
def flatten(iterable: Iterable[Iterable[T]]) -> Iterable[T]:
return itertools.chain.from_iterable(iterable)
def indent(iterable, ctx: Context) -> Iterable[str]:
yield from (f"{ctx.indent}{i}" for i in iterable)
def filter_addrs(addrs, ctx: Context):
yield from (a for a in addrs if isinstance(a, IPv4Address) == ctx.ipv4)
def quoted(string: str) -> str:
escaped = string.replace("\\", "\\\\").replace('"', '\\"')
return f'"{escaped}"'
def bird_name(name: str, ipv4: bool) -> str:
return f"{name}{'4' if ipv4 else '6'}"
def str_of_condition(condition: Condition, ctx: bool) -> str:
match condition:
case Proto(protos=[]) | Source(sources=[]) | Or(conditions=[]):
return "false"
case And(conditions=[]):
return "true"
case Not(condition=condition):
return f"!{str_of_condition(condition)}"
case And(conditions=[condition]) | Or(conditions=[condition]):
return str_of_condition(condition, ctx)
case And(conditions=conditions):
return " && ".join(
f"({str_of_condition(c, ctx)})" for c in conditions
)
case Or(conditions=conditions):
return " || ".join(
f"({str_of_condition(c, ctx)})" for c in conditions
)
case Proto(protos=[proto]):
return f"proto = {quoted(bird_name(proto, ctx.ipv4))}"
case Proto(protos=protos):
protos = [quoted(bird_name(p, ctx.ipv4)) for p in protos]
return f"proto ~ [ {', '.join(protos)} ]"
case Source(sources=[source]):
return f"krt_source = {source}"
case Source(sources=sources):
sources = [str(s) for s in sources]
return f"krt_source ~ [ {', '.join(sources)} ]"
def lines_of_action(action: Action, ctx: Context) -> Iterable[str]:
match action:
case "accept" | "reject":
yield f"{action};"
case Conditional(condition=condition, actions=actions):
yield f"if {str_of_condition(condition, ctx)} then {'{'}"
yield from indent(
flatten(lines_of_action(a, ctx) for a in actions), ctx
)
yield "}"
case PrefSrc(pref_src=sources):
source = next(filter_addrs(sources, ctx))
yield f"krt_prefsrc = {source};"
def lines_of_stmt(rule: Rule, ctx: Context) -> Iterable[str]:
match parse_obj_as(Rule, rule):
case ["accept"]:
yield f"{ctx.verb} all;"
case [] | ["reject"]:
yield f"{ctx.verb} none;"
# FIXME
case (Proto() | Source() | And() | Or() | Not()) as condition:
# Conditional(condition=condition, actions=["accept"])
yield f"{ctx.verb} where {str_of_condition(condition, ctx)};"
case _ as actions:
yield f"{ctx.verb} filter {'{'}"
yield from indent(
flatten(lines_of_action(a, ctx) for a in actions), ctx
)
yield "};"
def bird_import(rule: Rule, ipv4: bool, indent: str = " ") -> str:
ctx = Context(verb="import", ipv4=ipv4, indent=indent)
return "\n".join(lines_of_stmt(rule, ctx))
def bird_export(rule: Rule, ipv4: bool, indent: str = " ") -> str:
ctx = Context(verb="export", ipv4=ipv4, indent=indent)
return "\n".join(lines_of_stmt(rule, ctx))
class FilterModule:
def filters(self):
return {
"bird_import": bird_import,
"bird_export": bird_export,
"bird_name": bird_name,
}