feat(input): Add iif + protocols

This commit is contained in:
v-lafeychine 2023-08-28 02:32:40 +02:00
parent 26fec920b8
commit cf7903c9c3
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 210 additions and 33 deletions

View file

@ -120,11 +120,17 @@ class TcpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet() dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet()
def __getitem__(self, key):
return getattr(self, key)
class UdpProtocol(RestrictiveBaseModel): class UdpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet() dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet()
def __getitem__(self, key):
return getattr(self, key)
class Protocols(RestrictiveBaseModel): class Protocols(RestrictiveBaseModel):
icmp: bool = False icmp: bool = False
@ -133,6 +139,9 @@ class Protocols(RestrictiveBaseModel):
udp: UdpProtocol = UdpProtocol() udp: UdpProtocol = UdpProtocol()
vrrp: bool = False vrrp: bool = False
def __getitem__(self, key):
return getattr(self, key)
class Rule(RestrictiveBaseModel): class Rule(RestrictiveBaseModel):
iif: str | None iif: str | None
@ -226,6 +235,16 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
# ==========[ PARSER ]========================================================== # ==========[ PARSER ]==========================================================
def unmarshall_ports(
elements: set[Port | PortRange],
) -> Generator[int, None, None]:
for element in elements:
if isinstance(element, int):
yield element
if isinstance(element, range):
yield from element
def split_v4_v6( def split_v4_v6(
addrs: Generator[IPvAnyNetwork, None, None] addrs: Generator[IPvAnyNetwork, None, None]
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]: ) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
@ -242,37 +261,38 @@ def split_v4_v6(
return v4, v6 return v4, v6
def zones_blacklist( def zones_into_ip(
blacklist: Blacklist, zones: Zones elements: set[IPvAnyNetwork | ZoneName],
zones: Zones,
allow_negate: bool = True,
) -> Generator[IPvAnyNetwork, None, None]: ) -> Generator[IPvAnyNetwork, None, None]:
for blocked in blacklist.blocked: for element in elements:
match blocked: match element:
case ZoneName(): case ZoneName():
zone = zones[blocked] zone = zones[element]
if zone.negate: if not allow_negate and zone.negate:
raise ValueError( raise ValueError(f"zone '{element}' cannot be negated")
f"zone '{blocked}' cannot be negated in the blacklist"
)
yield from zone.addrs yield from zone.addrs
case IPv4Network() | IPv6Network(): case IPv4Network() | IPv6Network():
yield blocked yield element
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
# Sets # Sets blacklist_v4 and blacklist_v6
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"]) set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
# Elements ip_v4, ip_v6 = split_v4_v6(
ip_v4, ip_v6 = split_v4_v6(zones_blacklist(blacklist, zones)) zones_into_ip(blacklist.blocked, zones, allow_negate=False)
)
set_v4.elements.extend(ip_v4) set_v4.elements.extend(ip_v4)
set_v6.elements.extend(ip_v6) set_v6.elements.extend(ip_v6)
# Chains # Chain filter
chain_filter = nft.Chain( chain_filter = nft.Chain(
name="filter", name="filter",
type="filter", type="filter",
@ -281,21 +301,21 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
priority=-310, priority=-310,
) )
chain_v4 = nft.Match( rule_v4 = nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip", field="saddr"), left=nft.Payload(protocol="ip", field="saddr"),
right=nft.Immediate("@blacklist_v4"), right=nft.Immediate("@blacklist_v4"),
) )
chain_v6 = nft.Match( rule_v6 = nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip6", field="saddr"), left=nft.Payload(protocol="ip6", field="saddr"),
right=nft.Immediate("@blacklist_v6"), right=nft.Immediate("@blacklist_v6"),
) )
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")])) chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")])) chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
# Generate elements # Resulting table
table = nft.Table(name="blacklist", family="inet") table = nft.Table(name="blacklist", family="inet")
table.chains.extend([chain_filter]) table.chains.extend([chain_filter])
@ -305,12 +325,12 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table: def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
# Sets # Set disabled_ifs
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname") disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces)) disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
# Chains # Chain filter
chain_filter = nft.Chain( chain_filter = nft.Chain(
name="filter", name="filter",
type="filter", type="filter",
@ -319,31 +339,29 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
priority=-300, priority=-300,
) )
chain_iifname = nft.Match( rule_iifname = nft.Match(
op="!=", op="!=",
left=nft.Meta("iifname"), left=nft.Meta("iifname"),
right=nft.Immediate("@disabled_ifs"), right=nft.Immediate("@disabled_ifs"),
) )
chain_fib = nft.Match( rule_fib = nft.Match(
op="==", op="==",
left=nft.Fib(flags=["saddr", "iif"], result="oif"), left=nft.Fib(flags=["saddr", "iif"], result="oif"),
right=nft.Immediate(False), right=nft.Immediate(False),
) )
chain_pkttype = nft.Match( rule_pkttype = nft.Match(
op="==", op="==",
left=nft.Meta("pkttype"), left=nft.Meta("pkttype"),
right=nft.Immediate("host"), right=nft.Immediate("host"),
) )
chain_filter.rules.append( chain_filter.rules.append(
nft.Rule( nft.Rule([rule_iifname, rule_fib, rule_pkttype, nft.Verdict("drop")])
[chain_iifname, chain_fib, chain_pkttype, nft.Verdict("drop")]
)
) )
# Generate elements # Resulting table
table = nft.Table(name="reverse_path_filter", family="inet") table = nft.Table(name="reverse_path_filter", family="inet")
table.chains.extend([chain_filter]) table.chains.extend([chain_filter])
@ -352,15 +370,149 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
return table return table
# Create a chain "input_filter" and for each rule from the DSL:
# - Create a specific chain "input_rules_{i}"
# - Add a rule to "input_filter" that jumps to chain "input_rules_{i}"
def parse_filter_input(rules: list[Rule], zones: Zones) -> list[nft.Chain]:
all_chains = []
chain_input = nft.Chain(
name="input_filter",
type="filter",
hook="input",
policy="drop",
priority=0,
)
for i, rule in enumerate(rules):
chain_spec_rules: list[nft.Statement] = []
chain_input_rules: list[nft.Statement] = []
# Input interface: chain "input_filter"
if rule.iif is not None:
chain_input_rules.append(
nft.Match(
op="==",
left=nft.Meta("iif"),
right=nft.Immediate(rule.iif),
)
)
# Protocols (ICMP/OSPF/VRRP): chain "input_filter"
protocols_v4 = set()
protocols_v6 = set()
for v4, v6 in [
("icmp", "ipv6-icmp"),
("ospf", "ospf"),
("vrrp", "vrrp"),
]:
if rule.protocols[v4]:
protocols_v4.add(v4)
protocols_v6.add(v6)
if protocols_v4:
chain_spec_rules.append(
nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="protocol"),
right=nft.Immediate(protocols_v4),
)
)
chain_spec_rules.append(
nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="nexthdr"),
right=nft.Immediate(protocols_v6),
)
)
# Protocol UDP/TCP: chain "input_filter"
for proto, port in [
("udp", "dport"),
("udp", "sport"),
("tcp", "dport"),
("tcp", "sport"),
]:
if rule.protocols[proto][port]:
ports = set(unmarshall_ports(rule.protocols[proto][port]))
chain_spec_rules.append(
nft.Match(
op="==",
left=nft.Payload(protocol=proto, field=port),
right=nft.Immediate(ports),
)
)
# Verdict: specific chain "input_rules_{i}"
if rule.verdict == Verdict.accept:
rules_verdict = nft.Verdict("accept")
elif rule.verdict == Verdict.drop:
rules_verdict = nft.Verdict("drop")
elif rule.verdict == Verdict.reject:
rules_verdict = nft.Verdict("reject")
# Create the chain "input_rules_{i}"
chain = nft.Chain(name=f"input_rules_{i}")
for spec_rule in chain_spec_rules:
chain.rules.append(nft.Rule([spec_rule, rules_verdict]))
all_chains.append(chain)
# Add the chain "input_rules_{i}" to the chain "input_filter"
chain_input_rules.append(nft.Goto(f"input_rules_{i}"))
chain_input.rules.append(nft.Rule(chain_input_rules))
return all_chains + [chain_input]
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
# Conntrack
chain_conntrack = nft.Chain(name="conntrack")
rule_ct_accept = nft.Match(
op="==",
left=nft.Ct("state"),
right=nft.Immediate({"established", "related"}),
)
rule_ct_drop = nft.Match(
op="in",
left=nft.Ct("state"),
right=nft.Immediate("invalid"),
)
chain_conntrack.rules.extend(
[
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
]
)
# Inputs
chains_inputs = parse_filter_input(filter.input, zones)
# Resulting table
table = nft.Table(name="filter", family="inet")
table.chains.extend([chain_conntrack])
table.chains.extend(chains_inputs)
return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset: def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
# Tables # Tables
blacklist = parse_blacklist(firewall.blacklist, zones) blacklist = parse_blacklist(firewall.blacklist, zones)
rpf = parse_reverse_path_filter(firewall.reverse_path_filter) rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
filter = parse_filter(firewall.filter, zones)
# Ruleset # Resulting ruleset
ruleset = nft.Ruleset(flush=True) ruleset = nft.Ruleset(flush=True)
ruleset.tables.extend([blacklist, rpf]) ruleset.tables.extend([blacklist, rpf, filter])
return ruleset return ruleset

31
nft.py
View file

@ -16,6 +16,14 @@ def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
# Expressions # Expressions
@dataclass
class Ct:
key: str
def to_nft(self) -> JsonNftables:
return {"ct": {"key": self.key}}
@dataclass @dataclass
class Fib: class Fib:
flags: list[str] flags: list[str]
@ -35,6 +43,9 @@ class Immediate(Generic[T]):
): ):
return ip_to_nft(self.value) return ip_to_nft(self.value)
if isinstance(self.value, set):
return {"set": list(self.value)}
return self.value return self.value
@ -55,10 +66,24 @@ class Payload:
return {"payload": {"protocol": self.protocol, "field": self.field}} return {"payload": {"protocol": self.protocol, "field": self.field}}
Expression = Fib | Immediate | Meta | Payload Expression = Ct | Fib | Immediate | Meta | Payload
# Statements # Statements
@dataclass
class Counter:
def to_nft(self) -> JsonNftables:
return {"counter": {"packets": 0, "bytes": 0}}
@dataclass
class Goto:
target: str
def to_nft(self) -> JsonNftables:
return {"goto": {"target": self.target}}
@dataclass @dataclass
class Match: class Match:
op: str op: str
@ -85,7 +110,7 @@ class Verdict:
return {self.verdict: self.target} return {self.verdict: self.target}
Statement = Match | Verdict Statement = Counter | Goto | Match | Verdict
# Ruleset # Ruleset
@ -94,7 +119,7 @@ class Set:
name: str name: str
type: str type: str
flags: str | None = None flags: list[str] | None = None
elements: list[Immediate[Any]] = field(default_factory=list) elements: list[Immediate[Any]] = field(default_factory=list)