diff --git a/roles/bird/filter_plugins/bird.py b/roles/bird/filter_plugins/bird.py index b2764d9..2cf01de 100644 --- a/roles/bird/filter_plugins/bird.py +++ b/roles/bird/filter_plugins/bird.py @@ -3,7 +3,7 @@ from __future__ import annotations import itertools import re from dataclasses import dataclass -from ipaddress import IPv4Address +from ipaddress import IPv4Address, IPv4Network, IPv6Network, ip_network from typing import Any, Generic, Iterator, Literal, TypeVar from pydantic import ( @@ -89,12 +89,37 @@ class IPv4orIPv6(BaseModel): ipv6: list[int] = Field(ge=0, min_items=2, max_items=2) +class IPFlag(str): + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, v): + pattern = r"(?P.*?)(?P[+-]|\{[0-9]+,[0-9]+\})?$" + parts = re.match(pattern, v) + + return (ip_network(parts.group("ip")), parts.group("flag")) + + +class NetMatch(BaseModel): + matches: list[IPFlag] = Field(alias="net.match") + + class NetLength(BaseModel): length: IPv4orIPv6 = Field(alias="net.len") Condition = ( - Proto | Source | And | Or | Not | AsPathContains | AsPathLength | NetLength + Proto + | Source + | And + | Or + | Not + | AsPathContains + | AsPathLength + | NetLength + | NetMatch ) And.update_forward_refs() @@ -210,6 +235,18 @@ def str_of_condition(condition: Condition, ctx: bool) -> str: else: return f"{min_v6} <= net.len && net.len <= {max_v6}" + case NetMatch(matches=matches): + if ctx.ipv4: + networks = [ + m for m in matches if isinstance(m[0], IPv4Network) + ] + else: + networks = [ + m for m in matches if isinstance(m[0], IPv6Network) + ] + + return f"net ~ [ {', '.join([f'{network}{str(flag)}' for (network, flag) in networks])} ]" + def lines_of_action(action: Action, ctx: Context) -> Iterable[str]: match action: