feat(input): Add iif + protocols
This commit is contained in:
parent
26fec920b8
commit
cf7903c9c3
2 changed files with 210 additions and 33 deletions
212
firewall.py
212
firewall.py
|
@ -120,11 +120,17 @@ class TcpProtocol(RestrictiveBaseModel):
|
|||
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class UdpProtocol(RestrictiveBaseModel):
|
||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class Protocols(RestrictiveBaseModel):
|
||||
icmp: bool = False
|
||||
|
@ -133,6 +139,9 @@ class Protocols(RestrictiveBaseModel):
|
|||
udp: UdpProtocol = UdpProtocol()
|
||||
vrrp: bool = False
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class Rule(RestrictiveBaseModel):
|
||||
iif: str | None
|
||||
|
@ -226,6 +235,16 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|||
# ==========[ PARSER ]==========================================================
|
||||
|
||||
|
||||
def unmarshall_ports(
|
||||
elements: set[Port | PortRange],
|
||||
) -> Generator[int, None, None]:
|
||||
for element in elements:
|
||||
if isinstance(element, int):
|
||||
yield element
|
||||
if isinstance(element, range):
|
||||
yield from element
|
||||
|
||||
|
||||
def split_v4_v6(
|
||||
addrs: Generator[IPvAnyNetwork, None, None]
|
||||
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
|
||||
|
@ -242,37 +261,38 @@ def split_v4_v6(
|
|||
return v4, v6
|
||||
|
||||
|
||||
def zones_blacklist(
|
||||
blacklist: Blacklist, zones: Zones
|
||||
def zones_into_ip(
|
||||
elements: set[IPvAnyNetwork | ZoneName],
|
||||
zones: Zones,
|
||||
allow_negate: bool = True,
|
||||
) -> Generator[IPvAnyNetwork, None, None]:
|
||||
for blocked in blacklist.blocked:
|
||||
match blocked:
|
||||
for element in elements:
|
||||
match element:
|
||||
case ZoneName():
|
||||
zone = zones[blocked]
|
||||
zone = zones[element]
|
||||
|
||||
if zone.negate:
|
||||
raise ValueError(
|
||||
f"zone '{blocked}' cannot be negated in the blacklist"
|
||||
)
|
||||
if not allow_negate and zone.negate:
|
||||
raise ValueError(f"zone '{element}' cannot be negated")
|
||||
|
||||
yield from zone.addrs
|
||||
|
||||
case IPv4Network() | IPv6Network():
|
||||
yield blocked
|
||||
yield element
|
||||
|
||||
|
||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||
# Sets
|
||||
# Sets blacklist_v4 and blacklist_v6
|
||||
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
||||
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
||||
|
||||
# Elements
|
||||
ip_v4, ip_v6 = split_v4_v6(zones_blacklist(blacklist, zones))
|
||||
ip_v4, ip_v6 = split_v4_v6(
|
||||
zones_into_ip(blacklist.blocked, zones, allow_negate=False)
|
||||
)
|
||||
|
||||
set_v4.elements.extend(ip_v4)
|
||||
set_v6.elements.extend(ip_v6)
|
||||
|
||||
# Chains
|
||||
# Chain filter
|
||||
chain_filter = nft.Chain(
|
||||
name="filter",
|
||||
type="filter",
|
||||
|
@ -281,21 +301,21 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
|||
priority=-310,
|
||||
)
|
||||
|
||||
chain_v4 = nft.Match(
|
||||
rule_v4 = nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field="saddr"),
|
||||
right=nft.Immediate("@blacklist_v4"),
|
||||
)
|
||||
chain_v6 = nft.Match(
|
||||
rule_v6 = nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip6", field="saddr"),
|
||||
right=nft.Immediate("@blacklist_v6"),
|
||||
)
|
||||
|
||||
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")]))
|
||||
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")]))
|
||||
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
|
||||
chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
|
||||
|
||||
# Generate elements
|
||||
# Resulting table
|
||||
table = nft.Table(name="blacklist", family="inet")
|
||||
|
||||
table.chains.extend([chain_filter])
|
||||
|
@ -305,12 +325,12 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
|||
|
||||
|
||||
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||
# Sets
|
||||
# Set disabled_ifs
|
||||
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
||||
|
||||
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
|
||||
|
||||
# Chains
|
||||
# Chain filter
|
||||
chain_filter = nft.Chain(
|
||||
name="filter",
|
||||
type="filter",
|
||||
|
@ -319,31 +339,29 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
|||
priority=-300,
|
||||
)
|
||||
|
||||
chain_iifname = nft.Match(
|
||||
rule_iifname = nft.Match(
|
||||
op="!=",
|
||||
left=nft.Meta("iifname"),
|
||||
right=nft.Immediate("@disabled_ifs"),
|
||||
)
|
||||
|
||||
chain_fib = nft.Match(
|
||||
rule_fib = nft.Match(
|
||||
op="==",
|
||||
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
||||
right=nft.Immediate(False),
|
||||
)
|
||||
|
||||
chain_pkttype = nft.Match(
|
||||
rule_pkttype = nft.Match(
|
||||
op="==",
|
||||
left=nft.Meta("pkttype"),
|
||||
right=nft.Immediate("host"),
|
||||
)
|
||||
|
||||
chain_filter.rules.append(
|
||||
nft.Rule(
|
||||
[chain_iifname, chain_fib, chain_pkttype, nft.Verdict("drop")]
|
||||
)
|
||||
nft.Rule([rule_iifname, rule_fib, rule_pkttype, nft.Verdict("drop")])
|
||||
)
|
||||
|
||||
# Generate elements
|
||||
# Resulting table
|
||||
table = nft.Table(name="reverse_path_filter", family="inet")
|
||||
|
||||
table.chains.extend([chain_filter])
|
||||
|
@ -352,15 +370,149 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
|||
return table
|
||||
|
||||
|
||||
# Create a chain "input_filter" and for each rule from the DSL:
|
||||
# - Create a specific chain "input_rules_{i}"
|
||||
# - Add a rule to "input_filter" that jumps to chain "input_rules_{i}"
|
||||
def parse_filter_input(rules: list[Rule], zones: Zones) -> list[nft.Chain]:
|
||||
all_chains = []
|
||||
|
||||
chain_input = nft.Chain(
|
||||
name="input_filter",
|
||||
type="filter",
|
||||
hook="input",
|
||||
policy="drop",
|
||||
priority=0,
|
||||
)
|
||||
|
||||
for i, rule in enumerate(rules):
|
||||
chain_spec_rules: list[nft.Statement] = []
|
||||
chain_input_rules: list[nft.Statement] = []
|
||||
|
||||
# Input interface: chain "input_filter"
|
||||
if rule.iif is not None:
|
||||
chain_input_rules.append(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Meta("iif"),
|
||||
right=nft.Immediate(rule.iif),
|
||||
)
|
||||
)
|
||||
|
||||
# Protocols (ICMP/OSPF/VRRP): chain "input_filter"
|
||||
protocols_v4 = set()
|
||||
protocols_v6 = set()
|
||||
|
||||
for v4, v6 in [
|
||||
("icmp", "ipv6-icmp"),
|
||||
("ospf", "ospf"),
|
||||
("vrrp", "vrrp"),
|
||||
]:
|
||||
if rule.protocols[v4]:
|
||||
protocols_v4.add(v4)
|
||||
protocols_v6.add(v6)
|
||||
|
||||
if protocols_v4:
|
||||
chain_spec_rules.append(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field="protocol"),
|
||||
right=nft.Immediate(protocols_v4),
|
||||
)
|
||||
)
|
||||
chain_spec_rules.append(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip6", field="nexthdr"),
|
||||
right=nft.Immediate(protocols_v6),
|
||||
)
|
||||
)
|
||||
|
||||
# Protocol UDP/TCP: chain "input_filter"
|
||||
for proto, port in [
|
||||
("udp", "dport"),
|
||||
("udp", "sport"),
|
||||
("tcp", "dport"),
|
||||
("tcp", "sport"),
|
||||
]:
|
||||
if rule.protocols[proto][port]:
|
||||
ports = set(unmarshall_ports(rule.protocols[proto][port]))
|
||||
|
||||
chain_spec_rules.append(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol=proto, field=port),
|
||||
right=nft.Immediate(ports),
|
||||
)
|
||||
)
|
||||
|
||||
# Verdict: specific chain "input_rules_{i}"
|
||||
if rule.verdict == Verdict.accept:
|
||||
rules_verdict = nft.Verdict("accept")
|
||||
elif rule.verdict == Verdict.drop:
|
||||
rules_verdict = nft.Verdict("drop")
|
||||
elif rule.verdict == Verdict.reject:
|
||||
rules_verdict = nft.Verdict("reject")
|
||||
|
||||
# Create the chain "input_rules_{i}"
|
||||
chain = nft.Chain(name=f"input_rules_{i}")
|
||||
|
||||
for spec_rule in chain_spec_rules:
|
||||
chain.rules.append(nft.Rule([spec_rule, rules_verdict]))
|
||||
|
||||
all_chains.append(chain)
|
||||
|
||||
# Add the chain "input_rules_{i}" to the chain "input_filter"
|
||||
chain_input_rules.append(nft.Goto(f"input_rules_{i}"))
|
||||
chain_input.rules.append(nft.Rule(chain_input_rules))
|
||||
|
||||
return all_chains + [chain_input]
|
||||
|
||||
|
||||
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
||||
# Conntrack
|
||||
chain_conntrack = nft.Chain(name="conntrack")
|
||||
|
||||
rule_ct_accept = nft.Match(
|
||||
op="==",
|
||||
left=nft.Ct("state"),
|
||||
right=nft.Immediate({"established", "related"}),
|
||||
)
|
||||
|
||||
rule_ct_drop = nft.Match(
|
||||
op="in",
|
||||
left=nft.Ct("state"),
|
||||
right=nft.Immediate("invalid"),
|
||||
)
|
||||
|
||||
chain_conntrack.rules.extend(
|
||||
[
|
||||
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
|
||||
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
|
||||
]
|
||||
)
|
||||
|
||||
# Inputs
|
||||
chains_inputs = parse_filter_input(filter.input, zones)
|
||||
|
||||
# Resulting table
|
||||
table = nft.Table(name="filter", family="inet")
|
||||
|
||||
table.chains.extend([chain_conntrack])
|
||||
table.chains.extend(chains_inputs)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
||||
# Tables
|
||||
blacklist = parse_blacklist(firewall.blacklist, zones)
|
||||
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
||||
filter = parse_filter(firewall.filter, zones)
|
||||
|
||||
# Ruleset
|
||||
# Resulting ruleset
|
||||
ruleset = nft.Ruleset(flush=True)
|
||||
|
||||
ruleset.tables.extend([blacklist, rpf])
|
||||
ruleset.tables.extend([blacklist, rpf, filter])
|
||||
|
||||
return ruleset
|
||||
|
||||
|
|
31
nft.py
31
nft.py
|
@ -16,6 +16,14 @@ def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
|
|||
|
||||
|
||||
# Expressions
|
||||
@dataclass
|
||||
class Ct:
|
||||
key: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"ct": {"key": self.key}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fib:
|
||||
flags: list[str]
|
||||
|
@ -35,6 +43,9 @@ class Immediate(Generic[T]):
|
|||
):
|
||||
return ip_to_nft(self.value)
|
||||
|
||||
if isinstance(self.value, set):
|
||||
return {"set": list(self.value)}
|
||||
|
||||
return self.value
|
||||
|
||||
|
||||
|
@ -55,10 +66,24 @@ class Payload:
|
|||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||
|
||||
|
||||
Expression = Fib | Immediate | Meta | Payload
|
||||
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||
|
||||
|
||||
# Statements
|
||||
@dataclass
|
||||
class Counter:
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"counter": {"packets": 0, "bytes": 0}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Goto:
|
||||
target: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"goto": {"target": self.target}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
op: str
|
||||
|
@ -85,7 +110,7 @@ class Verdict:
|
|||
return {self.verdict: self.target}
|
||||
|
||||
|
||||
Statement = Match | Verdict
|
||||
Statement = Counter | Goto | Match | Verdict
|
||||
|
||||
|
||||
# Ruleset
|
||||
|
@ -94,7 +119,7 @@ class Set:
|
|||
name: str
|
||||
type: str
|
||||
|
||||
flags: str | None = None
|
||||
flags: list[str] | None = None
|
||||
|
||||
elements: list[Immediate[Any]] = field(default_factory=list)
|
||||
|
||||
|
|
Loading…
Reference in a new issue