feat(input): Add iif + protocols
This commit is contained in:
parent
26fec920b8
commit
cf7903c9c3
2 changed files with 210 additions and 33 deletions
212
firewall.py
212
firewall.py
|
@ -120,11 +120,17 @@ class TcpProtocol(RestrictiveBaseModel):
|
||||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
class UdpProtocol(RestrictiveBaseModel):
|
class UdpProtocol(RestrictiveBaseModel):
|
||||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
class Protocols(RestrictiveBaseModel):
|
class Protocols(RestrictiveBaseModel):
|
||||||
icmp: bool = False
|
icmp: bool = False
|
||||||
|
@ -133,6 +139,9 @@ class Protocols(RestrictiveBaseModel):
|
||||||
udp: UdpProtocol = UdpProtocol()
|
udp: UdpProtocol = UdpProtocol()
|
||||||
vrrp: bool = False
|
vrrp: bool = False
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
class Rule(RestrictiveBaseModel):
|
class Rule(RestrictiveBaseModel):
|
||||||
iif: str | None
|
iif: str | None
|
||||||
|
@ -226,6 +235,16 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
# ==========[ PARSER ]==========================================================
|
# ==========[ PARSER ]==========================================================
|
||||||
|
|
||||||
|
|
||||||
|
def unmarshall_ports(
|
||||||
|
elements: set[Port | PortRange],
|
||||||
|
) -> Generator[int, None, None]:
|
||||||
|
for element in elements:
|
||||||
|
if isinstance(element, int):
|
||||||
|
yield element
|
||||||
|
if isinstance(element, range):
|
||||||
|
yield from element
|
||||||
|
|
||||||
|
|
||||||
def split_v4_v6(
|
def split_v4_v6(
|
||||||
addrs: Generator[IPvAnyNetwork, None, None]
|
addrs: Generator[IPvAnyNetwork, None, None]
|
||||||
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
|
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
|
||||||
|
@ -242,37 +261,38 @@ def split_v4_v6(
|
||||||
return v4, v6
|
return v4, v6
|
||||||
|
|
||||||
|
|
||||||
def zones_blacklist(
|
def zones_into_ip(
|
||||||
blacklist: Blacklist, zones: Zones
|
elements: set[IPvAnyNetwork | ZoneName],
|
||||||
|
zones: Zones,
|
||||||
|
allow_negate: bool = True,
|
||||||
) -> Generator[IPvAnyNetwork, None, None]:
|
) -> Generator[IPvAnyNetwork, None, None]:
|
||||||
for blocked in blacklist.blocked:
|
for element in elements:
|
||||||
match blocked:
|
match element:
|
||||||
case ZoneName():
|
case ZoneName():
|
||||||
zone = zones[blocked]
|
zone = zones[element]
|
||||||
|
|
||||||
if zone.negate:
|
if not allow_negate and zone.negate:
|
||||||
raise ValueError(
|
raise ValueError(f"zone '{element}' cannot be negated")
|
||||||
f"zone '{blocked}' cannot be negated in the blacklist"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield from zone.addrs
|
yield from zone.addrs
|
||||||
|
|
||||||
case IPv4Network() | IPv6Network():
|
case IPv4Network() | IPv6Network():
|
||||||
yield blocked
|
yield element
|
||||||
|
|
||||||
|
|
||||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||||
# Sets
|
# Sets blacklist_v4 and blacklist_v6
|
||||||
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
||||||
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
||||||
|
|
||||||
# Elements
|
ip_v4, ip_v6 = split_v4_v6(
|
||||||
ip_v4, ip_v6 = split_v4_v6(zones_blacklist(blacklist, zones))
|
zones_into_ip(blacklist.blocked, zones, allow_negate=False)
|
||||||
|
)
|
||||||
|
|
||||||
set_v4.elements.extend(ip_v4)
|
set_v4.elements.extend(ip_v4)
|
||||||
set_v6.elements.extend(ip_v6)
|
set_v6.elements.extend(ip_v6)
|
||||||
|
|
||||||
# Chains
|
# Chain filter
|
||||||
chain_filter = nft.Chain(
|
chain_filter = nft.Chain(
|
||||||
name="filter",
|
name="filter",
|
||||||
type="filter",
|
type="filter",
|
||||||
|
@ -281,21 +301,21 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||||
priority=-310,
|
priority=-310,
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_v4 = nft.Match(
|
rule_v4 = nft.Match(
|
||||||
op="==",
|
op="==",
|
||||||
left=nft.Payload(protocol="ip", field="saddr"),
|
left=nft.Payload(protocol="ip", field="saddr"),
|
||||||
right=nft.Immediate("@blacklist_v4"),
|
right=nft.Immediate("@blacklist_v4"),
|
||||||
)
|
)
|
||||||
chain_v6 = nft.Match(
|
rule_v6 = nft.Match(
|
||||||
op="==",
|
op="==",
|
||||||
left=nft.Payload(protocol="ip6", field="saddr"),
|
left=nft.Payload(protocol="ip6", field="saddr"),
|
||||||
right=nft.Immediate("@blacklist_v6"),
|
right=nft.Immediate("@blacklist_v6"),
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")]))
|
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
|
||||||
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")]))
|
chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
|
||||||
|
|
||||||
# Generate elements
|
# Resulting table
|
||||||
table = nft.Table(name="blacklist", family="inet")
|
table = nft.Table(name="blacklist", family="inet")
|
||||||
|
|
||||||
table.chains.extend([chain_filter])
|
table.chains.extend([chain_filter])
|
||||||
|
@ -305,12 +325,12 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||||
|
|
||||||
|
|
||||||
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||||
# Sets
|
# Set disabled_ifs
|
||||||
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
||||||
|
|
||||||
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
|
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
|
||||||
|
|
||||||
# Chains
|
# Chain filter
|
||||||
chain_filter = nft.Chain(
|
chain_filter = nft.Chain(
|
||||||
name="filter",
|
name="filter",
|
||||||
type="filter",
|
type="filter",
|
||||||
|
@ -319,31 +339,29 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||||
priority=-300,
|
priority=-300,
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_iifname = nft.Match(
|
rule_iifname = nft.Match(
|
||||||
op="!=",
|
op="!=",
|
||||||
left=nft.Meta("iifname"),
|
left=nft.Meta("iifname"),
|
||||||
right=nft.Immediate("@disabled_ifs"),
|
right=nft.Immediate("@disabled_ifs"),
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_fib = nft.Match(
|
rule_fib = nft.Match(
|
||||||
op="==",
|
op="==",
|
||||||
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
||||||
right=nft.Immediate(False),
|
right=nft.Immediate(False),
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_pkttype = nft.Match(
|
rule_pkttype = nft.Match(
|
||||||
op="==",
|
op="==",
|
||||||
left=nft.Meta("pkttype"),
|
left=nft.Meta("pkttype"),
|
||||||
right=nft.Immediate("host"),
|
right=nft.Immediate("host"),
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_filter.rules.append(
|
chain_filter.rules.append(
|
||||||
nft.Rule(
|
nft.Rule([rule_iifname, rule_fib, rule_pkttype, nft.Verdict("drop")])
|
||||||
[chain_iifname, chain_fib, chain_pkttype, nft.Verdict("drop")]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate elements
|
# Resulting table
|
||||||
table = nft.Table(name="reverse_path_filter", family="inet")
|
table = nft.Table(name="reverse_path_filter", family="inet")
|
||||||
|
|
||||||
table.chains.extend([chain_filter])
|
table.chains.extend([chain_filter])
|
||||||
|
@ -352,15 +370,149 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
# Create a chain "input_filter" and for each rule from the DSL:
|
||||||
|
# - Create a specific chain "input_rules_{i}"
|
||||||
|
# - Add a rule to "input_filter" that jumps to chain "input_rules_{i}"
|
||||||
|
def parse_filter_input(rules: list[Rule], zones: Zones) -> list[nft.Chain]:
|
||||||
|
all_chains = []
|
||||||
|
|
||||||
|
chain_input = nft.Chain(
|
||||||
|
name="input_filter",
|
||||||
|
type="filter",
|
||||||
|
hook="input",
|
||||||
|
policy="drop",
|
||||||
|
priority=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, rule in enumerate(rules):
|
||||||
|
chain_spec_rules: list[nft.Statement] = []
|
||||||
|
chain_input_rules: list[nft.Statement] = []
|
||||||
|
|
||||||
|
# Input interface: chain "input_filter"
|
||||||
|
if rule.iif is not None:
|
||||||
|
chain_input_rules.append(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Meta("iif"),
|
||||||
|
right=nft.Immediate(rule.iif),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Protocols (ICMP/OSPF/VRRP): chain "input_filter"
|
||||||
|
protocols_v4 = set()
|
||||||
|
protocols_v6 = set()
|
||||||
|
|
||||||
|
for v4, v6 in [
|
||||||
|
("icmp", "ipv6-icmp"),
|
||||||
|
("ospf", "ospf"),
|
||||||
|
("vrrp", "vrrp"),
|
||||||
|
]:
|
||||||
|
if rule.protocols[v4]:
|
||||||
|
protocols_v4.add(v4)
|
||||||
|
protocols_v6.add(v6)
|
||||||
|
|
||||||
|
if protocols_v4:
|
||||||
|
chain_spec_rules.append(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol="ip", field="protocol"),
|
||||||
|
right=nft.Immediate(protocols_v4),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
chain_spec_rules.append(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol="ip6", field="nexthdr"),
|
||||||
|
right=nft.Immediate(protocols_v6),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Protocol UDP/TCP: chain "input_filter"
|
||||||
|
for proto, port in [
|
||||||
|
("udp", "dport"),
|
||||||
|
("udp", "sport"),
|
||||||
|
("tcp", "dport"),
|
||||||
|
("tcp", "sport"),
|
||||||
|
]:
|
||||||
|
if rule.protocols[proto][port]:
|
||||||
|
ports = set(unmarshall_ports(rule.protocols[proto][port]))
|
||||||
|
|
||||||
|
chain_spec_rules.append(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol=proto, field=port),
|
||||||
|
right=nft.Immediate(ports),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verdict: specific chain "input_rules_{i}"
|
||||||
|
if rule.verdict == Verdict.accept:
|
||||||
|
rules_verdict = nft.Verdict("accept")
|
||||||
|
elif rule.verdict == Verdict.drop:
|
||||||
|
rules_verdict = nft.Verdict("drop")
|
||||||
|
elif rule.verdict == Verdict.reject:
|
||||||
|
rules_verdict = nft.Verdict("reject")
|
||||||
|
|
||||||
|
# Create the chain "input_rules_{i}"
|
||||||
|
chain = nft.Chain(name=f"input_rules_{i}")
|
||||||
|
|
||||||
|
for spec_rule in chain_spec_rules:
|
||||||
|
chain.rules.append(nft.Rule([spec_rule, rules_verdict]))
|
||||||
|
|
||||||
|
all_chains.append(chain)
|
||||||
|
|
||||||
|
# Add the chain "input_rules_{i}" to the chain "input_filter"
|
||||||
|
chain_input_rules.append(nft.Goto(f"input_rules_{i}"))
|
||||||
|
chain_input.rules.append(nft.Rule(chain_input_rules))
|
||||||
|
|
||||||
|
return all_chains + [chain_input]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
||||||
|
# Conntrack
|
||||||
|
chain_conntrack = nft.Chain(name="conntrack")
|
||||||
|
|
||||||
|
rule_ct_accept = nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Ct("state"),
|
||||||
|
right=nft.Immediate({"established", "related"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
rule_ct_drop = nft.Match(
|
||||||
|
op="in",
|
||||||
|
left=nft.Ct("state"),
|
||||||
|
right=nft.Immediate("invalid"),
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_conntrack.rules.extend(
|
||||||
|
[
|
||||||
|
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
|
||||||
|
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
chains_inputs = parse_filter_input(filter.input, zones)
|
||||||
|
|
||||||
|
# Resulting table
|
||||||
|
table = nft.Table(name="filter", family="inet")
|
||||||
|
|
||||||
|
table.chains.extend([chain_conntrack])
|
||||||
|
table.chains.extend(chains_inputs)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
||||||
# Tables
|
# Tables
|
||||||
blacklist = parse_blacklist(firewall.blacklist, zones)
|
blacklist = parse_blacklist(firewall.blacklist, zones)
|
||||||
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
||||||
|
filter = parse_filter(firewall.filter, zones)
|
||||||
|
|
||||||
# Ruleset
|
# Resulting ruleset
|
||||||
ruleset = nft.Ruleset(flush=True)
|
ruleset = nft.Ruleset(flush=True)
|
||||||
|
|
||||||
ruleset.tables.extend([blacklist, rpf])
|
ruleset.tables.extend([blacklist, rpf, filter])
|
||||||
|
|
||||||
return ruleset
|
return ruleset
|
||||||
|
|
||||||
|
|
31
nft.py
31
nft.py
|
@ -16,6 +16,14 @@ def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
|
||||||
|
|
||||||
|
|
||||||
# Expressions
|
# Expressions
|
||||||
|
@dataclass
|
||||||
|
class Ct:
|
||||||
|
key: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"ct": {"key": self.key}}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fib:
|
class Fib:
|
||||||
flags: list[str]
|
flags: list[str]
|
||||||
|
@ -35,6 +43,9 @@ class Immediate(Generic[T]):
|
||||||
):
|
):
|
||||||
return ip_to_nft(self.value)
|
return ip_to_nft(self.value)
|
||||||
|
|
||||||
|
if isinstance(self.value, set):
|
||||||
|
return {"set": list(self.value)}
|
||||||
|
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,10 +66,24 @@ class Payload:
|
||||||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||||
|
|
||||||
|
|
||||||
Expression = Fib | Immediate | Meta | Payload
|
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||||
|
|
||||||
|
|
||||||
# Statements
|
# Statements
|
||||||
|
@dataclass
|
||||||
|
class Counter:
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"counter": {"packets": 0, "bytes": 0}}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Goto:
|
||||||
|
target: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"goto": {"target": self.target}}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Match:
|
class Match:
|
||||||
op: str
|
op: str
|
||||||
|
@ -85,7 +110,7 @@ class Verdict:
|
||||||
return {self.verdict: self.target}
|
return {self.verdict: self.target}
|
||||||
|
|
||||||
|
|
||||||
Statement = Match | Verdict
|
Statement = Counter | Goto | Match | Verdict
|
||||||
|
|
||||||
|
|
||||||
# Ruleset
|
# Ruleset
|
||||||
|
@ -94,7 +119,7 @@ class Set:
|
||||||
name: str
|
name: str
|
||||||
type: str
|
type: str
|
||||||
|
|
||||||
flags: str | None = None
|
flags: list[str] | None = None
|
||||||
|
|
||||||
elements: list[Immediate[Any]] = field(default_factory=list)
|
elements: list[Immediate[Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue