|
|
|
@ -1,8 +1,9 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import itertools
|
|
|
|
|
import re
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from ipaddress import IPv4Address
|
|
|
|
|
from ipaddress import IPv4Address, IPv4Network, IPv6Network, ip_network
|
|
|
|
|
from typing import Any, Generic, Iterator, Literal, TypeVar
|
|
|
|
|
|
|
|
|
|
from pydantic import (
|
|
|
|
@ -29,6 +30,30 @@ class AutoList(list[T], Generic[T]):
|
|
|
|
|
return [parse_obj_as(T, value)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
VARIABLES = {
|
|
|
|
|
"net.len": "net.len",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def interpolate(string: str, ctx: Context) -> str:
|
|
|
|
|
pattern = r"(?<!\\)(\$\{[_a-z][_a-z0-9.]*\})"
|
|
|
|
|
|
|
|
|
|
def lookup(var: str) -> str:
|
|
|
|
|
try:
|
|
|
|
|
return VARIABLES[var]
|
|
|
|
|
except KeyError:
|
|
|
|
|
return quoted(getattr(ctx, var))
|
|
|
|
|
|
|
|
|
|
split = re.split(pattern, string)
|
|
|
|
|
parts = [
|
|
|
|
|
(lookup(p[2:-1]) if re.match(pattern, p) else quoted(p))
|
|
|
|
|
for p in split
|
|
|
|
|
if p
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return ", ".join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Proto(BaseModel):
|
|
|
|
|
protos: AutoList[str]
|
|
|
|
|
|
|
|
|
@ -49,7 +74,53 @@ class Not(BaseModel):
|
|
|
|
|
condition: Condition = Field(alias="not")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Condition = Proto | Source | And | Or | Not
|
|
|
|
|
class AsPathContains(BaseModel):
|
|
|
|
|
contains: AutoList[int] = Field(alias="as_path.contains")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsPathLength(BaseModel):
|
|
|
|
|
length: list[int] = Field(
|
|
|
|
|
ge=0, min_items=2, max_items=2, alias="as_path.len"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPv4orIPv6(BaseModel):
|
|
|
|
|
ipv4: list[int] = Field(ge=0, min_items=2, max_items=2)
|
|
|
|
|
ipv6: list[int] = Field(ge=0, min_items=2, max_items=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPFlag:
|
|
|
|
|
@classmethod
|
|
|
|
|
def __get_validators__(cls):
|
|
|
|
|
yield cls.validate
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def validate(cls, v):
|
|
|
|
|
pattern = r"(?P<ip>.*?)(?P<flag>[+-]|\{[0-9]+,[0-9]+\})?$"
|
|
|
|
|
parts = re.match(pattern, v)
|
|
|
|
|
|
|
|
|
|
return (ip_network(parts.group("ip")), parts.group("flag") or "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NetMatch(BaseModel):
|
|
|
|
|
matches: list[IPFlag] = Field(alias="net.match")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NetLength(BaseModel):
|
|
|
|
|
length: IPv4orIPv6 = Field(alias="net.len")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Condition = (
|
|
|
|
|
Proto
|
|
|
|
|
| Source
|
|
|
|
|
| And
|
|
|
|
|
| Or
|
|
|
|
|
| Not
|
|
|
|
|
| AsPathContains
|
|
|
|
|
| AsPathLength
|
|
|
|
|
| NetLength
|
|
|
|
|
| NetMatch
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
And.update_forward_refs()
|
|
|
|
|
Or.update_forward_refs()
|
|
|
|
@ -61,6 +132,10 @@ Accept = Literal["accept"]
|
|
|
|
|
Reject = Literal["reject"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RejectWithMsg(BaseModel):
|
|
|
|
|
reject: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PrefSrc(BaseModel):
|
|
|
|
|
pref_src: AutoList[IPvAnyAddress]
|
|
|
|
|
|
|
|
|
@ -70,7 +145,7 @@ class Conditional(BaseModel):
|
|
|
|
|
actions: AutoList[Action] = Field(alias="then")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Action = Accept | Reject | PrefSrc | Conditional
|
|
|
|
|
Action = Accept | Reject | RejectWithMsg | PrefSrc | Conditional
|
|
|
|
|
|
|
|
|
|
Conditional.update_forward_refs()
|
|
|
|
|
|
|
|
|
@ -144,12 +219,37 @@ def str_of_condition(condition: Condition, ctx: bool) -> str:
|
|
|
|
|
sources = [str(s) for s in sources]
|
|
|
|
|
return f"krt_source ~ [ {', '.join(sources)} ]"
|
|
|
|
|
|
|
|
|
|
case AsPathContains(contains=contains):
|
|
|
|
|
return (
|
|
|
|
|
f"bgp_path ~ [ {', '.join([str(asn) for asn in contains])} ]"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
case AsPathLength(length=[min_len, max_len]):
|
|
|
|
|
return f"{min_len} <= bgp_path.len && bgp_path.len <= {max_len}"
|
|
|
|
|
|
|
|
|
|
case NetLength(
|
|
|
|
|
length=IPv4orIPv6(ipv4=[min_v4, max_v4], ipv6=[min_v6, max_v6])
|
|
|
|
|
):
|
|
|
|
|
if ctx.ipv4:
|
|
|
|
|
return f"{min_v4} <= net.len && net.len <= {max_v4}"
|
|
|
|
|
else:
|
|
|
|
|
return f"{min_v6} <= net.len && net.len <= {max_v6}"
|
|
|
|
|
|
|
|
|
|
case NetMatch(matches=matches):
|
|
|
|
|
networkType = IPv4Network if ctx.ipv4 else IPv6Network
|
|
|
|
|
networks = [m for m in matches if isinstance(m[0], networkType)]
|
|
|
|
|
|
|
|
|
|
return f"net ~ [ {', '.join([f'{network}{flag}' for (network, flag) in networks])} ]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def lines_of_action(action: Action, ctx: Context) -> Iterable[str]:
|
|
|
|
|
match action:
|
|
|
|
|
case "accept" | "reject":
|
|
|
|
|
yield f"{action};"
|
|
|
|
|
|
|
|
|
|
case RejectWithMsg(reject=reason):
|
|
|
|
|
yield f"reject {interpolate(reason, ctx)};"
|
|
|
|
|
|
|
|
|
|
case Conditional(condition=condition, actions=actions):
|
|
|
|
|
yield f"if {str_of_condition(condition, ctx)} then {'{'}"
|
|
|
|
|
yield from indent(
|
|
|
|
|