2023-08-26 01:08:30 +02:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
from argparse import ArgumentParser, FileType
|
|
|
|
from enum import Enum
|
|
|
|
from graphlib import TopologicalSorter
|
2023-08-27 12:56:41 +02:00
|
|
|
from netaddr import IPSet
|
2023-08-26 01:08:30 +02:00
|
|
|
from nftables import Nftables
|
|
|
|
from pydantic import (
|
|
|
|
BaseModel,
|
|
|
|
Extra,
|
|
|
|
FilePath,
|
|
|
|
IPvAnyAddress,
|
|
|
|
IPvAnyNetwork,
|
|
|
|
conint,
|
|
|
|
parse_obj_as,
|
|
|
|
validator,
|
|
|
|
root_validator,
|
|
|
|
)
|
2023-08-27 12:56:41 +02:00
|
|
|
from typing import TypeAlias
|
2023-08-26 01:08:30 +02:00
|
|
|
from yaml import safe_load
|
2023-08-27 12:56:41 +02:00
|
|
|
import nft
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
# ==========[ YAML MODEL ]======================================================
|
|
|
|
|
|
|
|
|
|
|
|
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# Ports
|
|
|
|
Port: TypeAlias = conint(ge=0, le=2**16)
|
|
|
|
|
|
|
|
|
|
|
|
class PortRange(str):
|
|
|
|
@classmethod
|
|
|
|
def __get_validators__(cls):
|
|
|
|
yield cls.validate
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def validate(cls, v):
|
|
|
|
try:
|
|
|
|
start, end = v.split("..")
|
|
|
|
except AttributeError:
|
|
|
|
parse_obj_as(Port, v) # This is the expected error
|
2023-08-27 12:56:41 +02:00
|
|
|
raise ValueError(
|
|
|
|
"invalid port range: must be in the form start..end"
|
|
|
|
)
|
2023-08-26 01:08:30 +02:00
|
|
|
except ValueError:
|
2023-08-27 12:56:41 +02:00
|
|
|
raise ValueError(
|
|
|
|
"invalid port range: must be in the form start..end"
|
|
|
|
)
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
|
|
|
if start > end:
|
|
|
|
raise ValueError("invalid port range: start must be less than end")
|
|
|
|
|
|
|
|
return range(start, end)
|
|
|
|
|
|
|
|
|
|
|
|
# Zones
|
|
|
|
class ZoneName(str):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class ZoneEntry(RestrictiveBaseModel):
|
|
|
|
addrs: set[IPvAnyNetwork] = set()
|
|
|
|
files: set[FilePath] = set()
|
|
|
|
negate: bool = False
|
|
|
|
zones: set[ZoneName] = set()
|
|
|
|
|
|
|
|
|
|
|
|
# Blacklist
|
|
|
|
class BlackList(RestrictiveBaseModel):
|
|
|
|
addr: list[IPvAnyAddress] = list()
|
|
|
|
|
|
|
|
|
|
|
|
# Reverse Path Filter
|
|
|
|
class ReversePathFilter(RestrictiveBaseModel):
|
|
|
|
enabled: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
# Filters
|
|
|
|
class Verdict(str, Enum):
|
|
|
|
accept = "accept"
|
|
|
|
drop = "drop"
|
|
|
|
reject = "reject"
|
|
|
|
|
|
|
|
|
|
|
|
class TcpProtocol(RestrictiveBaseModel):
|
|
|
|
dport: list[Port | PortRange] = list()
|
|
|
|
sport: list[Port | PortRange] = list()
|
|
|
|
|
|
|
|
|
|
|
|
class UdpProtocol(RestrictiveBaseModel):
|
|
|
|
dport: list[Port | PortRange] = list()
|
|
|
|
sport: list[Port | PortRange] = list()
|
|
|
|
|
|
|
|
|
|
|
|
class Protocols(RestrictiveBaseModel):
|
|
|
|
icmp: bool = False
|
|
|
|
ospf: bool = False
|
|
|
|
tcp: TcpProtocol = TcpProtocol()
|
|
|
|
udp: UdpProtocol = UdpProtocol()
|
|
|
|
vrrp: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
class Rule(RestrictiveBaseModel):
|
|
|
|
iif: str | None
|
|
|
|
oif: str | None
|
|
|
|
protocols: Protocols = Protocols()
|
|
|
|
src: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
|
|
|
|
dst: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
|
|
|
|
verdict: Verdict = Verdict.accept
|
|
|
|
|
|
|
|
|
|
|
|
class ForwardRule(Rule):
|
|
|
|
dest: ZoneName | list[IPvAnyNetwork | ZoneName] | None
|
|
|
|
|
|
|
|
|
|
|
|
class Filter(RestrictiveBaseModel):
|
|
|
|
input: list[Rule] = list()
|
|
|
|
output: list[Rule] = list()
|
|
|
|
forward: list[ForwardRule] = list()
|
|
|
|
|
|
|
|
|
|
|
|
# Nat
|
|
|
|
class SNat(RestrictiveBaseModel):
|
|
|
|
addr: IPvAnyAddress
|
|
|
|
persistent: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
class Nat(RestrictiveBaseModel):
|
|
|
|
src: ZoneName
|
|
|
|
snat: SNat
|
|
|
|
|
|
|
|
|
|
|
|
# Root model
|
|
|
|
class Firewall(RestrictiveBaseModel):
|
|
|
|
zones: dict[ZoneName, ZoneEntry] = list()
|
|
|
|
blacklist: BlackList = BlackList()
|
|
|
|
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
|
|
|
filter: Filter = Filter()
|
|
|
|
nat: list[Nat] = list()
|
|
|
|
|
|
|
|
|
|
|
|
# ==========[ ZONES ]===========================================================
|
|
|
|
|
|
|
|
|
|
|
|
# Zones: Graph resolver
|
|
|
|
def resolve_zones(zones_entries: list[ZoneEntry]) -> None:
|
|
|
|
zone_name = {entry.name: entry.zones for entry in zones_entries}
|
|
|
|
|
|
|
|
for name in TopologicalSorter(zone_name).static_order():
|
|
|
|
print(name)
|
|
|
|
|
|
|
|
# TODO: Check negation inclusion
|
|
|
|
|
|
|
|
|
|
|
|
# ==========[ PARSER ]==========================================================
|
|
|
|
|
|
|
|
|
2023-08-27 12:56:41 +02:00
|
|
|
def parse_blacklist(blacklist: BlackList) -> nft.Table:
|
|
|
|
table = nft.Table(name="blacklist", family="inet")
|
|
|
|
|
|
|
|
# Sets
|
|
|
|
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
|
|
|
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
table.sets.extend([set_v4, set_v6])
|
2023-08-27 12:56:41 +02:00
|
|
|
|
|
|
|
# Chains
|
|
|
|
chain_filter = nft.Chain(
|
|
|
|
name="filter",
|
|
|
|
type="filter",
|
|
|
|
hook="prerouting",
|
|
|
|
policy="accept",
|
|
|
|
priority=-310,
|
|
|
|
)
|
|
|
|
|
|
|
|
chain_v4 = nft.Match(
|
|
|
|
op="==",
|
|
|
|
left=nft.Payload(protocol="ip", field="saddr"),
|
|
|
|
right=nft.Immediate("@blacklist_v4"),
|
|
|
|
)
|
|
|
|
chain_v6 = nft.Match(
|
|
|
|
op="==",
|
|
|
|
left=nft.Payload(protocol="ip6", field="saddr"),
|
|
|
|
right=nft.Immediate("@blacklist_v6"),
|
|
|
|
)
|
|
|
|
|
|
|
|
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")]))
|
|
|
|
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")]))
|
|
|
|
|
|
|
|
table.chains.extend([chain_filter])
|
|
|
|
|
2023-08-26 01:08:30 +02:00
|
|
|
return table
|
|
|
|
|
|
|
|
|
2023-08-27 12:56:41 +02:00
|
|
|
def parse_firewall(firewall: Firewall) -> nft.Ruleset:
|
|
|
|
ruleset = nft.Ruleset(flush=True)
|
2023-08-26 01:08:30 +02:00
|
|
|
blacklist = parse_blacklist(firewall.blacklist)
|
|
|
|
|
|
|
|
ruleset.tables.extend([blacklist])
|
|
|
|
return ruleset
|
|
|
|
|
|
|
|
|
|
|
|
# ==========[ MAIN ]============================================================
|
|
|
|
|
|
|
|
|
2023-08-27 12:56:41 +02:00
|
|
|
def send_to_nftables(cmd: nft.JsonNftables) -> int:
|
2023-08-26 01:08:30 +02:00
|
|
|
nft = Nftables()
|
|
|
|
|
2023-08-27 12:56:41 +02:00
|
|
|
import json
|
|
|
|
|
|
|
|
print(json.dumps(cmd, indent=4))
|
|
|
|
|
2023-08-26 01:08:30 +02:00
|
|
|
try:
|
|
|
|
nft.json_validate(cmd)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"JSON validation failed: {e}")
|
|
|
|
return 1
|
|
|
|
|
|
|
|
rc, output, error = nft.json_cmd(cmd)
|
|
|
|
|
|
|
|
if rc != 0:
|
|
|
|
print(f"nft returned {rc}: {error}")
|
|
|
|
return 1
|
|
|
|
|
|
|
|
if len(output) != 0:
|
|
|
|
print(output)
|
|
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> int:
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
rules = Firewall(**safe_load(args.file))
|
|
|
|
|
|
|
|
return send_to_nftables(parse_firewall(rules).to_nft())
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
exit(main())
|