234 lines
7.2 KiB
Haskell
234 lines
7.2 KiB
Haskell
module FWL.AST where
|
|
|
|
import Data.Bits ((.&.), (.|.), shiftL, shiftR)
|
|
import Data.Word (Word8) -- Word8 still used for ByteElem/hex literals
|
|
|
|
type Name = String
|
|
|
|
-- ─── Program ────────────────────────────────────────────────────────────────
|
|
|
|
data Program = Program
|
|
{ progConfig :: Config
|
|
, progDecls :: [Decl]
|
|
} deriving (Show)
|
|
|
|
data Config = Config
|
|
{ configTable :: String -- default "fwl"
|
|
} deriving (Show)
|
|
|
|
defaultConfig :: Config
|
|
defaultConfig = Config { configTable = "fwl" }
|
|
|
|
-- ─── Declarations ───────────────────────────────────────────────────────────
|
|
|
|
data Decl
|
|
= DInterface Name IfaceKind [IfaceProp]
|
|
| DZone Name [Name]
|
|
| DImport Name Type FilePath
|
|
| DLet Name Type Expr
|
|
| DPattern Name Type Pat
|
|
| DFlow Name FlowExpr
|
|
| DRule Name Type Expr
|
|
| DPolicy Name Type PolicyMeta ArmBlock
|
|
deriving (Show)
|
|
|
|
data PolicyMeta = PolicyMeta
|
|
{ pmHook :: Hook
|
|
, pmTable :: TableName
|
|
, pmPriority :: Priority
|
|
} deriving (Show)
|
|
|
|
data Hook = HInput | HForward | HOutput | HPrerouting | HPostrouting
|
|
deriving (Show, Eq)
|
|
data TableName = TFilter | TNAT
|
|
deriving (Show, Eq)
|
|
-- Priority is always an integer in the nftables JSON.
|
|
-- Named constants are resolved to their numeric values at parse time.
|
|
newtype Priority = Priority { priorityValue :: Int }
|
|
deriving (Show, Eq)
|
|
|
|
-- Standard nftables priority constants
|
|
pRaw, pConnTrackDefrag, pConnTrack, pMangle, pDstNat, pFilter, pSecurity, pSrcNat :: Priority
|
|
pRaw = Priority (-300)
|
|
pConnTrackDefrag = Priority (-400)
|
|
pConnTrack = Priority (-200)
|
|
pMangle = Priority (-150)
|
|
pDstNat = Priority (-100)
|
|
pFilter = Priority 0
|
|
pSecurity = Priority 50
|
|
pSrcNat = Priority 100
|
|
|
|
data IfaceKind = IWan | ILan | IWireGuard | IUser Name
|
|
deriving (Show)
|
|
|
|
data IfaceProp
|
|
= IPDynamic
|
|
| IPCidr4 [CIDR]
|
|
| IPCidr6 [CIDR]
|
|
deriving (Show)
|
|
|
|
-- | A CIDR block: base address literal paired with prefix length.
|
|
-- e.g. (LIPv4 (10,0,0,0), 8) represents 10.0.0.0/8
|
|
type CIDR = (Literal, Int)
|
|
|
|
-- ─── Patterns ───────────────────────────────────────────────────────────────
|
|
|
|
data Pat
|
|
= PWild
|
|
| PVar Name
|
|
| PNamed Name
|
|
| PCtor Name [Pat]
|
|
| PRecord Name [FieldPat]
|
|
| PTuple [Pat]
|
|
| PFrame (Maybe PathPat) Pat
|
|
| PBytes [ByteElem]
|
|
deriving (Show)
|
|
|
|
data FieldPat
|
|
= FPEq Name Literal
|
|
| FPBind Name
|
|
| FPAs Name Name
|
|
deriving (Show)
|
|
|
|
data PathPat = PathPat (Maybe EndpointPat) (Maybe EndpointPat)
|
|
deriving (Show)
|
|
|
|
data EndpointPat
|
|
= EPWild
|
|
| EPName Name
|
|
| EPMember Name Name
|
|
deriving (Show)
|
|
|
|
data ByteElem
|
|
= BEHex Word8
|
|
| BEWild
|
|
| BEWildStar
|
|
deriving (Show)
|
|
|
|
-- ─── Flow ───────────────────────────────────────────────────────────────────
|
|
|
|
data FlowExpr
|
|
= FAtom Name
|
|
| FSeq FlowExpr FlowExpr (Maybe Duration)
|
|
deriving (Show)
|
|
|
|
type Duration = (Int, TimeUnit)
|
|
|
|
-- Fix 1: TimeUnit must derive Eq because Literal (which embeds it via
|
|
-- LDuration) derives Eq, requiring all constituent types to also have Eq.
|
|
data TimeUnit = Seconds | Millis | Minutes | Hours
|
|
deriving (Show, Eq)
|
|
|
|
-- ─── Types ──────────────────────────────────────────────────────────────────
|
|
|
|
data Type
|
|
= TName Name [Type]
|
|
| TTuple [Type]
|
|
| TFun Type Type
|
|
| TEffect [Name] Type
|
|
deriving (Show)
|
|
|
|
-- ─── Expressions ────────────────────────────────────────────────────────────
|
|
|
|
data Expr
|
|
= EVar Name
|
|
| EQual [Name]
|
|
| ELit Literal
|
|
| ELam Name Expr
|
|
| EApp Expr Expr
|
|
| ECase Expr ArmBlock
|
|
| EIf Expr Expr Expr
|
|
| EDo [DoStmt]
|
|
| ELet Name Expr Expr
|
|
| ETuple [Expr]
|
|
| ESet [Expr]
|
|
| EMap [(Expr, Expr)]
|
|
| EPerform [Name] [Expr]
|
|
| EInfix InfixOp Expr Expr
|
|
| ENot Expr
|
|
deriving (Show)
|
|
|
|
data InfixOp
|
|
= OpAnd | OpOr
|
|
| OpEq | OpNeq | OpLt | OpLte | OpGt | OpGte
|
|
| OpIn
|
|
| OpConcat
|
|
| OpThen
|
|
| OpBind
|
|
deriving (Show, Eq)
|
|
|
|
data DoStmt
|
|
= DSBind Name Expr
|
|
| DSExpr Expr
|
|
deriving (Show)
|
|
|
|
type ArmBlock = [Arm]
|
|
data Arm = Arm Pat (Maybe Expr) Expr
|
|
deriving (Show)
|
|
|
|
-- ─── Literals ───────────────────────────────────────────────────────────────
|
|
|
|
-- IP addresses are stored as plain Integers for easy arithmetic,
|
|
-- CIDR validation (mask host bits), and future subnet math.
|
|
-- IPv4: 32-bit value in the low 32 bits.
|
|
-- IPv6: 128-bit value.
|
|
-- CIDR host-bit validation: (addr .&. hostMask prefix bits) == 0
|
|
data IPVersion = IPv4 | IPv6
|
|
deriving (Show, Eq)
|
|
|
|
data Literal
|
|
= LInt Int
|
|
| LString String
|
|
| LBool Bool
|
|
| LIP IPVersion Integer -- unified IP address representation
|
|
| LCIDR Literal Int -- base address + prefix length
|
|
| LPort Int
|
|
| LDuration Int TimeUnit
|
|
| LHex Word8
|
|
deriving (Show, Eq)
|
|
|
|
-- ─── IP address helpers ──────────────────────────────────────────────────────
|
|
|
|
-- | Build an IPv4 literal from four octets.
|
|
ipv4Lit :: Int -> Int -> Int -> Int -> Literal
|
|
ipv4Lit a b c d =
|
|
LIP IPv4 (fromIntegral a `shiftL` 24
|
|
.|. fromIntegral b `shiftL` 16
|
|
.|. fromIntegral c `shiftL` 8
|
|
.|. fromIntegral d)
|
|
|
|
-- | Check that a CIDR has no host bits set.
|
|
cidrHostBitsZero :: Integer -> Int -> Int -> Bool
|
|
cidrHostBitsZero addr prefix bits =
|
|
let hostBits = bits - prefix
|
|
hostMask = (1 `shiftL` hostBits) - 1
|
|
in (addr .&. hostMask) == 0
|
|
|
|
-- | Render an IPv4 integer as a dotted-decimal string.
|
|
renderIPv4 :: Integer -> String
|
|
renderIPv4 n =
|
|
show ((n `shiftR` 24) .&. 0xff) ++ "." ++
|
|
show ((n `shiftR` 16) .&. 0xff) ++ "." ++
|
|
show ((n `shiftR` 8) .&. 0xff) ++ "." ++
|
|
show (n .&. 0xff)
|
|
|
|
-- | Render an IPv6 integer as a condensed colon-hex string.
|
|
renderIPv6 :: Integer -> String
|
|
renderIPv6 n =
|
|
let groups = [ fromIntegral ((n `shiftR` (i * 16)) .&. 0xffff) :: Int
|
|
| i <- [7,6..0] ]
|
|
hexGroups = map (`showHex` "") groups
|
|
in concatIntersperse ":" hexGroups
|
|
where
|
|
showHex x s = let h = showHexInt x in h ++ s
|
|
showHexInt x
|
|
| x == 0 = "0"
|
|
| otherwise = reverse (go x)
|
|
where go 0 = []
|
|
go v = let (q,r) = v `divMod` 16
|
|
c = "0123456789abcdef" !! r
|
|
in c : go q
|
|
concatIntersperse _ [] = ""
|
|
concatIntersperse _ [x] = x
|
|
concatIntersperse s (x:xs) = x ++ s ++ concatIntersperse s xs
|