from dataclasses import dataclass, field
from ipaddress import IPv4Address, IPv4Network, IPv6Network
from itertools import chain
from typing import Any, Generic, TypeVar, get_args

T = TypeVar("T")
JsonNftables = dict[str, Any]


def flatten(l: list[list[T]]) -> list[T]:
    return list(chain.from_iterable(l))


Immediate = int | str | bool | set | range | IPv4Network | IPv6Network


@dataclass
class Ct:
    key: str

    def to_nft(self) -> JsonNftables:
        return {"ct": {"key": self.key}}


@dataclass
class Fib:
    flags: list[str]
    result: str

    def to_nft(self) -> JsonNftables:
        return {"fib": {"flags": self.flags, "result": self.result}}


@dataclass
class Meta:
    key: str

    def to_nft(self) -> JsonNftables:
        return {"meta": {"key": self.key}}


@dataclass
class Payload:
    protocol: str
    field: str

    def to_nft(self) -> JsonNftables:
        return {"payload": {"protocol": self.protocol, "field": self.field}}


Expression = Ct | Fib | Immediate | Meta | Payload


def imm_to_nft(value: Immediate) -> Any:
    if isinstance(value, range):
        return {"range": [value.start, value.stop - 1]}

    elif isinstance(value, IPv4Network | IPv6Network):
        return {
            "prefix": {
                "addr": str(value.network_address),
                "len": value.prefixlen,
            }
        }

    elif isinstance(value, set):
        return {"set": [expr_to_nft(e) for e in value]}

    return value


def expr_to_nft(value: Expression) -> Any:
    if isinstance(value, get_args(Immediate)):
        return imm_to_nft(value)  # type: ignore

    return value.to_nft()  # type: ignore


# Statements
@dataclass
class Counter:
    def to_nft(self) -> JsonNftables:
        return {"counter": {"packets": 0, "bytes": 0}}


@dataclass
class Goto:
    target: str

    def to_nft(self) -> JsonNftables:
        return {"goto": {"target": self.target}}


@dataclass
class Jump:
    target: str

    def to_nft(self) -> JsonNftables:
        return {"jump": {"target": self.target}}


@dataclass
class Match:
    op: str
    left: Expression
    right: Expression

    def to_nft(self) -> JsonNftables:
        match = {
            "op": self.op,
            "left": expr_to_nft(self.left),
            "right": expr_to_nft(self.right),
        }

        return {"match": match}


@dataclass
class Snat:
    addr: IPv4Network | IPv4Address
    port: range | None
    persistent: bool

    def to_nft(self) -> JsonNftables:
        snat: JsonNftables = {}

        if isinstance(self.addr, IPv4Network):
            snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
        else:
            snat["addr"] = str(self.addr)

        if self.port is not None:
            snat["port"] = imm_to_nft(self.port)

        if self.persistent:
            snat["flags"] = "persistent"

        return {"snat": snat}


@dataclass
class Verdict:
    verdict: str

    target: str | None = None

    def to_nft(self) -> JsonNftables:
        return {self.verdict: self.target}


Statement = Counter | Goto | Jump | Match | Snat | Verdict


# Ruleset
@dataclass
class Set:
    name: str
    type: str

    flags: list[str] | None = None
    elements: list[Immediate] = field(default_factory=list)

    def to_nft(self, family: str, table: str) -> JsonNftables:
        set: JsonNftables = {
            "name": self.name,
            "family": family,
            "table": table,
            "type": self.type,
        }

        if self.elements:
            set["elem"] = [imm_to_nft(e) for e in self.elements]

        if self.flags:
            set["flags"] = self.flags

        return {"add": {"set": set}}


@dataclass
class Rule:
    stmts: list[Statement] = field(default_factory=list)

    def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
        rule = {
            "family": family,
            "table": table,
            "chain": chain,
            "expr": [stmt.to_nft() for stmt in self.stmts],
        }

        return {"add": {"rule": rule}}


@dataclass
class Chain:
    name: str

    type: str | None = None
    hook: str | None = None
    priority: int | None = None
    policy: str | None = None

    rules: list[Rule] = field(default_factory=list)

    def to_nft(self, family: str, table: str) -> list[JsonNftables]:
        chain: JsonNftables = {
            "name": self.name,
            "family": family,
            "table": table,
        }

        if self.type is not None:
            chain["type"] = self.type

        if self.hook is not None:
            chain["hook"] = self.hook

        if self.priority is not None:
            chain["prio"] = self.priority

        if self.policy is not None:
            chain["policy"] = self.policy

        commands = [{"add": {"chain": chain}}]

        for rule in self.rules:
            commands.append(rule.to_nft(family, table, self.name))

        return commands


@dataclass
class Table:
    family: str
    name: str

    chains: list[Chain] = field(default_factory=list)
    sets: list[Set] = field(default_factory=list)

    def to_nft(self) -> list[JsonNftables]:
        commands = [
            {"add": {"table": {"family": self.family, "name": self.name}}}
        ]

        for set in self.sets:
            commands.append(set.to_nft(self.family, self.name))

        for chain in self.chains:
            commands.extend(chain.to_nft(self.family, self.name))

        return commands


@dataclass
class Ruleset:
    flush: bool

    tables: list[Table] = field(default_factory=list)

    def to_nft(self) -> JsonNftables:
        ruleset = flatten([table.to_nft() for table in self.tables])

        if self.flush:
            ruleset.insert(0, {"flush": {"ruleset": None}})

        return {"nftables": ruleset}