firewall/firewall.py

248 lines
5.6 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
from argparse import ArgumentParser, FileType
from enum import Enum
from graphlib import TopologicalSorter
2023-08-27 12:56:41 +02:00
from netaddr import IPSet
from nftables import Nftables
from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyAddress,
IPvAnyNetwork,
conint,
parse_obj_as,
validator,
root_validator,
)
2023-08-27 12:56:41 +02:00
from typing import TypeAlias
from yaml import safe_load
2023-08-27 12:56:41 +02:00
import nft
# ==========[ YAML MODEL ]======================================================
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
pass
# Ports
Port: TypeAlias = conint(ge=0, le=2**16)
class PortRange(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
try:
start, end = v.split("..")
except AttributeError:
parse_obj_as(Port, v) # This is the expected error
2023-08-27 12:56:41 +02:00
raise ValueError(
"invalid port range: must be in the form start..end"
)
except ValueError:
2023-08-27 12:56:41 +02:00
raise ValueError(
"invalid port range: must be in the form start..end"
)
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end)
# Zones
class ZoneName(str):
pass
class ZoneEntry(RestrictiveBaseModel):
addrs: set[IPvAnyNetwork] = set()
files: set[FilePath] = set()
negate: bool = False
zones: set[ZoneName] = set()
# Blacklist
class BlackList(RestrictiveBaseModel):
addr: list[IPvAnyAddress] = list()
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
enabled: bool = False
# Filters
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] = list()
sport: list[Port | PortRange] = list()
class UdpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] = list()
sport: list[Port | PortRange] = list()
class Protocols(RestrictiveBaseModel):
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol = TcpProtocol()
udp: UdpProtocol = UdpProtocol()
vrrp: bool = False
class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols()
src: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
dst: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
verdict: Verdict = Verdict.accept
class ForwardRule(Rule):
dest: ZoneName | list[IPvAnyNetwork | ZoneName] | None
class Filter(RestrictiveBaseModel):
input: list[Rule] = list()
output: list[Rule] = list()
forward: list[ForwardRule] = list()
# Nat
class SNat(RestrictiveBaseModel):
addr: IPvAnyAddress
persistent: bool = True
class Nat(RestrictiveBaseModel):
src: ZoneName
snat: SNat
# Root model
class Firewall(RestrictiveBaseModel):
zones: dict[ZoneName, ZoneEntry] = list()
blacklist: BlackList = BlackList()
reverse_path_filter: ReversePathFilter = ReversePathFilter()
filter: Filter = Filter()
nat: list[Nat] = list()
# ==========[ ZONES ]===========================================================
# Zones: Graph resolver
def resolve_zones(zones_entries: list[ZoneEntry]) -> None:
zone_name = {entry.name: entry.zones for entry in zones_entries}
for name in TopologicalSorter(zone_name).static_order():
print(name)
# TODO: Check negation inclusion
# ==========[ PARSER ]==========================================================
2023-08-27 12:56:41 +02:00
def parse_blacklist(blacklist: BlackList) -> nft.Table:
table = nft.Table(name="blacklist", family="inet")
# Sets
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
table.sets.extend([set_v4, set_v6])
2023-08-27 12:56:41 +02:00
# Chains
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-310,
)
chain_v4 = nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="saddr"),
right=nft.Immediate("@blacklist_v4"),
)
chain_v6 = nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="saddr"),
right=nft.Immediate("@blacklist_v6"),
)
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")]))
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")]))
table.chains.extend([chain_filter])
return table
2023-08-27 12:56:41 +02:00
def parse_firewall(firewall: Firewall) -> nft.Ruleset:
ruleset = nft.Ruleset(flush=True)
blacklist = parse_blacklist(firewall.blacklist)
ruleset.tables.extend([blacklist])
return ruleset
# ==========[ MAIN ]============================================================
2023-08-27 12:56:41 +02:00
def send_to_nftables(cmd: nft.JsonNftables) -> int:
nft = Nftables()
2023-08-27 12:56:41 +02:00
import json
print(json.dumps(cmd, indent=4))
try:
nft.json_validate(cmd)
except Exception as e:
print(f"JSON validation failed: {e}")
return 1
rc, output, error = nft.json_cmd(cmd)
if rc != 0:
print(f"nft returned {rc}: {error}")
return 1
if len(output) != 0:
print(output)
return 0
def main() -> int:
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args()
rules = Firewall(**safe_load(args.file))
return send_to_nftables(parse_firewall(rules).to_nft())
if __name__ == "__main__":
exit(main())