Files
fwl/src/FWL/Compile.hs
2026-05-03 17:46:52 -07:00

314 lines
12 KiB
Haskell

{-# LANGUAGE OverloadedStrings #-}
{- | Compile a checked FWL program to nftables JSON using Aeson.
All policies (Filter and NAT) go into one table named by Config.
Layer stripping: Frame patterns that omit Ether compile identically
to those that include it.
-}
module FWL.Compile
( compileProgram
, compileToJson
) where
import Data.List (intercalate)
import Data.Maybe (mapMaybe)
import qualified Data.Map.Strict as Map
import Data.Aeson ((.=), Value(..), object, toJSON)
import qualified Data.Aeson as A
import qualified Data.Text as T
import qualified Data.ByteString.Lazy as BL
import Data.Aeson.Encode.Pretty (encodePretty)
import FWL.AST
-- ─── Entry points ────────────────────────────────────────────────────────────
compileToJson :: Program -> BL.ByteString
compileToJson = encodePretty . programToValue
-- exposed for tests
compileProgram :: Program -> Value
compileProgram = programToValue
programToValue :: Program -> Value
programToValue (Program cfg decls) =
object [ "nftables" .= toJSON
(metainfo : tableObj : chainObjs ++ mapObjs ++ ruleObjs) ]
where
env = buildEnv decls
tbl = configTable cfg
metainfo = object [ "metainfo" .= object
[ "json_schema_version" .= (1 :: Int) ] ]
tableObj = object [ "table" .= tableValue tbl ]
policies = [ (n, pm, ab) | DPolicy n _ pm ab <- decls ]
chainObjs = map (\(n, pm, _ ) -> chainDeclValue tbl n pm) policies
ruleObjs = concatMap
(\(n, _, ab) -> concatMap (armToRuleValues env tbl n) ab)
policies
letDecls = [ (n, t, e) | DLet n t e <- decls ]
mapObjs = mapMaybe (\(n, _, e) -> letToMapValue tbl n e) letDecls
-- ─── Table / Chain declarations ──────────────────────────────────────────────
tableValue :: String -> Value
tableValue tbl = object
[ "family" .= ("inet" :: String)
, "name" .= tbl
]
chainDeclValue :: String -> Name -> PolicyMeta -> Value
chainDeclValue tbl n pm = object
[ "chain" .= object
[ "family" .= ("inet" :: String)
, "table" .= tbl
, "name" .= n
, "type" .= chainTypeStr (pmTable pm)
, "hook" .= hookStr (pmHook pm)
, "prio" .= priorityInt (pmPriority pm)
, "policy" .= defaultPolicyStr (pmHook pm)
]
]
chainTypeStr :: TableName -> String
chainTypeStr TFilter = "filter"
chainTypeStr TNAT = "nat"
hookStr :: Hook -> String
hookStr HInput = "input"
hookStr HForward = "forward"
hookStr HOutput = "output"
hookStr HPrerouting = "prerouting"
hookStr HPostrouting = "postrouting"
-- Priority is emitted as an integer in nftables JSON.
priorityInt :: Priority -> Int
priorityInt = priorityValue
defaultPolicyStr :: Hook -> String
defaultPolicyStr HInput = "drop"
defaultPolicyStr HForward = "drop"
defaultPolicyStr _ = "accept"
-- ─── Arm → Rule objects ──────────────────────────────────────────────────────
armToRuleValues :: CompileEnv -> String -> Name -> Arm -> [Value]
armToRuleValues env tbl chain (Arm p mg body) =
case compileAction env body of
Nothing -> []
Just verdict ->
let patExprs = compilePat env p
guardExprs = maybe [] (compileGuard env) mg
allExprs = patExprs ++ guardExprs ++ [verdict]
in [ object
[ "rule" .= object
[ "family" .= ("inet" :: String)
, "table" .= tbl
, "chain" .= chain
, "expr" .= toJSON allExprs
]
]
]
-- ─── Pattern → [Value] ───────────────────────────────────────────────────────
type CompileEnv = Map.Map String Decl
buildEnv :: [Decl] -> CompileEnv
buildEnv = foldr (\d m -> Map.insert (declNameOf d) d m) Map.empty
where
declNameOf (DInterface n _ _) = n
declNameOf (DZone n _) = n
declNameOf (DPattern n _ _) = n
declNameOf (DFlow n _) = n
declNameOf (DRule n _ _) = n
declNameOf (DPolicy n _ _ _) = n
declNameOf (DLet n _ _) = n
declNameOf (DImport n _ _) = n
compilePat :: CompileEnv -> Pat -> [Value]
compilePat _ PWild = []
compilePat _ (PVar _) = []
compilePat env (PNamed n) = expandNamedPat env n
compilePat env (PFrame mp inner) =
maybe [] (compilePathPat env) mp ++ compilePat env inner
compilePat env (PCtor n ps) = compileCtorPat env n ps
compilePat _ (PRecord n fs) = compileRecordPat n fs
compilePat env (PTuple ps) = concatMap (compilePat env) ps
compilePat _ (PBytes _) = []
expandNamedPat :: CompileEnv -> Name -> [Value]
expandNamedPat env n =
case Map.lookup n env of
Just (DPattern _ _ p) -> compilePat env p
_ -> []
compileCtorPat :: CompileEnv -> String -> [Pat] -> [Value]
compileCtorPat env ctor ps = case ctor of
"Ether" -> children
"IPv4" -> matchMeta "nfproto" "ipv4" : children
"IPv6" -> matchMeta "nfproto" "ipv6" : children
"TCP" -> matchPayload "th" "protocol" "tcp" : children
"UDP" -> matchPayload "th" "protocol" "udp" : children
"ICMPv6" -> matchPayload "ip6" "nexthdr" "ipv6-icmp" : children
"ICMP" -> matchPayload "ip" "protocol" "icmp" : children
_ -> children
where
children = concatMap (compilePat env) ps
compileRecordPat :: String -> [FieldPat] -> [Value]
compileRecordPat proto = mapMaybe go
where
go (FPEq field lit) = Just (matchPayload proto field (renderLit lit))
go _ = Nothing
compilePathPat :: CompileEnv -> PathPat -> [Value]
compilePathPat _ (PathPat ms md) =
maybe [] (compileEndpoint "iifname") ms ++
maybe [] (compileEndpoint "oifname") md
compileEndpoint :: String -> EndpointPat -> [Value]
compileEndpoint _ EPWild = []
compileEndpoint dir (EPName n) = [matchMeta dir n]
compileEndpoint dir (EPMember _ z) = [matchInSet (metaVal dir) [z]]
-- ─── Guard → [Value] ─────────────────────────────────────────────────────────
compileGuard :: CompileEnv -> Expr -> [Value]
compileGuard env (EInfix OpAnd l r) = compileGuard env l ++ compileGuard env r
compileGuard _ (EInfix OpIn l r) = [compileInExpr l r]
compileGuard _ (EInfix OpEq l r) = [matchExpr "==" (exprVal l) (exprVal r)]
compileGuard _ (EInfix OpNeq l r) = [matchExpr "!=" (exprVal l) (exprVal r)]
compileGuard _ _ = []
compileInExpr :: Expr -> Expr -> Value
-- Fix 4: put the more-specific ct patterns BEFORE the generic 2-element
-- EQual case to eliminate the overlapping pattern match warning.
compileInExpr (EQual ["ct", "state"]) (ESet vs) = ctMatch "state" vs
compileInExpr (EQual ["ct", "status"]) (ESet vs) = ctMatch "status" vs
compileInExpr l (ESet vs) =
matchExpr "in" (exprVal l) (setVal (map exprToStr vs))
compileInExpr l r =
matchExpr "==" (exprVal l) (exprVal r)
ctMatch :: String -> [Expr] -> Value
ctMatch key vs = matchExpr "in"
(object ["ct" .= object ["key" .= (key :: String)]])
(setVal (map exprToStr vs))
-- ─── Action → Maybe Value ─────────────────────────────────────────────────────
compileAction :: CompileEnv -> Expr -> Maybe Value
compileAction _ (EVar "Allow") = Just (object ["accept" .= Null])
compileAction _ (EVar "Drop") = Just (object ["drop" .= Null])
compileAction _ (EVar "Continue") = Nothing
compileAction _ (EVar "Masquerade") = Just (object ["masquerade" .= Null])
compileAction _ (EApp (EVar "DNAT") arg) =
Just $ object ["dnat" .= object ["addr" .= exprToStr arg]]
compileAction _ (EApp (EVar "DNATMap") arg) =
Just $ object ["dnat" .= object ["addr" .= object
[ "map" .= object [ "key" .= object ["concat" .= Array mempty]
, "data" .= exprToStr arg ]]]]
compileAction env (EApp (EVar rn) _) =
case Map.lookup rn env of
Just (DRule _ _ _) -> Just $ object ["jump" .= object ["target" .= rn]]
_ -> Just (object ["accept" .= Null])
compileAction _ _ = Just (object ["accept" .= Null])
-- ─── Let → Map object ────────────────────────────────────────────────────────
letToMapValue :: String -> Name -> Expr -> Maybe Value
letToMapValue tbl n (EMap entries) = Just $ object
[ "map" .= object
[ "family" .= ("inet" :: String)
, "table" .= tbl
, "name" .= n
, "type" .= ("inetproto . inetservice" :: String)
, "map" .= ("ipv4_addr . inetservice" :: String)
, "elem" .= toJSON (map renderMapElem entries)
]
]
letToMapValue _ _ _ = Nothing
renderMapElem :: (Expr, Expr) -> Value
renderMapElem (k, v) = toJSON
[ object ["concat" .= toJSON [exprToStr k]]
, A.String (toText (exprToStr v))
]
-- ─── Aeson building blocks ───────────────────────────────────────────────────
matchExpr :: String -> Value -> Value -> Value
matchExpr op l r = object
[ "match" .= object
[ "op" .= (op :: String)
, "left" .= l
, "right" .= r
]
]
matchMeta :: String -> String -> Value
matchMeta key val = matchExpr "==" (metaVal key) (A.String (toText val))
matchPayload :: String -> String -> String -> Value
matchPayload proto field val =
matchExpr "==" (payloadVal proto field) (A.String (toText val))
matchInSet :: Value -> [String] -> Value
matchInSet lhs vals = matchExpr "in" lhs (setVal vals)
metaVal :: String -> Value
metaVal key = object ["meta" .= object ["key" .= (key :: String)]]
payloadVal :: String -> String -> Value
payloadVal proto field =
object ["payload" .= object
[ "protocol" .= (proto :: String)
, "field" .= (field :: String)
]]
setVal :: [String] -> Value
setVal vs = object ["set" .= toJSON vs]
-- ─── Expression helpers ───────────────────────────────────────────────────────
-- Fix 3 (overlap): specific ct pattern first, generic 2-element case second.
exprVal :: Expr -> Value
exprVal (EQual ["ct", k]) = object ["ct" .= object ["key" .= (k :: String)]]
exprVal (EQual [p, f]) = payloadVal p f
exprVal (EQual ns) = A.String (toText (intercalate "." ns))
exprVal (EVar n) = metaVal n
exprVal (ELit l) = A.String (toText (renderLit l))
exprVal (ESet vs) = setVal (map exprToStr vs)
exprVal e = A.String (toText (exprToStr e))
exprToStr :: Expr -> String
exprToStr (EVar n) = n
exprToStr (ELit l) = renderLit l
exprToStr (EQual ns) = intercalate "." ns
exprToStr (ETuple es) = intercalate " . " (map exprToStr es)
exprToStr _ = "_"
-- Fix 2: Use Data.Text.pack via OverloadedStrings + fromString instead of
-- the fragile read(show s) hack. With OverloadedStrings enabled, string
-- literals already produce the correct Text/Key types; for runtime String
toText :: String -> T.Text
toText = T.pack
renderLit :: Literal -> String
renderLit (LInt n) = show n
renderLit (LString s) = s
renderLit (LBool True) = "true"
renderLit (LBool False) = "false"
renderLit (LIPv4 (a, b, c, d)) =
show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d
renderLit (LIPv6 _) = "::1"
renderLit (LCIDR ip p) = renderLit ip ++ "/" ++ show p
renderLit (LPort p) = show p
renderLit (LDuration n Seconds) = show n ++ "s"
renderLit (LDuration n Millis) = show n ++ "ms"
renderLit (LDuration n Minutes) = show n ++ "m"
renderLit (LDuration n Hours) = show n ++ "h"
renderLit (LHex b) = show b