305 lines
7.6 KiB
Python
305 lines
7.6 KiB
Python
from __future__ import annotations
|
|
|
|
import itertools
|
|
import re
|
|
from dataclasses import dataclass
|
|
from ipaddress import IPv4Address, IPv4Network, IPv6Network, ip_network
|
|
from typing import Any, Generic, Iterator, Literal, TypeVar
|
|
|
|
from pydantic import (
|
|
BaseModel,
|
|
Field,
|
|
IPvAnyAddress,
|
|
ValidationError,
|
|
parse_obj_as,
|
|
)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class AutoList(list[T], Generic[T]):
|
|
@classmethod
|
|
def __get_validators__(cls):
|
|
yield cls.__validator__
|
|
|
|
@classmethod
|
|
def __validator__(cls, value):
|
|
try:
|
|
return parse_obj_as(list[T], value)
|
|
except ValidationError:
|
|
return [parse_obj_as(T, value)]
|
|
|
|
|
|
VARIABLES = {
|
|
"net.len": "net.len",
|
|
}
|
|
|
|
|
|
def interpolate(string: str, ctx: Context) -> str:
|
|
pattern = r"(?<!\\)(\$\{[_a-z][_a-z0-9.]*\})"
|
|
|
|
def lookup(var: str) -> str:
|
|
try:
|
|
return VARIABLES[var]
|
|
except KeyError:
|
|
return quoted(getattr(ctx, var))
|
|
|
|
split = re.split(pattern, string)
|
|
parts = [
|
|
(lookup(p[2:-1]) if re.match(pattern, p) else quoted(p))
|
|
for p in split
|
|
if p
|
|
]
|
|
|
|
return ", ".join(parts)
|
|
|
|
|
|
class Proto(BaseModel):
|
|
protos: AutoList[str]
|
|
|
|
|
|
class Source(BaseModel):
|
|
sources: AutoList[int]
|
|
|
|
|
|
class And(BaseModel):
|
|
conditions: AutoList[Condition] = Field(alias="and")
|
|
|
|
|
|
class Or(BaseModel):
|
|
conditions: AutoList[Condition] = Field(alias="or")
|
|
|
|
|
|
class Not(BaseModel):
|
|
condition: Condition = Field(alias="not")
|
|
|
|
|
|
class AsPathContains(BaseModel):
|
|
contains: AutoList[int] = Field(alias="as_path.contains")
|
|
|
|
|
|
class AsPathLength(BaseModel):
|
|
length: list[int] = Field(
|
|
ge=0, min_items=2, max_items=2, alias="as_path.len"
|
|
)
|
|
|
|
|
|
class IPv4orIPv6(BaseModel):
|
|
ipv4: list[int] = Field(ge=0, min_items=2, max_items=2)
|
|
ipv6: list[int] = Field(ge=0, min_items=2, max_items=2)
|
|
|
|
|
|
class IPFlag(str):
|
|
@classmethod
|
|
def __get_validators__(cls):
|
|
yield cls.validate
|
|
|
|
@classmethod
|
|
def validate(cls, v):
|
|
pattern = r"(?P<ip>.*?)(?P<flag>[+-]|\{[0-9]+,[0-9]+\})?$"
|
|
parts = re.match(pattern, v)
|
|
|
|
return (ip_network(parts.group("ip")), parts.group("flag"))
|
|
|
|
|
|
class NetMatch(BaseModel):
|
|
matches: list[IPFlag] = Field(alias="net.match")
|
|
|
|
|
|
class NetLength(BaseModel):
|
|
length: IPv4orIPv6 = Field(alias="net.len")
|
|
|
|
|
|
Condition = (
|
|
Proto
|
|
| Source
|
|
| And
|
|
| Or
|
|
| Not
|
|
| AsPathContains
|
|
| AsPathLength
|
|
| NetLength
|
|
| NetMatch
|
|
)
|
|
|
|
And.update_forward_refs()
|
|
Or.update_forward_refs()
|
|
Not.update_forward_refs()
|
|
|
|
|
|
Accept = Literal["accept"]
|
|
|
|
Reject = Literal["reject"]
|
|
|
|
|
|
class RejectWithMsg(BaseModel):
|
|
reject: str
|
|
|
|
|
|
class PrefSrc(BaseModel):
|
|
pref_src: AutoList[IPvAnyAddress]
|
|
|
|
|
|
class Conditional(BaseModel):
|
|
condition: Condition = Field(alias="if")
|
|
actions: AutoList[Action] = Field(alias="then")
|
|
|
|
|
|
Action = Accept | Reject | RejectWithMsg | PrefSrc | Conditional
|
|
|
|
Conditional.update_forward_refs()
|
|
|
|
|
|
Rule = Condition | AutoList[Action]
|
|
|
|
|
|
@dataclass
|
|
class Context:
|
|
ipv4: bool
|
|
indent: str
|
|
verb: str
|
|
|
|
|
|
def flatten(iterable: Iterable[Iterable[T]]) -> Iterable[T]:
|
|
return itertools.chain.from_iterable(iterable)
|
|
|
|
|
|
def indent(iterable, ctx: Context) -> Iterable[str]:
|
|
yield from (f"{ctx.indent}{i}" for i in iterable)
|
|
|
|
|
|
def filter_addrs(addrs, ctx: Context):
|
|
yield from (a for a in addrs if isinstance(a, IPv4Address) == ctx.ipv4)
|
|
|
|
|
|
def quoted(string: str) -> str:
|
|
escaped = string.replace("\\", "\\\\").replace('"', '\\"')
|
|
return f'"{escaped}"'
|
|
|
|
|
|
def bird_name(name: str, ipv4: bool) -> str:
|
|
return f"{name}{'4' if ipv4 else '6'}"
|
|
|
|
|
|
def str_of_condition(condition: Condition, ctx: bool) -> str:
|
|
match condition:
|
|
case Proto(protos=[]) | Source(sources=[]) | Or(conditions=[]):
|
|
return "false"
|
|
|
|
case And(conditions=[]):
|
|
return "true"
|
|
|
|
case Not(condition=condition):
|
|
return f"!{str_of_condition(condition)}"
|
|
|
|
case And(conditions=[condition]) | Or(conditions=[condition]):
|
|
return str_of_condition(condition, ctx)
|
|
|
|
case And(conditions=conditions):
|
|
return " && ".join(
|
|
f"({str_of_condition(c, ctx)})" for c in conditions
|
|
)
|
|
|
|
case Or(conditions=conditions):
|
|
return " || ".join(
|
|
f"({str_of_condition(c, ctx)})" for c in conditions
|
|
)
|
|
|
|
case Proto(protos=[proto]):
|
|
return f"proto = {quoted(bird_name(proto, ctx.ipv4))}"
|
|
|
|
case Proto(protos=protos):
|
|
protos = [quoted(bird_name(p, ctx.ipv4)) for p in protos]
|
|
return " || ".join(f"proto = {p}" for p in protos)
|
|
|
|
case Source(sources=[source]):
|
|
return f"krt_source = {source}"
|
|
|
|
case Source(sources=sources):
|
|
sources = [str(s) for s in sources]
|
|
return f"krt_source ~ [ {', '.join(sources)} ]"
|
|
|
|
case AsPathContains(contains=contains):
|
|
return (
|
|
f"bgp_path ~ [ {', '.join([str(asn) for asn in contains])} ]"
|
|
)
|
|
|
|
case AsPathLength(length=[min_len, max_len]):
|
|
return f"{min_len} <= bgp_path.len && bgp_path.len <= {max_len}"
|
|
|
|
case NetLength(
|
|
length=IPv4orIPv6(ipv4=[min_v4, max_v4], ipv6=[min_v6, max_v6])
|
|
):
|
|
if ctx.ipv4:
|
|
return f"{min_v4} <= net.len && net.len <= {max_v4}"
|
|
else:
|
|
return f"{min_v6} <= net.len && net.len <= {max_v6}"
|
|
|
|
case NetMatch(matches=matches):
|
|
if ctx.ipv4:
|
|
networks = [
|
|
m for m in matches if isinstance(m[0], IPv4Network)
|
|
]
|
|
else:
|
|
networks = [
|
|
m for m in matches if isinstance(m[0], IPv6Network)
|
|
]
|
|
|
|
return f"net ~ [ {', '.join([f'{network}{str(flag)}' for (network, flag) in networks])} ]"
|
|
|
|
|
|
def lines_of_action(action: Action, ctx: Context) -> Iterable[str]:
|
|
match action:
|
|
case "accept" | "reject":
|
|
yield f"{action};"
|
|
|
|
case RejectWithMsg(reject=reason):
|
|
yield f"reject {interpolate(reason, ctx)};"
|
|
|
|
case Conditional(condition=condition, actions=actions):
|
|
yield f"if {str_of_condition(condition, ctx)} then {'{'}"
|
|
yield from indent(
|
|
flatten(lines_of_action(a, ctx) for a in actions), ctx
|
|
)
|
|
yield "}"
|
|
|
|
case PrefSrc(pref_src=sources):
|
|
source = next(filter_addrs(sources, ctx))
|
|
yield f"krt_prefsrc = {source};"
|
|
|
|
|
|
def lines_of_stmt(rule: Rule, ctx: Context) -> Iterable[str]:
|
|
match parse_obj_as(Rule, rule):
|
|
case ["accept"]:
|
|
yield f"{ctx.verb} all;"
|
|
case [] | ["reject"]:
|
|
yield f"{ctx.verb} none;"
|
|
# FIXME
|
|
case (Proto() | Source() | And() | Or() | Not()) as condition:
|
|
# Conditional(condition=condition, actions=["accept"])
|
|
yield f"{ctx.verb} where {str_of_condition(condition, ctx)};"
|
|
case _ as actions:
|
|
yield f"{ctx.verb} filter {'{'}"
|
|
yield from indent(
|
|
flatten(lines_of_action(a, ctx) for a in actions), ctx
|
|
)
|
|
yield "};"
|
|
|
|
|
|
def bird_import(rule: Rule, ipv4: bool, indent: str = " ") -> str:
|
|
ctx = Context(verb="import", ipv4=ipv4, indent=indent)
|
|
return "\n".join(lines_of_stmt(rule, ctx))
|
|
|
|
|
|
def bird_export(rule: Rule, ipv4: bool, indent: str = " ") -> str:
|
|
ctx = Context(verb="export", ipv4=ipv4, indent=indent)
|
|
return "\n".join(lines_of_stmt(rule, ctx))
|
|
|
|
|
|
class FilterModule:
|
|
def filters(self):
|
|
return {
|
|
"bird_import": bird_import,
|
|
"bird_export": bird_export,
|
|
"bird_name": bird_name,
|
|
}
|