v2 perplexed

This commit is contained in:
2026-05-03 17:46:52 -07:00
parent 30427521ca
commit 2a44095791
16 changed files with 3091 additions and 0 deletions

54
app/Main.hs Normal file
View File

@@ -0,0 +1,54 @@
module Main where
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
import System.IO (hPutStrLn, stderr)
import qualified Data.ByteString.Lazy.Char8 as BL
import FWL.Parser (parseFile)
import FWL.Pretty (prettyProgram)
import FWL.Check (checkProgram)
import FWL.Compile (compileToJson)
main :: IO ()
main = do
args <- getArgs
case args of
["check", fp] -> runCheck fp
["compile", fp] -> runCompile fp
["pretty", fp] -> runPretty fp
_ -> do
putStrLn "Usage: fwlc <command> <file.fwl>"
putStrLn " check <file> -- parse and static-check"
putStrLn " compile <file> -- emit nftables JSON to stdout"
putStrLn " pretty <file> -- parse and re-print"
exitFailure
runCheck :: FilePath -> IO ()
runCheck fp = do
result <- parseFile fp
case result of
Left err -> hPutStrLn stderr ("Parse error:\n" ++ show err) >> exitFailure
Right prog -> do
let errs = checkProgram prog
if null errs
then putStrLn "OK" >> exitSuccess
else mapM_ (hPutStrLn stderr . show) errs >> exitFailure
runCompile :: FilePath -> IO ()
runCompile fp = do
result <- parseFile fp
case result of
Left err -> hPutStrLn stderr ("Parse error:\n" ++ show err) >> exitFailure
Right prog -> do
let errs = checkProgram prog
if null errs
then BL.putStrLn (compileToJson prog)
else mapM_ (hPutStrLn stderr . ("Check error: " ++) . show) errs >> exitFailure
runPretty :: FilePath -> IO ()
runPretty fp = do
result <- parseFile fp
case result of
Left err -> hPutStrLn stderr ("Parse error:\n" ++ show err) >> exitFailure
Right prog -> putStr (prettyProgram prog)

1
cabal.project Normal file
View File

@@ -0,0 +1 @@
packages: .

95
examples/router.fwl Normal file
View File

@@ -0,0 +1,95 @@
-- Example: home router firewall in FWL
-- Compile with: fwlc compile examples/router.fwl
interface wan : WAN { dynamic; };
interface lan : LAN { cidr4 = { 10.17.1.0/24 }; };
interface wg0 : WireGuard {};
zone lan_zone = { lan, wg0 };
import rfc1918 : CIDRSet from "builtin:rfc1918";
let forwards : Map<(Protocol, Port), (IP, Port)> = {
(tcp, :8080) -> (10.17.1.10, :80),
(tcp, :2222) -> (10.17.1.11, :22)
};
-- WireGuard handshake detection (compiles to ct mark state machine)
pattern WGInitiation : (UDPHeader, Bytes) =
(udp { length = 156 }, [0x01 _*]);
pattern WGResponse : (UDPHeader, Bytes) =
(udp { length = 100 }, [0x02 _*]);
flow WireGuardHandshake : FlowPattern =
WGInitiation . WGResponse within 5s;
-- Block LAN clients from tunnelling out via WireGuard
rule blockOutboundWG : Frame -> <FlowMatch, Log> Action =
\frame ->
case frame of {
| Frame(iif in lan_zone -> wan, IPv4(ip, UDP(udp, payload)))
if matches(WGInitiation, (udp, payload)) ->
case perform FlowMatch.check(flowOf(ip, wg), WireGuardHandshake) of {
| Matched -> do {
perform Log.emit(Warn, "WG blocked");
Drop
};
| _ -> Continue;
};
| _ -> Continue;
};
-- Inbound to router
policy input : Frame
on { hook = Input, table = Filter, priority = Filter }
= {
| _ if ct.state in { Established, Related } -> Allow;
| Frame(lo, _) -> Allow;
| Frame(_, IPv6(ip6, ICMPv6(_, _)))
if ip6.src in fe80::/10 -> Allow;
| Frame(_, IPv4(_, TCP(tcp, _)))
if tcp.dport == :22 -> Allow;
| Frame(_, IPv4(_, UDP(udp, _)))
if udp.dport == :51944 -> Allow;
| _ -> Drop;
};
-- Forwarded traffic
policy forward : Frame
on { hook = Forward, table = Filter, priority = Filter }
= {
| _ if ct.state in { Established, Related } -> Allow;
| frame if iif in lan_zone && oif == wan -> blockOutboundWG(frame);
| _ if ct.status == DNAT -> Allow;
| Frame(iif in lan_zone -> wan, _) -> Allow;
| Frame(iif in lan_zone -> lan_zone, _) -> Allow;
| Frame(wan -> lan_zone, IPv4(ip, TCP(tcp, _)))
if (ip.dst, tcp.dport) in forwards -> Allow;
| _ -> Drop;
};
-- Outbound from router
policy output : Frame
on { hook = Output, table = Filter, priority = Filter }
= {
| _ -> Allow;
};
-- NAT
policy nat_prerouting : Frame
on { hook = Prerouting, table = NAT, priority = DstNat }
= {
| Frame(_, IPv4(ip, _)) ->
if perform FIB.daddrLocal(ip.dst)
then DNATMap(forwards)
else Allow;
| _ -> Allow;
};
policy nat_postrouting : Frame
on { hook = Postrouting, table = NAT, priority = SrcNat }
= {
| Frame(_ -> wan, IPv4(ip, _)) if ip.src in rfc1918 -> Masquerade;
| _ -> Allow;
};

58
fwl.cabal Normal file
View File

@@ -0,0 +1,58 @@
cabal-version: 3.0
name: fwl
version: 0.1.0.0
synopsis: Firewall Language — MVP
build-type: Simple
common shared
ghc-options: -Wall
default-language: Haskell2010
library
import: shared
hs-source-dirs: src
exposed-modules:
FWL.AST
, FWL.Lexer
, FWL.Parser
, FWL.Pretty
, FWL.Check
, FWL.Compile
build-depends:
base >= 4.14
, parsec >= 3.1
, aeson >= 2.0
, aeson-pretty >= 0.8
, text >= 1.2
, containers >= 0.6
, mtl >= 2.2
, prettyprinter >= 1.7
, bytestring >= 0.11
, word8 >= 0.1
executable fwlc
import: shared
main-is: Main.hs
hs-source-dirs: app
build-depends:
base, fwl, text, aeson-pretty, bytestring
test-suite fwl-tests
import: shared
type: exitcode-stdio-1.0
main-is: Spec.hs
hs-source-dirs: test
other-modules:
FWL.Util
, ParserTests
, CheckTests
, CompileTests
build-depends:
base, fwl
, tasty >= 1.4
, tasty-hunit >= 0.10
, aeson >= 2.0
, aeson-pretty >= 0.8
, bytestring >= 0.11
, parsec >= 3.1
, vector >= 0.12

233
src/FWL/AST.hs Normal file
View File

@@ -0,0 +1,233 @@
module FWL.AST where
import Data.Bits ((.&.), (.|.), shiftL, shiftR)
import Data.Word (Word8) -- Word8 still used for ByteElem/hex literals
type Name = String
-- ─── Program ────────────────────────────────────────────────────────────────
data Program = Program
{ progConfig :: Config
, progDecls :: [Decl]
} deriving (Show)
data Config = Config
{ configTable :: String -- default "fwl"
} deriving (Show)
defaultConfig :: Config
defaultConfig = Config { configTable = "fwl" }
-- ─── Declarations ───────────────────────────────────────────────────────────
data Decl
= DInterface Name IfaceKind [IfaceProp]
| DZone Name [Name]
| DImport Name Type FilePath
| DLet Name Type Expr
| DPattern Name Type Pat
| DFlow Name FlowExpr
| DRule Name Type Expr
| DPolicy Name Type PolicyMeta ArmBlock
deriving (Show)
data PolicyMeta = PolicyMeta
{ pmHook :: Hook
, pmTable :: TableName
, pmPriority :: Priority
} deriving (Show)
data Hook = HInput | HForward | HOutput | HPrerouting | HPostrouting
deriving (Show, Eq)
data TableName = TFilter | TNAT
deriving (Show, Eq)
-- Priority is always an integer in the nftables JSON.
-- Named constants are resolved to their numeric values at parse time.
newtype Priority = Priority { priorityValue :: Int }
deriving (Show, Eq)
-- Standard nftables priority constants
pRaw, pConnTrackDefrag, pConnTrack, pMangle, pDstNat, pFilter, pSecurity, pSrcNat :: Priority
pRaw = Priority (-300)
pConnTrackDefrag = Priority (-400)
pConnTrack = Priority (-200)
pMangle = Priority (-150)
pDstNat = Priority (-100)
pFilter = Priority 0
pSecurity = Priority 50
pSrcNat = Priority 100
data IfaceKind = IWan | ILan | IWireGuard | IUser Name
deriving (Show)
data IfaceProp
= IPDynamic
| IPCidr4 [CIDR]
| IPCidr6 [CIDR]
deriving (Show)
-- | A CIDR block: base address literal paired with prefix length.
-- e.g. (LIPv4 (10,0,0,0), 8) represents 10.0.0.0/8
type CIDR = (Literal, Int)
-- ─── Patterns ───────────────────────────────────────────────────────────────
data Pat
= PWild
| PVar Name
| PNamed Name
| PCtor Name [Pat]
| PRecord Name [FieldPat]
| PTuple [Pat]
| PFrame (Maybe PathPat) Pat
| PBytes [ByteElem]
deriving (Show)
data FieldPat
= FPEq Name Literal
| FPBind Name
| FPAs Name Name
deriving (Show)
data PathPat = PathPat (Maybe EndpointPat) (Maybe EndpointPat)
deriving (Show)
data EndpointPat
= EPWild
| EPName Name
| EPMember Name Name
deriving (Show)
data ByteElem
= BEHex Word8
| BEWild
| BEWildStar
deriving (Show)
-- ─── Flow ───────────────────────────────────────────────────────────────────
data FlowExpr
= FAtom Name
| FSeq FlowExpr FlowExpr (Maybe Duration)
deriving (Show)
type Duration = (Int, TimeUnit)
-- Fix 1: TimeUnit must derive Eq because Literal (which embeds it via
-- LDuration) derives Eq, requiring all constituent types to also have Eq.
data TimeUnit = Seconds | Millis | Minutes | Hours
deriving (Show, Eq)
-- ─── Types ──────────────────────────────────────────────────────────────────
data Type
= TName Name [Type]
| TTuple [Type]
| TFun Type Type
| TEffect [Name] Type
deriving (Show)
-- ─── Expressions ────────────────────────────────────────────────────────────
data Expr
= EVar Name
| EQual [Name]
| ELit Literal
| ELam Name Expr
| EApp Expr Expr
| ECase Expr ArmBlock
| EIf Expr Expr Expr
| EDo [DoStmt]
| ELet Name Expr Expr
| ETuple [Expr]
| ESet [Expr]
| EMap [(Expr, Expr)]
| EPerform [Name] [Expr]
| EInfix InfixOp Expr Expr
| ENot Expr
deriving (Show)
data InfixOp
= OpAnd | OpOr
| OpEq | OpNeq | OpLt | OpLte | OpGt | OpGte
| OpIn
| OpConcat
| OpThen
| OpBind
deriving (Show, Eq)
data DoStmt
= DSBind Name Expr
| DSExpr Expr
deriving (Show)
type ArmBlock = [Arm]
data Arm = Arm Pat (Maybe Expr) Expr
deriving (Show)
-- ─── Literals ───────────────────────────────────────────────────────────────
-- IP addresses are stored as plain Integers for easy arithmetic,
-- CIDR validation (mask host bits), and future subnet math.
-- IPv4: 32-bit value in the low 32 bits.
-- IPv6: 128-bit value.
-- CIDR host-bit validation: (addr .&. hostMask prefix bits) == 0
data IPVersion = IPv4 | IPv6
deriving (Show, Eq)
data Literal
= LInt Int
| LString String
| LBool Bool
| LIP IPVersion Integer -- unified IP address representation
| LCIDR Literal Int -- base address + prefix length
| LPort Int
| LDuration Int TimeUnit
| LHex Word8
deriving (Show, Eq)
-- ─── IP address helpers ──────────────────────────────────────────────────────
-- | Build an IPv4 literal from four octets.
ipv4Lit :: Int -> Int -> Int -> Int -> Literal
ipv4Lit a b c d =
LIP IPv4 (fromIntegral a `shiftL` 24
.|. fromIntegral b `shiftL` 16
.|. fromIntegral c `shiftL` 8
.|. fromIntegral d)
-- | Check that a CIDR has no host bits set.
cidrHostBitsZero :: Integer -> Int -> Int -> Bool
cidrHostBitsZero addr prefix bits =
let hostBits = bits - prefix
hostMask = (1 `shiftL` hostBits) - 1
in (addr .&. hostMask) == 0
-- | Render an IPv4 integer as a dotted-decimal string.
renderIPv4 :: Integer -> String
renderIPv4 n =
show ((n `shiftR` 24) .&. 0xff) ++ "." ++
show ((n `shiftR` 16) .&. 0xff) ++ "." ++
show ((n `shiftR` 8) .&. 0xff) ++ "." ++
show (n .&. 0xff)
-- | Render an IPv6 integer as a condensed colon-hex string.
renderIPv6 :: Integer -> String
renderIPv6 n =
let groups = [ fromIntegral ((n `shiftR` (i * 16)) .&. 0xffff) :: Int
| i <- [7,6..0] ]
hexGroups = map (`showHex` "") groups
in concatIntersperse ":" hexGroups
where
showHex x s = let h = showHexInt x in h ++ s
showHexInt x
| x == 0 = "0"
| otherwise = reverse (go x)
where go 0 = []
go v = let (q,r) = v `divMod` 16
c = "0123456789abcdef" !! r
in c : go q
concatIntersperse _ [] = ""
concatIntersperse _ [x] = x
concatIntersperse s (x:xs) = x ++ s ++ concatIntersperse s xs

207
src/FWL/Check.hs Normal file
View File

@@ -0,0 +1,207 @@
{- | Static checks for MVP:
1. Undefined name detection (interfaces, zones, patterns, rules/policies)
2. Policy arm termination: last arm of a policy must not be Continue
3. Named pattern cycle detection
4. CIDR exhaustiveness stub (warns but does not error for MVP)
-}
module FWL.Check
( checkProgram
, CheckError(..)
) where
import Data.List (foldl', nub)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import FWL.AST
data CheckError
= UndefinedName String String -- kind, name
| PolicyNoContinue String -- policy name
| PatternCycle [String] -- cycle path
| DuplicateDecl String String -- kind, name
deriving (Show, Eq)
type Env = Map.Map String DeclKind
data DeclKind = KInterface | KZone | KLet | KPattern | KFlow | KRule | KPolicy
deriving (Show, Eq)
checkProgram :: Program -> [CheckError]
checkProgram (Program _ decls) =
dupErrs ++ nameErrs ++ policyErrs ++ cycleErrs
where
env = buildEnv decls
dupErrs = findDups decls
nameErrs = concatMap (checkDecl env) decls
policyErrs = concatMap checkPolicyTermination decls
cycleErrs = checkPatternCycles decls
-- ─── Environment ─────────────────────────────────────────────────────────────
buildEnv :: [Decl] -> Env
buildEnv = foldl' addDecl Map.empty
where
addDecl m (DInterface n _ _) = Map.insert n KInterface m
addDecl m (DZone n _) = Map.insert n KZone m
addDecl m (DLet n _ _) = Map.insert n KLet m
addDecl m (DPattern n _ _) = Map.insert n KPattern m
addDecl m (DFlow n _) = Map.insert n KFlow m
addDecl m (DRule n _ _) = Map.insert n KRule m
addDecl m (DPolicy n _ _ _) = Map.insert n KPolicy m
addDecl m _ = m
findDups :: [Decl] -> [CheckError]
findDups decls = go [] Set.empty decls
where
go acc _ [] = acc
go acc seen (d:ds) =
let n = declName d in
if Set.member n seen
then go (DuplicateDecl (declKindStr d) n : acc) seen ds
else go acc (Set.insert n seen) ds
declName :: Decl -> String
declName (DInterface n _ _) = n
declName (DZone n _) = n
declName (DImport n _ _) = n
declName (DLet n _ _) = n
declName (DPattern n _ _) = n
declName (DFlow n _) = n
declName (DRule n _ _) = n
declName (DPolicy n _ _ _) = n
declKindStr :: Decl -> String
declKindStr (DInterface _ _ _) = "interface"
declKindStr (DZone _ _) = "zone"
declKindStr (DImport _ _ _) = "import"
declKindStr (DLet _ _ _) = "let"
declKindStr (DPattern _ _ _) = "pattern"
declKindStr (DFlow _ _) = "flow"
declKindStr (DRule _ _ _) = "rule"
declKindStr (DPolicy _ _ _ _) = "policy"
-- ─── Name resolution ─────────────────────────────────────────────────────────
checkDecl :: Env -> Decl -> [CheckError]
checkDecl env (DZone _ ns) = concatMap (checkName env "interface or zone") ns
checkDecl env (DPattern _ _ p) = checkPat env p
checkDecl env (DFlow _ fe) = checkFlow env fe
checkDecl env (DRule _ _ e) = checkExpr env e
checkDecl env (DPolicy _ _ _ ab) = concatMap (checkArm env) ab
checkDecl env (DLet _ _ e) = checkExpr env e
checkDecl _ _ = []
checkName :: Env -> String -> String -> [CheckError]
checkName env kind n
| Map.member n env = []
| isBuiltin n = []
| otherwise = [UndefinedName kind n]
isBuiltin :: String -> Bool
isBuiltin n = n `elem`
[ "ct", "iif", "oif", "lo", "wan", "lan"
, "tcp", "udp", "ip", "ip6", "eth"
, "Established", "Related", "DNAT"
, "Allow", "Drop", "Continue", "Masquerade"
, "Matched", "Unmatched"
, "true", "false"
]
checkPat :: Env -> Pat -> [CheckError]
checkPat _ PWild = []
checkPat _ (PVar _) = []
checkPat env (PNamed n) = checkName env "pattern" n
checkPat env (PCtor _ ps) = concatMap (checkPat env) ps
checkPat env (PRecord _ fs) = concatMap (checkFP env) fs
checkPat env (PTuple ps) = concatMap (checkPat env) ps
checkPat env (PFrame mp inner)= maybe [] (checkPath env) mp ++ checkPat env inner
checkPat _ (PBytes _) = []
checkFP :: Env -> FieldPat -> [CheckError]
checkFP _ _ = [] -- field names checked by type-checker later
checkPath :: Env -> PathPat -> [CheckError]
checkPath env (PathPat ms md) =
maybe [] (checkEP env) ms ++ maybe [] (checkEP env) md
checkEP :: Env -> EndpointPat -> [CheckError]
checkEP _ EPWild = []
checkEP env (EPName n) = checkName env "interface or zone" n
checkEP env (EPMember _ z) = checkName env "zone" z
checkFlow :: Env -> FlowExpr -> [CheckError]
checkFlow env (FAtom n) = checkName env "pattern" n
checkFlow env (FSeq a b _) = checkFlow env a ++ checkFlow env b
checkArm :: Env -> Arm -> [CheckError]
checkArm env (Arm p mg e) =
checkPat env p ++
maybe [] (checkExpr env) mg ++
checkExpr env e
checkExpr :: Env -> Expr -> [CheckError]
checkExpr env (EVar n) = checkName env "name" n
checkExpr _ (EQual _) = [] -- qualified names: deferred
checkExpr _ (ELit _) = []
checkExpr env (ELam _ e) = checkExpr env e
checkExpr env (EApp f x) = checkExpr env f ++ checkExpr env x
checkExpr env (ECase e ab) = checkExpr env e ++ concatMap (checkArm env) ab
checkExpr env (EIf c t f) = concatMap (checkExpr env) [c,t,f]
checkExpr env (EDo ss) = concatMap (checkStmt env) ss
checkExpr env (ELet _ e1 e2) = checkExpr env e1 ++ checkExpr env e2
checkExpr env (ETuple es) = concatMap (checkExpr env) es
checkExpr env (ESet es) = concatMap (checkExpr env) es
checkExpr env (EMap ms) = concatMap (\(k,v) -> checkExpr env k ++ checkExpr env v) ms
checkExpr env (EPerform _ as_) = concatMap (checkExpr env) as_
checkExpr env (EInfix _ l r) = checkExpr env l ++ checkExpr env r
checkExpr env (ENot e) = checkExpr env e
checkStmt :: Env -> DoStmt -> [CheckError]
checkStmt env (DSBind _ e) = checkExpr env e
checkStmt env (DSExpr e) = checkExpr env e
-- ─── Policy termination ───────────────────────────────────────────────────────
-- The last arm of a policy block must not unconditionally return Continue.
checkPolicyTermination :: Decl -> [CheckError]
checkPolicyTermination (DPolicy n _ _ arms)
| null arms = [PolicyNoContinue n]
| isContinue (last arms) = [PolicyNoContinue n]
| otherwise = []
where
isContinue (Arm PWild Nothing (EVar "Continue")) = True
isContinue _ = False
checkPolicyTermination _ = []
-- ─── Pattern cycle detection ─────────────────────────────────────────────────
checkPatternCycles :: [Decl] -> [CheckError]
checkPatternCycles decls =
[ PatternCycle c
| c <- findCycles graph
]
where
patDecls = [(n, p) | DPattern n _ p <- decls]
graph = Map.fromList [(n, nub (refsInPat p)) | (n,p) <- patDecls]
allPats = Set.fromList (map fst patDecls)
refsInPat :: Pat -> [String]
refsInPat (PNamed r) = [r | Set.member r allPats]
refsInPat (PCtor _ ps) = concatMap refsInPat ps
refsInPat (PTuple ps) = concatMap refsInPat ps
refsInPat (PFrame _ p) = refsInPat p
refsInPat _ = []
findCycles :: Map.Map String [String] -> [[String]]
findCycles graph = go Set.empty Set.empty [] (Map.keys graph)
where
go _ _ _ [] = []
go visited onPath path (n:ns)
| Set.member n visited = go visited onPath path ns
| Set.member n onPath = [path]
| otherwise =
let onPath' = Set.insert n onPath
path' = path ++ [n]
deps = Map.findWithDefault [] n graph
cycles = go visited onPath' path' deps
in cycles ++ go (Set.insert n visited) onPath path ns

313
src/FWL/Compile.hs Normal file
View File

@@ -0,0 +1,313 @@
{-# LANGUAGE OverloadedStrings #-}
{- | Compile a checked FWL program to nftables JSON using Aeson.
All policies (Filter and NAT) go into one table named by Config.
Layer stripping: Frame patterns that omit Ether compile identically
to those that include it.
-}
module FWL.Compile
( compileProgram
, compileToJson
) where
import Data.List (intercalate)
import Data.Maybe (mapMaybe)
import qualified Data.Map.Strict as Map
import Data.Aeson ((.=), Value(..), object, toJSON)
import qualified Data.Aeson as A
import qualified Data.Text as T
import qualified Data.ByteString.Lazy as BL
import Data.Aeson.Encode.Pretty (encodePretty)
import FWL.AST
-- ─── Entry points ────────────────────────────────────────────────────────────
compileToJson :: Program -> BL.ByteString
compileToJson = encodePretty . programToValue
-- exposed for tests
compileProgram :: Program -> Value
compileProgram = programToValue
programToValue :: Program -> Value
programToValue (Program cfg decls) =
object [ "nftables" .= toJSON
(metainfo : tableObj : chainObjs ++ mapObjs ++ ruleObjs) ]
where
env = buildEnv decls
tbl = configTable cfg
metainfo = object [ "metainfo" .= object
[ "json_schema_version" .= (1 :: Int) ] ]
tableObj = object [ "table" .= tableValue tbl ]
policies = [ (n, pm, ab) | DPolicy n _ pm ab <- decls ]
chainObjs = map (\(n, pm, _ ) -> chainDeclValue tbl n pm) policies
ruleObjs = concatMap
(\(n, _, ab) -> concatMap (armToRuleValues env tbl n) ab)
policies
letDecls = [ (n, t, e) | DLet n t e <- decls ]
mapObjs = mapMaybe (\(n, _, e) -> letToMapValue tbl n e) letDecls
-- ─── Table / Chain declarations ──────────────────────────────────────────────
tableValue :: String -> Value
tableValue tbl = object
[ "family" .= ("inet" :: String)
, "name" .= tbl
]
chainDeclValue :: String -> Name -> PolicyMeta -> Value
chainDeclValue tbl n pm = object
[ "chain" .= object
[ "family" .= ("inet" :: String)
, "table" .= tbl
, "name" .= n
, "type" .= chainTypeStr (pmTable pm)
, "hook" .= hookStr (pmHook pm)
, "prio" .= priorityInt (pmPriority pm)
, "policy" .= defaultPolicyStr (pmHook pm)
]
]
chainTypeStr :: TableName -> String
chainTypeStr TFilter = "filter"
chainTypeStr TNAT = "nat"
hookStr :: Hook -> String
hookStr HInput = "input"
hookStr HForward = "forward"
hookStr HOutput = "output"
hookStr HPrerouting = "prerouting"
hookStr HPostrouting = "postrouting"
-- Priority is emitted as an integer in nftables JSON.
priorityInt :: Priority -> Int
priorityInt = priorityValue
defaultPolicyStr :: Hook -> String
defaultPolicyStr HInput = "drop"
defaultPolicyStr HForward = "drop"
defaultPolicyStr _ = "accept"
-- ─── Arm → Rule objects ──────────────────────────────────────────────────────
armToRuleValues :: CompileEnv -> String -> Name -> Arm -> [Value]
armToRuleValues env tbl chain (Arm p mg body) =
case compileAction env body of
Nothing -> []
Just verdict ->
let patExprs = compilePat env p
guardExprs = maybe [] (compileGuard env) mg
allExprs = patExprs ++ guardExprs ++ [verdict]
in [ object
[ "rule" .= object
[ "family" .= ("inet" :: String)
, "table" .= tbl
, "chain" .= chain
, "expr" .= toJSON allExprs
]
]
]
-- ─── Pattern → [Value] ───────────────────────────────────────────────────────
type CompileEnv = Map.Map String Decl
buildEnv :: [Decl] -> CompileEnv
buildEnv = foldr (\d m -> Map.insert (declNameOf d) d m) Map.empty
where
declNameOf (DInterface n _ _) = n
declNameOf (DZone n _) = n
declNameOf (DPattern n _ _) = n
declNameOf (DFlow n _) = n
declNameOf (DRule n _ _) = n
declNameOf (DPolicy n _ _ _) = n
declNameOf (DLet n _ _) = n
declNameOf (DImport n _ _) = n
compilePat :: CompileEnv -> Pat -> [Value]
compilePat _ PWild = []
compilePat _ (PVar _) = []
compilePat env (PNamed n) = expandNamedPat env n
compilePat env (PFrame mp inner) =
maybe [] (compilePathPat env) mp ++ compilePat env inner
compilePat env (PCtor n ps) = compileCtorPat env n ps
compilePat _ (PRecord n fs) = compileRecordPat n fs
compilePat env (PTuple ps) = concatMap (compilePat env) ps
compilePat _ (PBytes _) = []
expandNamedPat :: CompileEnv -> Name -> [Value]
expandNamedPat env n =
case Map.lookup n env of
Just (DPattern _ _ p) -> compilePat env p
_ -> []
compileCtorPat :: CompileEnv -> String -> [Pat] -> [Value]
compileCtorPat env ctor ps = case ctor of
"Ether" -> children
"IPv4" -> matchMeta "nfproto" "ipv4" : children
"IPv6" -> matchMeta "nfproto" "ipv6" : children
"TCP" -> matchPayload "th" "protocol" "tcp" : children
"UDP" -> matchPayload "th" "protocol" "udp" : children
"ICMPv6" -> matchPayload "ip6" "nexthdr" "ipv6-icmp" : children
"ICMP" -> matchPayload "ip" "protocol" "icmp" : children
_ -> children
where
children = concatMap (compilePat env) ps
compileRecordPat :: String -> [FieldPat] -> [Value]
compileRecordPat proto = mapMaybe go
where
go (FPEq field lit) = Just (matchPayload proto field (renderLit lit))
go _ = Nothing
compilePathPat :: CompileEnv -> PathPat -> [Value]
compilePathPat _ (PathPat ms md) =
maybe [] (compileEndpoint "iifname") ms ++
maybe [] (compileEndpoint "oifname") md
compileEndpoint :: String -> EndpointPat -> [Value]
compileEndpoint _ EPWild = []
compileEndpoint dir (EPName n) = [matchMeta dir n]
compileEndpoint dir (EPMember _ z) = [matchInSet (metaVal dir) [z]]
-- ─── Guard → [Value] ─────────────────────────────────────────────────────────
compileGuard :: CompileEnv -> Expr -> [Value]
compileGuard env (EInfix OpAnd l r) = compileGuard env l ++ compileGuard env r
compileGuard _ (EInfix OpIn l r) = [compileInExpr l r]
compileGuard _ (EInfix OpEq l r) = [matchExpr "==" (exprVal l) (exprVal r)]
compileGuard _ (EInfix OpNeq l r) = [matchExpr "!=" (exprVal l) (exprVal r)]
compileGuard _ _ = []
compileInExpr :: Expr -> Expr -> Value
-- Fix 4: put the more-specific ct patterns BEFORE the generic 2-element
-- EQual case to eliminate the overlapping pattern match warning.
compileInExpr (EQual ["ct", "state"]) (ESet vs) = ctMatch "state" vs
compileInExpr (EQual ["ct", "status"]) (ESet vs) = ctMatch "status" vs
compileInExpr l (ESet vs) =
matchExpr "in" (exprVal l) (setVal (map exprToStr vs))
compileInExpr l r =
matchExpr "==" (exprVal l) (exprVal r)
ctMatch :: String -> [Expr] -> Value
ctMatch key vs = matchExpr "in"
(object ["ct" .= object ["key" .= (key :: String)]])
(setVal (map exprToStr vs))
-- ─── Action → Maybe Value ─────────────────────────────────────────────────────
compileAction :: CompileEnv -> Expr -> Maybe Value
compileAction _ (EVar "Allow") = Just (object ["accept" .= Null])
compileAction _ (EVar "Drop") = Just (object ["drop" .= Null])
compileAction _ (EVar "Continue") = Nothing
compileAction _ (EVar "Masquerade") = Just (object ["masquerade" .= Null])
compileAction _ (EApp (EVar "DNAT") arg) =
Just $ object ["dnat" .= object ["addr" .= exprToStr arg]]
compileAction _ (EApp (EVar "DNATMap") arg) =
Just $ object ["dnat" .= object ["addr" .= object
[ "map" .= object [ "key" .= object ["concat" .= Array mempty]
, "data" .= exprToStr arg ]]]]
compileAction env (EApp (EVar rn) _) =
case Map.lookup rn env of
Just (DRule _ _ _) -> Just $ object ["jump" .= object ["target" .= rn]]
_ -> Just (object ["accept" .= Null])
compileAction _ _ = Just (object ["accept" .= Null])
-- ─── Let → Map object ────────────────────────────────────────────────────────
letToMapValue :: String -> Name -> Expr -> Maybe Value
letToMapValue tbl n (EMap entries) = Just $ object
[ "map" .= object
[ "family" .= ("inet" :: String)
, "table" .= tbl
, "name" .= n
, "type" .= ("inetproto . inetservice" :: String)
, "map" .= ("ipv4_addr . inetservice" :: String)
, "elem" .= toJSON (map renderMapElem entries)
]
]
letToMapValue _ _ _ = Nothing
renderMapElem :: (Expr, Expr) -> Value
renderMapElem (k, v) = toJSON
[ object ["concat" .= toJSON [exprToStr k]]
, A.String (toText (exprToStr v))
]
-- ─── Aeson building blocks ───────────────────────────────────────────────────
matchExpr :: String -> Value -> Value -> Value
matchExpr op l r = object
[ "match" .= object
[ "op" .= (op :: String)
, "left" .= l
, "right" .= r
]
]
matchMeta :: String -> String -> Value
matchMeta key val = matchExpr "==" (metaVal key) (A.String (toText val))
matchPayload :: String -> String -> String -> Value
matchPayload proto field val =
matchExpr "==" (payloadVal proto field) (A.String (toText val))
matchInSet :: Value -> [String] -> Value
matchInSet lhs vals = matchExpr "in" lhs (setVal vals)
metaVal :: String -> Value
metaVal key = object ["meta" .= object ["key" .= (key :: String)]]
payloadVal :: String -> String -> Value
payloadVal proto field =
object ["payload" .= object
[ "protocol" .= (proto :: String)
, "field" .= (field :: String)
]]
setVal :: [String] -> Value
setVal vs = object ["set" .= toJSON vs]
-- ─── Expression helpers ───────────────────────────────────────────────────────
-- Fix 3 (overlap): specific ct pattern first, generic 2-element case second.
exprVal :: Expr -> Value
exprVal (EQual ["ct", k]) = object ["ct" .= object ["key" .= (k :: String)]]
exprVal (EQual [p, f]) = payloadVal p f
exprVal (EQual ns) = A.String (toText (intercalate "." ns))
exprVal (EVar n) = metaVal n
exprVal (ELit l) = A.String (toText (renderLit l))
exprVal (ESet vs) = setVal (map exprToStr vs)
exprVal e = A.String (toText (exprToStr e))
exprToStr :: Expr -> String
exprToStr (EVar n) = n
exprToStr (ELit l) = renderLit l
exprToStr (EQual ns) = intercalate "." ns
exprToStr (ETuple es) = intercalate " . " (map exprToStr es)
exprToStr _ = "_"
-- Fix 2: Use Data.Text.pack via OverloadedStrings + fromString instead of
-- the fragile read(show s) hack. With OverloadedStrings enabled, string
-- literals already produce the correct Text/Key types; for runtime String
toText :: String -> T.Text
toText = T.pack
renderLit :: Literal -> String
renderLit (LInt n) = show n
renderLit (LString s) = s
renderLit (LBool True) = "true"
renderLit (LBool False) = "false"
renderLit (LIPv4 (a, b, c, d)) =
show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d
renderLit (LIPv6 _) = "::1"
renderLit (LCIDR ip p) = renderLit ip ++ "/" ++ show p
renderLit (LPort p) = show p
renderLit (LDuration n Seconds) = show n ++ "s"
renderLit (LDuration n Millis) = show n ++ "ms"
renderLit (LDuration n Minutes) = show n ++ "m"
renderLit (LDuration n Hours) = show n ++ "h"
renderLit (LHex b) = show b

101
src/FWL/Lexer.hs Normal file
View File

@@ -0,0 +1,101 @@
module FWL.Lexer where
import Text.Parsec
import Text.Parsec.String (Parser)
import qualified Text.Parsec.Token as Tok
import Text.Parsec.Language (emptyDef)
-- ─── Language definition ─────────────────────────────────────────────────────
fwlDef :: Tok.LanguageDef ()
fwlDef = emptyDef
{ Tok.commentLine = "--"
, Tok.commentStart = "{-"
, Tok.commentEnd = "-}"
, Tok.identStart = letter <|> char '_'
, Tok.identLetter = alphaNum <|> char '_'
, Tok.reservedNames =
-- Only genuine syntactic keywords belong here.
-- Semantic values used as constructors, actions, type names, or
-- pattern references (Allow, Drop, Log, Matched, Frame, etc.) must
-- NOT be reserved so that `identifier` can consume them in those
-- positions.
[ "config", "table"
, "interface", "zone", "import", "from"
, "let", "in", "pattern", "flow", "rule", "policy", "on"
, "case", "of", "if", "then", "else", "do", "perform"
, "within", "as", "dynamic", "cidr4", "cidr6"
, "hook", "priority"
, "WAN", "LAN", "WireGuard"
, "Input", "Forward", "Output", "Prerouting", "Postrouting"
, "Filter", "NAT", "Mangle", "DstNat", "SrcNat", "Raw", "ConnTrack"
, "true", "false"
]
, Tok.reservedOpNames =
[ "->", "<-", "=>", "::", ":", "=", ".", ".."
, "\\", "|", ","
, "&&", "||", "!", "==" , "!=", "<", "<=", ">", ">="
, "++", ">>", ">>="
, ""
]
, Tok.caseSensitive = True
}
lexer :: Tok.TokenParser ()
lexer = Tok.makeTokenParser fwlDef
-- ─── Token helpers ───────────────────────────────────────────────────────────
identifier :: Parser String
identifier = Tok.identifier lexer
reserved :: String -> Parser ()
reserved = Tok.reserved lexer
reservedOp :: String -> Parser ()
reservedOp = Tok.reservedOp lexer
symbol :: String -> Parser String
symbol = Tok.symbol lexer
parens :: Parser a -> Parser a
parens = Tok.parens lexer
braces :: Parser a -> Parser a
braces = Tok.braces lexer
angles :: Parser a -> Parser a
angles = Tok.angles lexer
brackets :: Parser a -> Parser a
brackets = Tok.brackets lexer
semi :: Parser String
semi = Tok.semi lexer
comma :: Parser String
comma = Tok.comma lexer
colon :: Parser String
colon = Tok.colon lexer
dot :: Parser String
dot = Tok.dot lexer
whiteSpace :: Parser ()
whiteSpace = Tok.whiteSpace lexer
stringLit :: Parser String
stringLit = Tok.stringLiteral lexer
natural :: Parser Integer
natural = Tok.natural lexer
commaSep :: Parser a -> Parser [a]
commaSep = Tok.commaSep lexer
commaSep1 :: Parser a -> Parser [a]
commaSep1 = Tok.commaSep1 lexer
semiSep :: Parser a -> Parser [a]
semiSep = Tok.semiSep lexer

659
src/FWL/Parser.hs Normal file
View File

@@ -0,0 +1,659 @@
module FWL.Parser
( parseProgram
, parseFile
) where
import Control.Monad (void)
import Data.Bits ((.&.), (.|.), shiftL)
import Data.List (foldl')
import Data.Word (Word8)
import Numeric (readHex)
import Text.Parsec
import Text.Parsec.String (Parser)
import Data.Functor.Identity (Identity)
import qualified Text.Parsec.Expr as Ex
import FWL.AST
import FWL.Lexer
-- ─── Entry points ────────────────────────────────────────────────────────────
parseProgram :: String -> String -> Either ParseError Program
parseProgram src input = parse program src input
parseFile :: FilePath -> IO (Either ParseError Program)
parseFile fp = parseProgram fp <$> readFile fp
-- ─── Top-level ───────────────────────────────────────────────────────────────
program :: Parser Program
program = do
whiteSpace
cfg <- option defaultConfig configBlock
ds <- many decl
eof
return (Program cfg ds)
configBlock :: Parser Config
configBlock = do
reserved "config"
props <- braces (semiSep configProp)
optional semi
return $ foldr applyProp defaultConfig props
where
applyProp ("table", v) c = c { configTable = v }
applyProp _ c = c
configProp :: Parser (String, String)
configProp = do
reserved "table"
reservedOp "="
v <- stringLit
return ("table", v)
-- ─── Declarations ────────────────────────────────────────────────────────────
decl :: Parser Decl
decl = interfaceDecl
<|> zoneDecl
<|> importDecl
<|> letDecl
<|> patternDecl
<|> flowDecl
<|> ruleDecl
<|> policyDecl
interfaceDecl :: Parser Decl
interfaceDecl = do
reserved "interface"
n <- identifier
reservedOp ":"
k <- ifaceKind
ps <- braces (endBy ifaceProp semi)
_ <- semi
return (DInterface n k ps)
ifaceKind :: Parser IfaceKind
ifaceKind = (reserved "WAN" >> return IWan)
<|> (reserved "LAN" >> return ILan)
<|> (reserved "WireGuard" >> return IWireGuard)
<|> (IUser <$> identifier)
ifaceProp :: Parser IfaceProp
ifaceProp = (reserved "dynamic" >> return IPDynamic)
<|> (reserved "cidr4" >> reservedOp "=" >> IPCidr4 <$> cidrSet)
<|> (reserved "cidr6" >> reservedOp "=" >> IPCidr6 <$> cidrSet)
cidrSet :: Parser [CIDR]
cidrSet = braces (commaSep1 cidrLit)
zoneDecl :: Parser Decl
zoneDecl = do
reserved "zone"
n <- identifier
reservedOp "="
ns <- braces (commaSep1 identifier)
_ <- semi
return (DZone n ns)
importDecl :: Parser Decl
importDecl = do
reserved "import"
n <- identifier
reservedOp ":"
t <- typeP
reserved "from"
s <- stringLit
_ <- semi
return (DImport n t s)
letDecl :: Parser Decl
letDecl = do
reserved "let"
n <- identifier
reservedOp ":"
t <- typeP
reservedOp "="
e <- expr
_ <- semi
return (DLet n t e)
patternDecl :: Parser Decl
patternDecl = do
reserved "pattern"
n <- identifier
reservedOp ":"
t <- typeP
reservedOp "="
p <- pat
_ <- semi
return (DPattern n t p)
flowDecl :: Parser Decl
flowDecl = do
reserved "flow"
n <- identifier
reservedOp ":"
reserved "FlowPattern"
reservedOp "="
f <- flowExpr
_ <- semi
return (DFlow n f)
ruleDecl :: Parser Decl
ruleDecl = do
reserved "rule"
n <- identifier
reservedOp ":"
t <- typeP
reservedOp "="
e <- expr
_ <- semi
return (DRule n t e)
policyDecl :: Parser Decl
policyDecl = do
reserved "policy"
n <- identifier
reservedOp ":"
t <- typeP
reserved "on"
pm <- braces policyMeta
reservedOp "="
ab <- armBlock
_ <- semi
return (DPolicy n t pm ab)
policyMeta :: Parser PolicyMeta
policyMeta = do
props <- commaSep1 metaProp
let h = foldr (\p a -> case p of Left v -> v; _ -> a) HInput props
tb = foldr (\p a -> case p of Right (Left v) -> v; _ -> a) TFilter props
pr = foldr (\p a -> case p of Right (Right v) -> v; _ -> a) pFilter props
return (PolicyMeta h tb pr)
metaProp :: Parser (Either Hook (Either TableName Priority))
metaProp
= (reserved "hook" >> reservedOp "=" >> fmap (Left) hookP)
<|> (reserved "table" >> reservedOp "=" >> fmap (Right . Left) tableNameP)
<|> (reserved "priority" >> reservedOp "=" >> fmap (Right . Right) priorityP)
hookP :: Parser Hook
hookP = (reserved "Input" >> return HInput)
<|> (reserved "Forward" >> return HForward)
<|> (reserved "Output" >> return HOutput)
<|> (reserved "Prerouting" >> return HPrerouting)
<|> (reserved "Postrouting" >> return HPostrouting)
tableNameP :: Parser TableName
tableNameP = (reserved "Filter" >> return TFilter)
<|> (reserved "NAT" >> return TNAT)
priorityP :: Parser Priority
priorityP
= (reserved "Filter" >> return pFilter)
<|> (reserved "DstNat" >> return pDstNat)
<|> (reserved "SrcNat" >> return pSrcNat)
<|> (reserved "Mangle" >> return pMangle)
<|> (reserved "Raw" >> return pRaw)
<|> (reserved "ConnTrack" >> return pConnTrack)
<|> (Priority . fromIntegral <$> integerP)
where
-- Accept optional leading minus for negative priorities
integerP = do
neg <- option 1 (char '-' >> return (-1))
n <- natural
whiteSpace
return (neg * fromIntegral n)
-- ─── Arm blocks ──────────────────────────────────────────────────────────────
armBlock :: Parser ArmBlock
armBlock = braces (many arm)
arm :: Parser Arm
arm = do
_ <- symbol "|"
p <- pat
g <- optionMaybe (reserved "if" >> expr)
reservedOp "->"
e <- expr
_ <- semi
return (Arm p g e)
-- ─── Patterns ────────────────────────────────────────────────────────────────
pat :: Parser Pat
pat = wildcardPat
<|> try framePat
<|> try tuplePat
<|> bytesPat
<|> try recordPat
<|> try namedOrCtorPat
wildcardPat :: Parser Pat
wildcardPat = symbol "_" >> return PWild
-- Frame(...) — optional path then inner pattern
-- Layer stripping: if the inner pattern is not Ether/IPv4/IPv6/etc the
-- type-checker will peel outer layers automatically. Parser just stores
-- whatever the user wrote.
framePat :: Parser Pat
framePat = do
reserved "Frame"
(mp, inner) <- parens frameArgs
return (PFrame mp inner)
frameArgs :: Parser (Maybe PathPat, Pat)
frameArgs = try withPath <|> withoutPath
where
withPath = do
pp <- pathPat
_ <- comma
inner <- pat
return (Just pp, inner)
withoutPath = do
inner <- pat
return (Nothing, inner)
pathPat :: Parser PathPat
pathPat = do
src <- optionMaybe (try endpointPat)
dst <- optionMaybe (try (reservedOp "->" >> endpointPat))
case (src, dst) of
(Nothing, Nothing) -> fail "empty path pattern"
_ -> return (PathPat src dst)
endpointPat :: Parser EndpointPat
endpointPat
= (symbol "_" >> return EPWild)
<|> try (do n <- identifier
memberOp
z <- identifier
return (EPMember n z))
<|> (EPName <$> identifier)
memberOp :: Parser ()
memberOp = (reservedOp "" <|> reserved "in") >> return ()
tuplePat :: Parser Pat
tuplePat = do
ps <- parens (commaSep2 pat)
return (PTuple ps)
commaSep2 :: Parser a -> Parser [a]
commaSep2 p = do
x <- p
_ <- comma
xs <- commaSep1 p
return (x:xs)
bytesPat :: Parser Pat
bytesPat = brackets (PBytes <$> many byteElem)
byteElem :: Parser ByteElem
byteElem
= try (symbol "_*" >> return BEWildStar)
<|> try (symbol "_" >> return BEWild)
<|> (BEHex <$> hexByte)
hexByte :: Parser Word8
hexByte = do
void (string "0x")
h1 <- hexDigit
h2 <- hexDigit
whiteSpace
case (readHex [h1,h2] :: [(Integer, String)]) of
[(v,"")] -> return (fromIntegral v)
_ -> fail "invalid hex byte"
-- Record pattern: ident { fields }
recordPat :: Parser Pat
recordPat = do
n <- identifier
fs <- braces (commaSep fieldPat)
return (PRecord n fs)
fieldPat :: Parser FieldPat
fieldPat = do
n <- identifier
try (reservedOp "=" >> FPEq n <$> fieldLiteral)
<|> try (reserved "as" >> FPAs n <$> identifier)
<|> return (FPBind n)
-- Port literals (:22) are valid in record field position as well as plain literals.
fieldLiteral :: Parser Literal
fieldLiteral = try portLit <|> literal
where
portLit = do
void (char ':')
n <- fromIntegral <$> natural
return (LPort n)
-- Named pattern reference OR constructor: starts with uppercase-ish ident
namedOrCtorPat :: Parser Pat
namedOrCtorPat = do
n <- identifier
args <- optionMaybe (try (parens (commaSep pat)))
case args of
Nothing -> return (PNamed n) -- bare name = named pattern ref
Just ps -> return (PCtor n ps)
-- ─── Flow expressions ────────────────────────────────────────────────────────
flowExpr :: Parser FlowExpr
flowExpr = do
first <- FAtom <$> identifier
rest <- many (reservedOp "." >> identifier)
mw <- optionMaybe (reserved "within" >> durationLit)
return $ buildSeq (first : map FAtom rest) mw
where
buildSeq [x] mw = case mw of
Nothing -> x
Just w -> FSeq x x (Just w) -- degenerate
buildSeq (x:xs) mw = FSeq x (buildSeq xs mw) mw
buildSeq [] _ = error "impossible"
durationLit :: Parser Duration
durationLit = do
n <- fromIntegral <$> natural
u <- (char 's' >> return Seconds)
<|> (string "ms" >> return Millis)
<|> (char 'm' >> return Minutes)
<|> (char 'h' >> return Hours)
whiteSpace
return (n, u)
-- ─── Types ───────────────────────────────────────────────────────────────────
typeP :: Parser Type
typeP = do
t <- baseType
option t (reservedOp "->" >> TFun t <$> typeP)
baseType :: Parser Type
baseType
= effectType
<|> try tupleTy
<|> simpleTy
effectType :: Parser Type
effectType = do
effs <- angles (commaSep identifier)
t <- simpleTy
return (TEffect effs t)
tupleTy :: Parser Type
tupleTy = TTuple <$> parens (commaSep2 typeP)
simpleTy :: Parser Type
simpleTy = do
n <- identifier
args <- option [] (angles (commaSep typeP))
return (TName n args)
-- ─── Expressions ─────────────────────────────────────────────────────────────
expr :: Parser Expr
expr = lamExpr
<|> ifExpr
<|> doExpr
<|> caseExpr
<|> letExpr
<|> infixExpr
lamExpr :: Parser Expr
lamExpr = do
reservedOp "\\"
n <- identifier
reservedOp "->"
e <- expr
return (ELam n e)
ifExpr :: Parser Expr
ifExpr = do
reserved "if"
c <- expr
reserved "then"
t <- expr
reserved "else"
f <- expr
return (EIf c t f)
doExpr :: Parser Expr
doExpr = reserved "do" >> braces (EDo <$> semiSep doStmt)
doStmt :: Parser DoStmt
doStmt = try bindStmt <|> (DSExpr <$> expr)
bindStmt :: Parser DoStmt
bindStmt = do
n <- identifier
reservedOp "<-"
e <- expr
return (DSBind n e)
caseExpr :: Parser Expr
caseExpr = do
reserved "case"
e <- expr
reserved "of"
ab <- armBlock
return (ECase e ab)
letExpr :: Parser Expr
letExpr = do
reserved "let"
n <- identifier
reservedOp "="
e1 <- expr
reserved "in"
e2 <- expr
return (ELet n e1 e2)
-- Operator table for infix expressions
infixExpr :: Parser Expr
infixExpr = Ex.buildExpressionParser opTable appExpr
opTable :: Ex.OperatorTable String () Identity Expr
opTable =
[ [ prefix "!" ENot ]
, [ infixL "==" OpEq, infixL "!=" OpNeq
, infixL "<" OpLt, infixL "<=" OpLte
, infixL ">" OpGt, infixL ">=" OpGte
, infixIn ]
, [ infixR "&&" OpAnd ]
, [ infixR "||" OpOr ]
, [ infixR "++" OpConcat ]
, [ infixL ">>=" OpBind ]
, [ infixL ">>" OpThen ]
]
where
prefix op f = Ex.Prefix (reservedOp op >> return f)
infixL op c = Ex.Infix (reservedOp op >> return (EInfix c)) Ex.AssocLeft
infixR op c = Ex.Infix (reservedOp op >> return (EInfix c)) Ex.AssocRight
infixIn = Ex.Infix
((memberOp <|> reserved "in") >> return (EInfix OpIn))
Ex.AssocNone
appExpr :: Parser Expr
appExpr = do
f <- atom
args <- many atom
return (foldl EApp f args)
atom :: Parser Expr
atom
= try performExpr
<|> try mapLit
<|> try setLit
<|> try tupleLit
<|> try (parens expr)
<|> try litExpr
<|> try portExpr
<|> qualNameExpr
performExpr :: Parser Expr
performExpr = do
reserved "perform"
parts <- sepBy1 identifier dot
args <- parens (commaSep expr)
return (EPerform parts args)
qualNameExpr :: Parser Expr
qualNameExpr = do
parts <- sepBy1 identifier (try (dot <* notFollowedBy digit))
case parts of
[n] -> return (EVar n)
ns -> return (EQual ns)
litExpr :: Parser Expr
litExpr = ELit <$> literal
portExpr :: Parser Expr
portExpr = do
void (char ':')
n <- fromIntegral <$> natural
return (ELit (LPort n))
tupleLit :: Parser Expr
tupleLit = ETuple <$> parens (commaSep2 expr)
setLit :: Parser Expr
setLit = braces $ do
items <- commaSep expr
return (ESet items)
-- map literal: { expr -> expr, ... }
mapLit :: Parser Expr
mapLit = braces $ do
entries <- commaSep1 mapEntry
return (EMap entries)
mapEntry :: Parser (Expr, Expr)
mapEntry = do
k <- expr
reservedOp "->"
v <- expr
return (k, v)
-- ─── Literals ────────────────────────────────────────────────────────────────
literal :: Parser Literal
literal
= try ipOrCidrLit
<|> try hexLit
<|> try (LBool True <$ reserved "true")
<|> try (LBool False <$ reserved "false")
<|> try (LString <$> stringLit)
<|> try (LInt . fromIntegral <$> natural)
hexLit :: Parser Literal
hexLit = LHex <$> hexByte
-- ─── IP / CIDR parsing ───────────────────────────────────────────────────────
-- | Parse an IPv4 or IPv6 address, optionally followed by /prefix.
-- Tries IPv6 first (it can start with hex digits too), then IPv4.
ipOrCidrLit :: Parser Literal
ipOrCidrLit = do
ip <- try ipv6Lit <|> ipv4Lit_
mPrefix <- optionMaybe (char '/' >> fromIntegral <$> natural)
whiteSpace
return $ case mPrefix of
Nothing -> ip
Just p -> LCIDR ip p
-- | IPv4: four decimal octets separated by dots → LIP IPv4 (32-bit Integer)
ipv4Lit_ :: Parser Literal
ipv4Lit_ = do
a <- octet
void (char '.')
b <- octet
void (char '.')
c <- octet
void (char '.')
d <- octet
return $ LIP IPv4
( fromIntegral a `shiftL` 24
.|. fromIntegral b `shiftL` 16
.|. fromIntegral c `shiftL` 8
.|. fromIntegral d)
where
octet = do
n <- fromIntegral <$> natural
if n > 255 then fail "octet out of range" else return n
-- | IPv6: full notation, :: abbreviation, and optional embedded IPv4.
-- Stores as LIP IPv6 (128-bit Integer).
ipv6Lit :: Parser Literal
ipv6Lit = do
(left, right) <- ipv6Groups
let missing = 8 - length left - length right
when (missing < 0) $ fail "too many groups in IPv6 address"
let groups = left ++ replicate missing 0 ++ right
when (length groups /= 8) $ fail "invalid IPv6 address"
let val = foldl' (\acc g -> (acc `shiftL` 16) .|. fromIntegral g) (0::Integer) groups
return (LIP IPv6 val)
-- Returns (left-of-::, right-of-::).
-- If no :: present, left has all 8 groups and right is empty.
ipv6Groups :: Parser ([Int], [Int])
ipv6Groups = do
-- must start with a hex digit or ':' (for ::)
ahead <- lookAhead (hexDigit <|> char ':')
case ahead of
':' -> do
void (string "::")
right <- ipv6RightGroups
return ([], right)
_ -> do
left <- ipv6LeftGroups
mDbl <- optionMaybe (try (string "::"))
case mDbl of
Nothing -> return (left, [])
Just _ -> do
right <- ipv6RightGroups
return (left, right)
-- Parse a run of hex16:hex16:... stopping before :: or end
ipv6LeftGroups :: Parser [Int]
ipv6LeftGroups = do
first <- hex16
rest <- many (try (char ':' >> notFollowedBy (char ':') >> hex16))
return (first : rest)
-- Parse groups to the right of ::, including optional embedded IPv4
ipv6RightGroups :: Parser [Int]
ipv6RightGroups = option [] $
try ipv4EmbeddedGroups <|> ipv6LeftGroups
-- IPv4-mapped groups: e.g. ffff:192.168.1.1 -> [0xffff, 0xc0a8, 0x0101]
ipv4EmbeddedGroups :: Parser [Int]
ipv4EmbeddedGroups = do
prefix <- many (try (hex16 <* char ':' <* lookAhead digit))
a <- octet_; void (char '.')
b <- octet_; void (char '.')
c <- octet_; void (char '.')
d <- octet_
let hi = (a `shiftL` 8) .|. b
lo = (c `shiftL` 8) .|. d
return (prefix ++ [hi, lo])
where
octet_ = do
n <- fromIntegral <$> natural
if n > 255 then fail "IPv4 octet out of range" else return n
hex16 :: Parser Int
hex16 = do
digits <- many1 hexDigit
case (reads ("0x" ++ digits)) :: [(Int,String)] of
[(v,"")] -> if v > 0xffff then fail "hex16 out of range" else return v
_ -> fail "invalid hex group"
cidrLit :: Parser CIDR
cidrLit = do
l <- ipOrCidrLit
case l of
LCIDR ip p -> return (ip, p)
_ -> fail "expected CIDR notation (address/prefix)"

187
src/FWL/Pretty.hs Normal file
View File

@@ -0,0 +1,187 @@
-- | Pretty printer: round-trips the AST back to FWL source.
module FWL.Pretty (prettyProgram) where
import Data.List (intercalate)
import FWL.AST
prettyProgram :: Program -> String
prettyProgram (Program cfg ds) =
prettyConfig cfg ++ "\n" ++ unlines (map prettyDecl ds)
prettyConfig :: Config -> String
prettyConfig (Config t)
| t == "fwl" = ""
| otherwise = "config { table = \"" ++ t ++ "\"; }\n"
prettyDecl :: Decl -> String
prettyDecl (DInterface n k ps) =
"interface " ++ n ++ " : " ++ prettyKind k ++ " {\n" ++
concatMap (\p -> " " ++ prettyIfaceProp p ++ ";\n") ps ++
"};"
prettyDecl (DZone n ns) =
"zone " ++ n ++ " = { " ++ intercalate ", " ns ++ " };"
prettyDecl (DImport n t s) =
"import " ++ n ++ " : " ++ prettyType t ++ " from \"" ++ s ++ "\";"
prettyDecl (DLet n t e) =
"let " ++ n ++ " : " ++ prettyType t ++ " = " ++ prettyExpr e ++ ";"
prettyDecl (DPattern n t p) =
"pattern " ++ n ++ " : " ++ prettyType t ++ " = " ++ prettyPat p ++ ";"
prettyDecl (DFlow n f) =
"flow " ++ n ++ " : FlowPattern = " ++ prettyFlow f ++ ";"
prettyDecl (DRule n t e) =
"rule " ++ n ++ " : " ++ prettyType t ++ " =\n " ++ prettyExpr e ++ ";"
prettyDecl (DPolicy n t pm ab) =
"policy " ++ n ++ " : " ++ prettyType t ++ "\n" ++
" on { hook = " ++ prettyHook (pmHook pm) ++
", table = " ++ prettyTable (pmTable pm) ++
", priority = " ++ prettyPriority (pmPriority pm) ++ " }\n" ++
" = " ++ prettyArmBlock ab ++ ";"
prettyKind :: IfaceKind -> String
prettyKind IWan = "WAN"
prettyKind ILan = "LAN"
prettyKind IWireGuard = "WireGuard"
prettyKind (IUser n) = n
prettyIfaceProp :: IfaceProp -> String
prettyIfaceProp IPDynamic = "dynamic"
prettyIfaceProp (IPCidr4 cs) = "cidr4 = { " ++ intercalate ", " (map prettyCidr cs) ++ " }"
prettyIfaceProp (IPCidr6 cs) = "cidr6 = { " ++ intercalate ", " (map prettyCidr cs) ++ " }"
prettyCidr :: CIDR -> String
prettyCidr (LIPv4 (a,b,c,d), p) =
show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ "/" ++ show p
prettyCidr (ip, p) = prettyLit ip ++ "/" ++ show p
prettyHook :: Hook -> String
prettyHook HInput = "Input"
prettyHook HForward = "Forward"
prettyHook HOutput = "Output"
prettyHook HPrerouting = "Prerouting"
prettyHook HPostrouting = "Postrouting"
prettyTable :: TableName -> String
prettyTable TFilter = "Filter"
prettyTable TNAT = "NAT"
prettyPriority :: Priority -> String
prettyPriority p = show (priorityValue p)
prettyType :: Type -> String
prettyType (TName n []) = n
prettyType (TName n ts) = n ++ "<" ++ intercalate ", " (map prettyType ts) ++ ">"
prettyType (TTuple ts) = "(" ++ intercalate ", " (map prettyType ts) ++ ")"
prettyType (TFun a b) = prettyType a ++ " -> " ++ prettyType b
prettyType (TEffect es t) = "<" ++ intercalate ", " es ++ "> " ++ prettyType t
prettyPat :: Pat -> String
prettyPat PWild = "_"
prettyPat (PVar n) = n
prettyPat (PNamed n) = n
prettyPat (PCtor n ps) = n ++ "(" ++ intercalate ", " (map prettyPat ps) ++ ")"
prettyPat (PRecord n fs) = n ++ " { " ++ intercalate ", " (map prettyFP fs) ++ " }"
prettyPat (PTuple ps) = "(" ++ intercalate ", " (map prettyPat ps) ++ ")"
prettyPat (PFrame mp inner)=
"Frame(" ++ maybe "" (\pp -> prettyPath pp ++ ", ") mp ++ prettyPat inner ++ ")"
prettyPat (PBytes bs) = "[" ++ unwords (map prettyBE bs) ++ "]"
prettyFP :: FieldPat -> String
prettyFP (FPEq n l) = n ++ " = " ++ prettyLit l
prettyFP (FPBind n) = n
prettyFP (FPAs n v) = n ++ " as " ++ v
prettyPath :: PathPat -> String
prettyPath (PathPat ms md) =
maybe "_" prettyEP ms ++ maybe "" (\d -> " -> " ++ prettyEP d) md
prettyEP :: EndpointPat -> String
prettyEP EPWild = "_"
prettyEP (EPName n) = n
prettyEP (EPMember n z) = n ++ " in " ++ z
prettyBE :: ByteElem -> String
prettyBE (BEHex w) = "0x" ++ pad (show w) -- simplified
where pad s = if length s < 2 then '0':s else s
prettyBE BEWild = "_"
prettyBE BEWildStar = "_*"
prettyFlow :: FlowExpr -> String
prettyFlow (FAtom n) = n
prettyFlow (FSeq a b mw) =
prettyFlow a ++ " . " ++ prettyFlow b ++
maybe "" (\(n,u) -> " within " ++ show n ++ prettyUnit u) mw
prettyUnit :: TimeUnit -> String
prettyUnit Seconds = "s"
prettyUnit Millis = "ms"
prettyUnit Minutes = "m"
prettyUnit Hours = "h"
prettyExpr :: Expr -> String
prettyExpr (EVar n) = n
prettyExpr (EQual ns) = intercalate "." ns
prettyExpr (ELit l) = prettyLit l
prettyExpr (ELam n e) = "\\" ++ n ++ " -> " ++ prettyExpr e
prettyExpr (EApp f x) = prettyExpr f ++ " " ++ prettyAtom x
prettyExpr (ECase e ab) =
"case " ++ prettyExpr e ++ " of " ++ prettyArmBlock ab
prettyExpr (EIf c t f) =
"if " ++ prettyExpr c ++ " then " ++ prettyExpr t ++ " else " ++ prettyExpr f
prettyExpr (EDo ss) =
"do { " ++ intercalate "; " (map prettyStmt ss) ++ " }"
prettyExpr (ELet n e1 e2) =
"let " ++ n ++ " = " ++ prettyExpr e1 ++ " in " ++ prettyExpr e2
prettyExpr (ETuple es) = "(" ++ intercalate ", " (map prettyExpr es) ++ ")"
prettyExpr (ESet es) = "{ " ++ intercalate ", " (map prettyExpr es) ++ " }"
prettyExpr (EMap ms) =
"{ " ++ intercalate ", " (map (\(k,v) -> prettyExpr k ++ " -> " ++ prettyExpr v) ms) ++ " }"
prettyExpr (EPerform ns as_) =
"perform " ++ intercalate "." ns ++ "(" ++ intercalate ", " (map prettyExpr as_) ++ ")"
prettyExpr (EInfix op l r) =
prettyAtom l ++ " " ++ prettyOp op ++ " " ++ prettyAtom r
prettyExpr (ENot e) = "!" ++ prettyAtom e
prettyAtom :: Expr -> String
prettyAtom e@(EInfix _ _ _) = "(" ++ prettyExpr e ++ ")"
prettyAtom e@(ELam _ _) = "(" ++ prettyExpr e ++ ")"
prettyAtom e = prettyExpr e
prettyOp :: InfixOp -> String
prettyOp OpAnd = "&&"
prettyOp OpOr = "||"
prettyOp OpEq = "=="
prettyOp OpNeq = "!="
prettyOp OpLt = "<"
prettyOp OpLte = "<="
prettyOp OpGt = ">"
prettyOp OpGte = ">="
prettyOp OpIn = "in"
prettyOp OpConcat = "++"
prettyOp OpThen = ">>"
prettyOp OpBind = ">>="
prettyStmt :: DoStmt -> String
prettyStmt (DSBind n e) = n ++ " <- " ++ prettyExpr e
prettyStmt (DSExpr e) = prettyExpr e
prettyArmBlock :: ArmBlock -> String
prettyArmBlock arms =
"{\n" ++
concatMap (\(Arm p mg e) ->
" | " ++ prettyPat p ++
maybe "" (\g -> " if " ++ prettyExpr g) mg ++
" -> " ++ prettyExpr e ++ ";\n") arms ++
" }"
prettyLit :: Literal -> String
prettyLit (LInt n) = show n
prettyLit (LString s) = "\"" ++ s ++ "\""
prettyLit (LBool True) = "true"
prettyLit (LBool False) = "false"
prettyLit (LIPv4 (a,b,c,d)) =
show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d
prettyLit (LIPv6 _) = "<ipv6>"
prettyLit (LCIDR ip p) = prettyLit ip ++ "/" ++ show p
prettyLit (LPort p) = ":" ++ show p
prettyLit (LDuration n u) = show n ++ prettyUnit u
prettyLit (LHex b) = "0x" ++ show b

224
test/CheckTests.hs Normal file
View File

@@ -0,0 +1,224 @@
module CheckTests (tests) where
import Test.Tasty
import Test.Tasty.HUnit
import FWL.Check
import FWL.Util
tests :: TestTree
tests = testGroup "Check"
[ undefinedNameTests
, duplicateTests
, policyTerminationTests
, patternCycleTests
, cleanProgramTests
]
-- ─── Helper ──────────────────────────────────────────────────────────────────
checkSrc :: String -> IO [CheckError]
checkSrc src = do
p <- parseOk src
return (checkProgram p)
assertNoErrors :: String -> IO ()
assertNoErrors src = do
errs <- checkSrc src
case errs of
[] -> return ()
_ -> assertFailure ("Unexpected errors: " ++ show errs)
assertHasError :: (CheckError -> Bool) -> String -> IO ()
assertHasError p src = do
errs <- checkSrc src
if any p errs
then return ()
else assertFailure ("Expected error not found. Got: " ++ show errs)
isUndefined :: String -> CheckError -> Bool
isUndefined n (UndefinedName _ m) = m == n
isUndefined _ _ = False
isDuplicate :: String -> CheckError -> Bool
isDuplicate n (DuplicateDecl _ m) = m == n
isDuplicate _ _ = False
isNoContinue :: String -> CheckError -> Bool
isNoContinue n (PolicyNoContinue m) = m == n
isNoContinue _ _ = False
isCycle :: CheckError -> Bool
isCycle (PatternCycle _) = True
isCycle _ = False
-- ─── Undefined name tests ────────────────────────────────────────────────────
undefinedNameTests :: TestTree
undefinedNameTests = testGroup "undefined names"
[ testCase "zone references unknown interface" $
assertHasError (isUndefined "ghost")
"zone bad_zone = { lan, ghost };"
, testCase "zone references known interface — no error" $
assertNoErrors
"interface lan : LAN {}; \
\zone good = { lan };"
, testCase "pattern references undefined named pattern" $
assertHasError (isUndefined "Undefined")
"pattern Bad : Frame = Frame(_, IPv4(ip, Undefined));"
, testCase "pattern references known named pattern — no error" $
assertNoErrors
"pattern WGInit : (UDPHeader,Bytes) = (udp { length = 156 }, [0x01 _*]); \
\pattern Compound : Frame = Frame(_, IPv4(ip, WGInit));"
, testCase "flow references undefined pattern" $
assertHasError (isUndefined "Ghost")
"flow Bad : FlowPattern = Ghost;"
, testCase "flow references known pattern — no error" $
assertNoErrors
"pattern P : T = udp { length = 1 }; \
\flow F : FlowPattern = P;"
, testCase "policy guard references undeclared zone" $
-- 'unknown_zone' not declared; check should flag it
assertHasError (isUndefined "unknown_zone")
"policy fwd : Frame \
\ on { hook = Forward, table = Filter, priority = Filter } \
\ = { | Frame(iif in unknown_zone -> wan, _) -> Allow; \
\ | _ -> Drop; \
\ };"
, testCase "policy references known zone — no error" $
assertNoErrors
"interface lan : LAN {}; \
\zone trusted = { lan }; \
\policy fwd : Frame \
\ on { hook = Forward, table = Filter, priority = Filter } \
\ = { | Frame(iif in trusted -> wan, _) -> Allow; \
\ | _ -> Drop; \
\ };"
]
-- ─── Duplicate declaration tests ─────────────────────────────────────────────
duplicateTests :: TestTree
duplicateTests = testGroup "duplicates"
[ testCase "duplicate interface" $
assertHasError (isDuplicate "lan")
"interface lan : LAN {}; \
\interface lan : WAN {};"
, testCase "duplicate zone" $
assertHasError (isDuplicate "z")
"zone z = { a }; \
\zone z = { b };"
, testCase "duplicate pattern" $
assertHasError (isDuplicate "P")
"pattern P : T = udp { length = 1 }; \
\pattern P : T = udp { length = 2 };"
, testCase "duplicate policy" $
assertHasError (isDuplicate "input")
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ -> Allow; }; \
\policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ -> Drop; };"
, testCase "distinct names — no error" $
assertNoErrors
"interface lan : LAN {}; \
\interface wan : WAN { dynamic; }; \
\zone z = { lan };"
]
-- ─── Policy termination tests ────────────────────────────────────────────────
policyTerminationTests :: TestTree
policyTerminationTests = testGroup "policy termination"
[ testCase "last arm is Continue — error" $
assertHasError (isNoContinue "bad_policy")
"policy bad_policy : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ -> Continue; };"
, testCase "last arm is Drop — ok" $
assertNoErrors
"policy good : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ if ct.state in { Established } -> Allow; \
\ | _ -> Drop; \
\ };"
, testCase "last arm is Allow — ok" $
assertNoErrors
"policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
, testCase "Continue in non-last arm is fine" $
assertNoErrors
"rule r : Frame -> Action = \
\ \\f -> case f of { \
\ | Frame(_, IPv4(ip, _)) -> Continue; \
\ | _ -> Drop; \
\ };"
, testCase "empty policy body — error" $
assertHasError (isNoContinue "empty")
"policy empty : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = {};"
]
-- ─── Pattern cycle tests ─────────────────────────────────────────────────────
patternCycleTests :: TestTree
patternCycleTests = testGroup "pattern cycles"
[ testCase "direct self-reference — cycle error" $
assertHasError isCycle
"pattern Loop : T = Frame(_, Loop);"
, testCase "mutual cycle — cycle error" $
assertHasError isCycle
"pattern A : T = Frame(_, B); \
\pattern B : T = Frame(_, A);"
, testCase "linear chain — no cycle" $
assertNoErrors
"pattern Base : T = udp { length = 1 }; \
\pattern Mid : T = Frame(_, Base); \
\pattern Top : T = Frame(_, Mid);"
]
-- ─── Clean full programs ──────────────────────────────────────────────────────
cleanProgramTests :: TestTree
cleanProgramTests = testGroup "clean programs"
[ testCase "minimal router skeleton" $
assertNoErrors
"interface wan : WAN { dynamic; }; \
\interface lan : LAN { cidr4 = { 10.17.1.0/24 }; }; \
\interface wg0 : WireGuard {}; \
\zone lan_zone = { lan, wg0 }; \
\policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ if ct.state in { Established, Related } -> Allow; \
\ | _ -> Drop; \
\ }; \
\policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
, testCase "pattern and flow declarations" $
assertNoErrors
"pattern WGInit : (UDPHeader,Bytes) = (udp { length = 156 }, [0x01 _*]); \
\pattern WGResp : (UDPHeader,Bytes) = (udp { length = 100 }, [0x02 _*]); \
\flow WGHandshake : FlowPattern = WGInit . WGResp within 5s;"
]

384
test/CompileTests.hs Normal file
View File

@@ -0,0 +1,384 @@
{-# LANGUAGE OverloadedStrings #-}
module CompileTests (tests) where
import Test.Tasty
import Test.Tasty.HUnit
import qualified Data.Aeson as A
import qualified Data.Aeson.Key as AK
import qualified Data.Aeson.KeyMap as AKM
import qualified Data.Vector as V
import qualified Data.ByteString.Lazy.Char8 as BL8
import FWL.AST
import FWL.Compile
import FWL.Util
tests :: TestTree
tests = testGroup "Compile"
[ jsonStructureTests
, chainTests
, ruleExprTests
, verdictTests
, layerStrippingTests
, continueTests
, configTests
]
-- ─── Helpers ─────────────────────────────────────────────────────────────────
compileToValue :: String -> IO A.Value
compileToValue src = do
p <- parseOk src
case A.decode (compileToJson p) of
Nothing -> assertFailure "Compiled output is not valid JSON" >> undefined
Just v -> return v
-- Navigate a Value by a list of string keys / numeric indices.
at :: [String] -> A.Value -> Maybe A.Value
at [] v = Just v
at (k:ks) (A.Object o) =
case AKM.lookup (AK.fromString k) o of
Nothing -> Nothing
Just v -> at ks v
at (k:ks) (A.Array arr) =
case reads k of
[(i,"")] | i < V.length arr -> at ks (arr V.! i)
_ -> Nothing
at _ _ = Nothing
nftArr :: A.Value -> IO [A.Value]
nftArr v =
case at ["nftables"] v of
Just (A.Array arr) -> return (V.toList arr)
_ -> assertFailure "Missing top-level 'nftables' array" >> undefined
withKey :: String -> [A.Value] -> [A.Value]
withKey k = filter (\v -> case at [k] v of Just _ -> True; _ -> False)
-- ─── JSON structure tests ────────────────────────────────────────────────────
jsonStructureTests :: TestTree
jsonStructureTests = testGroup "JSON structure"
[ testCase "output is valid JSON" $ do
_ <- compileToValue
"policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
return ()
, testCase "top-level nftables array present" $ do
v <- compileToValue "policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
_ <- nftArr v
return ()
, testCase "metainfo is first element" $ do
v <- compileToValue "policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
arr <- nftArr v
case arr of
(first:_) -> case at ["metainfo"] first of
Just _ -> return ()
Nothing -> assertFailure "First element is not metainfo"
[] -> assertFailure "Empty nftables array"
, testCase "table object present" $ do
v <- compileToValue "policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
arr <- nftArr v
assertBool "Expected at least one table object"
(not (null (withKey "table" arr)))
, testCase "default table name is fwl" $ do
v <- compileToValue "policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
arr <- nftArr v
case withKey "table" arr of
(t:_) -> at ["table","name"] t @?= Just (A.String "fwl")
[] -> assertFailure "No table object"
, testCase "custom table name respected" $ do
v <- compileToValue
"config { table = \"custom\"; } \
\policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
arr <- nftArr v
case withKey "table" arr of
(t:_) -> at ["table","name"] t @?= Just (A.String "custom")
[] -> assertFailure "No table object"
]
-- ─── Chain declaration tests ─────────────────────────────────────────────────
chainTests :: TestTree
chainTests = testGroup "chain declarations"
[ testCase "filter input chain has correct hook" $ do
v <- compileToValue
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ -> Drop; };"
arr <- nftArr v
case withKey "chain" arr of
(c:_) -> at ["chain","hook"] c @?= Just (A.String "input")
[] -> assertFailure "No chain"
, testCase "filter chain type is filter" $ do
v <- compileToValue
"policy fwd : Frame \
\ on { hook = Forward, table = Filter, priority = Filter } \
\ = { | _ -> Drop; };"
arr <- nftArr v
case withKey "chain" arr of
(c:_) -> at ["chain","type"] c @?= Just (A.String "filter")
[] -> assertFailure "No chain"
, testCase "NAT chain type is nat" $ do
v <- compileToValue
"policy nat_post : Frame \
\ on { hook = Postrouting, table = NAT, priority = SrcNat } \
\ = { | _ -> Allow; };"
arr <- nftArr v
case withKey "chain" arr of
(c:_) -> at ["chain","type"] c @?= Just (A.String "nat")
[] -> assertFailure "No chain"
, testCase "input chain default policy is drop" $ do
v <- compileToValue
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ -> Drop; };"
arr <- nftArr v
case withKey "chain" arr of
(c:_) -> at ["chain","policy"] c @?= Just (A.String "drop")
[] -> assertFailure "No chain"
, testCase "output chain default policy is accept" $ do
v <- compileToValue
"policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
arr <- nftArr v
case withKey "chain" arr of
(c:_) -> at ["chain","policy"] c @?= Just (A.String "accept")
[] -> assertFailure "No chain"
, testCase "chain name matches policy name" $ do
v <- compileToValue
"policy my_input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ -> Drop; };"
arr <- nftArr v
case withKey "chain" arr of
(c:_) -> at ["chain","name"] c @?= Just (A.String "my_input")
[] -> assertFailure "No chain"
, testCase "two policies produce two chains" $ do
v <- compileToValue
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ -> Drop; }; \
\policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
arr <- nftArr v
length (withKey "chain" arr) @?= 2
]
-- ─── Rule expression tests ───────────────────────────────────────────────────
ruleExprs :: [A.Value] -> [A.Value]
ruleExprs arr =
[ e | r <- withKey "rule" arr
, Just (A.Array es) <- [at ["rule","expr"] r]
, e <- V.toList es ]
ruleExprTests :: TestTree
ruleExprTests = testGroup "rule expressions"
[ testCase "two arms produce two rules" $ do
v <- compileToValue
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ if ct.state in { Established, Related } -> Allow; \
\ | _ -> Drop; \
\ };"
arr <- nftArr v
length (withKey "rule" arr) @?= 2
, testCase "arm without guard produces one rule" $ do
v <- compileToValue
"policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
arr <- nftArr v
length (withKey "rule" arr) @?= 1
, testCase "rule expr array is present" $ do
v <- compileToValue
"policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
arr <- nftArr v
case withKey "rule" arr of
(r:_) -> case at ["rule","expr"] r of
Just (A.Array _) -> return ()
_ -> assertFailure "Missing or non-array 'expr'"
[] -> assertFailure "No rule"
, testCase "IPv4 ctor emits nfproto match" $ do
v <- compileToValue
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | Frame(_, IPv4(ip, _)) -> Allow; \
\ | _ -> Drop; \
\ };"
arr <- nftArr v
let matches = withKey "match" (ruleExprs arr)
hasNfp = any (\m ->
at ["match","left","meta","key"] m == Just (A.String "nfproto"))
matches
assertBool "Expected nfproto match for IPv4 ctor" hasNfp
, testCase "record field pat emits payload match" $ do
v <- compileToValue
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | Frame(_, TCP(tcp { dport = :22 }, _)) -> Allow; \
\ | _ -> Drop; \
\ };"
arr <- nftArr v
let matches = withKey "match" (ruleExprs arr)
hasPort = any (\m ->
at ["match","right"] m == Just (A.String "22"))
matches
assertBool "Expected port 22 payload match" hasPort
]
-- ─── Verdict tests ───────────────────────────────────────────────────────────
allExprs :: [A.Value] -> [A.Value]
allExprs arr =
concatMap (\r -> case at ["rule","expr"] r of
Just (A.Array es) -> V.toList es; _ -> [])
(withKey "rule" arr)
verdictTests :: TestTree
verdictTests = testGroup "verdicts"
[ testCase "Allow compiles to accept" $ do
v <- compileToValue
"policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
arr <- nftArr v
assertBool "Expected accept verdict"
(not (null (withKey "accept" (allExprs arr))))
, testCase "Drop compiles to drop" $ do
v <- compileToValue
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ -> Drop; };"
arr <- nftArr v
assertBool "Expected drop verdict"
(not (null (withKey "drop" (allExprs arr))))
, testCase "Masquerade compiles to masquerade" $ do
v <- compileToValue
"policy nat_post : Frame \
\ on { hook = Postrouting, table = NAT, priority = SrcNat } \
\ = { | _ -> Masquerade; };"
arr <- nftArr v
assertBool "Expected masquerade verdict"
(not (null (withKey "masquerade" (allExprs arr))))
, testCase "rule call compiles to jump" $ do
v <- compileToValue
"rule blockAll : Frame -> Action = \\f -> case f of { | _ -> Drop; }; \
\policy fwd : Frame \
\ on { hook = Forward, table = Filter, priority = Filter } \
\ = { | frame -> blockAll(frame); };"
arr <- nftArr v
assertBool "Expected jump verdict for rule call"
(not (null (withKey "jump" (allExprs arr))))
]
-- ─── Layer stripping tests ───────────────────────────────────────────────────
layerStrippingTests :: TestTree
layerStrippingTests = testGroup "layer stripping"
[ testCase "Frame with and without Ether both emit nfproto match" $ do
let withEther =
"policy p1 : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | Frame(_, Ether(_, IPv4(ip, _))) -> Allow; \
\ | _ -> Drop; \
\ };"
withoutEther =
"policy p1 : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | Frame(_, IPv4(ip, _)) -> Allow; \
\ | _ -> Drop; \
\ };"
v1 <- compileToValue withEther
v2 <- compileToValue withoutEther
arr1 <- nftArr v1
arr2 <- nftArr v2
let nfp arr = filter
(\m -> at ["match","left","meta","key"] m == Just (A.String "nfproto"))
(withKey "match" (ruleExprs arr))
assertBool "Both should produce nfproto matches"
(not (null (nfp arr1)) && not (null (nfp arr2)))
]
-- ─── Continue tests ───────────────────────────────────────────────────────────
continueTests :: TestTree
continueTests = testGroup "Continue"
[ testCase "two terminal arms produce two rules" $ do
v <- compileToValue
"policy fwd : Frame \
\ on { hook = Forward, table = Filter, priority = Filter } \
\ = { | _ if ct.state in { Established } -> Allow; \
\ | _ -> Drop; \
\ };"
arr <- nftArr v
length (withKey "rule" arr) @?= 2
, testCase "non-Continue arms still produce rules" $ do
v <- compileToValue
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ if ct.state in { Established } -> Allow; \
\ | _ -> Drop; \
\ };"
arr <- nftArr v
assertBool "Should have rules for non-Continue arms"
(not (null (withKey "rule" arr)))
]
-- ─── Config tests ─────────────────────────────────────────────────────────────
configTests :: TestTree
configTests = testGroup "config"
[ testCase "all rule objects reference correct table" $ do
v <- compileToValue
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ -> Drop; };"
arr <- nftArr v
mapM_ (\r -> at ["rule","table"] r @?= Just (A.String "fwl"))
(withKey "rule" arr)
, testCase "chain objects reference correct table" $ do
v <- compileToValue
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { | _ -> Drop; };"
arr <- nftArr v
mapM_ (\c -> at ["chain","table"] c @?= Just (A.String "fwl"))
(withKey "chain" arr)
]

44
test/FWL/Util.hs Normal file
View File

@@ -0,0 +1,44 @@
-- | Shared test utilities.
module FWL.Util where
import Test.Tasty.HUnit
import Text.Parsec.String (Parser)
import Text.Parsec (parse)
import FWL.Parser (parseProgram)
import FWL.AST
-- | Assert a parser succeeds and return the result.
shouldParse :: (Show a) => Parser a -> String -> IO a
shouldParse p input =
case parse p "<test>" input of
Left err -> assertFailure ("Unexpected parse error:\n" ++ show err)
>> undefined
Right v -> return v
-- | Assert a parser fails.
shouldFailParse :: (Show a) => Parser a -> String -> IO ()
shouldFailParse p input =
case parse p "<test>" input of
Left _ -> return ()
Right v -> assertFailure ("Expected parse failure but got: " ++ show v)
-- | Parse a full program, asserting success.
parseOk :: String -> IO Program
parseOk src =
case parseProgram "<test>" src of
Left err -> assertFailure ("Parse error:\n" ++ show err) >> undefined
Right p -> return p
-- | Parse a full program, asserting failure.
parseFail :: String -> IO ()
parseFail src =
case parseProgram "<test>" src of
Left _ -> return ()
Right p -> assertFailure ("Expected parse failure, got:\n" ++ show p)
-- | Extract the single declaration from a one-decl program.
singleDecl :: Program -> IO Decl
singleDecl (Program _ [d]) = return d
singleDecl (Program _ ds) =
assertFailure ("Expected 1 decl, got " ++ show (length ds)) >> undefined

516
test/ParserTests.hs Normal file
View File

@@ -0,0 +1,516 @@
module ParserTests (tests) where
import Test.Tasty
import Test.Tasty.HUnit
import FWL.AST
import FWL.Util
tests :: TestTree
tests = testGroup "Parser"
[ interfaceTests
, zoneTests
, importTests
, letTests
, patternTests
, flowTests
, typeTests
, exprTests
, policyTests
, ruleTests
, configTests
, errorTests
]
-- ─── Interface ───────────────────────────────────────────────────────────────
interfaceTests :: TestTree
interfaceTests = testGroup "interface"
[ testCase "WAN dynamic" $ do
p <- parseOk "interface wan : WAN { dynamic; };"
d <- singleDecl p
case d of
DInterface "wan" IWan [IPDynamic] -> return ()
_ -> assertFailure (show d)
, testCase "LAN with cidr4" $ do
p <- parseOk "interface lan : LAN { cidr4 = { 10.0.0.0/8 }; };"
d <- singleDecl p
case d of
DInterface "lan" ILan [IPCidr4 [(LIPv4 (10,0,0,0), 8)]] -> return ()
_ -> assertFailure (show d)
, testCase "LAN with cidr4 and cidr6" $ do
p <- parseOk
"interface lan : LAN { \
\ cidr4 = { 10.17.1.0/24 }; \
\ cidr6 = { 192.168.0.0/16 }; \
\};"
d <- singleDecl p
case d of
DInterface "lan" ILan [IPCidr4 _, IPCidr6 _] -> return ()
_ -> assertFailure (show d)
, testCase "WireGuard interface" $ do
p <- parseOk "interface wg0 : WireGuard {};"
d <- singleDecl p
case d of
DInterface "wg0" IWireGuard [] -> return ()
_ -> assertFailure (show d)
, testCase "user-defined kind" $ do
p <- parseOk "interface eth0 : Bridge {};"
d <- singleDecl p
case d of
DInterface "eth0" (IUser "Bridge") [] -> return ()
_ -> assertFailure (show d)
, testCase "multiple CIDRs in set" $ do
p <- parseOk
"interface lan : LAN { \
\ cidr4 = { 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 }; \
\};"
d <- singleDecl p
case d of
DInterface _ _ [IPCidr4 cidrs] -> length cidrs @?= 3
_ -> assertFailure (show d)
]
-- ─── Zone ────────────────────────────────────────────────────────────────────
zoneTests :: TestTree
zoneTests = testGroup "zone"
[ testCase "single member" $ do
p <- parseOk "zone trusted = { lan };"
d <- singleDecl p
case d of
DZone "trusted" ["lan"] -> return ()
_ -> assertFailure (show d)
, testCase "multiple members" $ do
p <- parseOk "zone lan_zone = { lan, wg0, vlan10 };"
d <- singleDecl p
case d of
DZone "lan_zone" ["lan","wg0","vlan10"] -> return ()
_ -> assertFailure (show d)
]
-- ─── Import ──────────────────────────────────────────────────────────────────
importTests :: TestTree
importTests = testGroup "import"
[ testCase "basic import" $ do
p <- parseOk "import rfc1918 : CIDRSet from \"builtin:rfc1918\";"
d <- singleDecl p
case d of
DImport "rfc1918" (TName "CIDRSet" []) "builtin:rfc1918" -> return ()
_ -> assertFailure (show d)
]
-- ─── Let ─────────────────────────────────────────────────────────────────────
letTests :: TestTree
letTests = testGroup "let"
[ testCase "simple integer" $ do
p <- parseOk "let timeout : Int = 30;"
d <- singleDecl p
case d of
DLet "timeout" (TName "Int" []) (ELit (LInt 30)) -> return ()
_ -> assertFailure (show d)
, testCase "map literal" $ do
p <- parseOk
"let forwards : Map<(Protocol,Port),(IP,Port)> = { \
\ (tcp, :8080) -> (10.0.0.1, :80) \
\};"
d <- singleDecl p
case d of
DLet "forwards" _ (EMap [_]) -> return ()
_ -> assertFailure (show d)
, testCase "string literal" $ do
p <- parseOk "let name : String = \"hello\";"
d <- singleDecl p
case d of
DLet "name" _ (ELit (LString "hello")) -> return ()
_ -> assertFailure (show d)
]
-- ─── Pattern ─────────────────────────────────────────────────────────────────
patternTests :: TestTree
patternTests = testGroup "pattern"
[ testCase "tuple with record field" $ do
p <- parseOk
"pattern WGInitiation : (UDPHeader, Bytes) = \
\ (udp { length = 156 }, [0x01 _*]);"
d <- singleDecl p
case d of
DPattern "WGInitiation" _ (PTuple [PRecord "udp" _, PBytes _]) -> return ()
_ -> assertFailure (show d)
, testCase "byte pattern elements" $ do
p <- parseOk
"pattern WGResponse : (UDPHeader, Bytes) = \
\ (udp { length = 100 }, [0x02 _ _*]);"
d <- singleDecl p
case d of
DPattern "WGResponse" _ (PTuple [_, PBytes [BEHex 0x02, BEWild, BEWildStar]]) ->
return ()
_ -> assertFailure (show d)
, testCase "named pattern reference in ctor" $ do
p <- parseOk
"pattern Complex : Frame = \
\ Frame(_, IPv4(ip, WGInitiation));"
d <- singleDecl p
case d of
DPattern "Complex" _ (PFrame Nothing (PCtor "IPv4" [PVar "ip", PNamed "WGInitiation"])) ->
return ()
_ -> assertFailure (show d)
, testCase "record with field bind" $ do
p <- parseOk "pattern HasTCP : TCP = tcp { dport };"
d <- singleDecl p
case d of
DPattern "HasTCP" _ (PRecord "tcp" [FPBind "dport"]) -> return ()
_ -> assertFailure (show d)
, testCase "record with field equality" $ do
p <- parseOk "pattern SSH : TCP = tcp { dport = :22 };"
d <- singleDecl p
case d of
DPattern "SSH" _ (PRecord "tcp" [FPEq "dport" (LPort 22)]) -> return ()
_ -> assertFailure (show d)
]
-- ─── Flow ────────────────────────────────────────────────────────────────────
flowTests :: TestTree
flowTests = testGroup "flow"
[ testCase "two-step sequence with within" $ do
p <- parseOk
"flow WireGuardHandshake : FlowPattern = \
\ WGInitiation . WGResponse within 5s;"
d <- singleDecl p
case d of
DFlow "WireGuardHandshake" (FSeq (FAtom "WGInitiation") (FAtom "WGResponse") (Just (5, Seconds))) ->
return ()
_ -> assertFailure (show d)
, testCase "single atom flow" $ do
p <- parseOk "flow Simple : FlowPattern = Ping;"
d <- singleDecl p
case d of
DFlow "Simple" (FAtom "Ping") -> return ()
_ -> assertFailure (show d)
, testCase "duration in milliseconds" $ do
p <- parseOk "flow Fast : FlowPattern = A . B within 500ms;"
d <- singleDecl p
case d of
DFlow "Fast" (FSeq _ _ (Just (500, Millis))) -> return ()
_ -> assertFailure (show d)
]
-- ─── Types ───────────────────────────────────────────────────────────────────
typeTests :: TestTree
typeTests = testGroup "types"
[ testCase "simple name" $ do
p <- parseOk "let x : Frame = Allow;"
d <- singleDecl p
case d of
DLet _ (TName "Frame" []) _ -> return ()
_ -> assertFailure (show d)
, testCase "generic type" $ do
p <- parseOk "let x : Map<Int, String> = Allow;"
d <- singleDecl p
case d of
DLet _ (TName "Map" [TName "Int" [], TName "String" []]) _ -> return ()
_ -> assertFailure (show d)
, testCase "function type" $ do
p <- parseOk "let x : Frame -> Action = Allow;"
d <- singleDecl p
case d of
DLet _ (TFun (TName "Frame" []) (TName "Action" [])) _ -> return ()
_ -> assertFailure (show d)
, testCase "effect type" $ do
p <- parseOk "let x : <Log, FlowMatch> Action = Allow;"
d <- singleDecl p
case d of
DLet _ (TEffect ["Log","FlowMatch"] (TName "Action" [])) _ -> return ()
_ -> assertFailure (show d)
, testCase "tuple type" $ do
p <- parseOk "let x : (Int, String) = Allow;"
d <- singleDecl p
case d of
DLet _ (TTuple [TName "Int" [], TName "String" []]) _ -> return ()
_ -> assertFailure (show d)
, testCase "function with effects" $ do
p <- parseOk "let x : Frame -> <Log> Action = Allow;"
d <- singleDecl p
case d of
DLet _ (TFun _ (TEffect ["Log"] _)) _ -> return ()
_ -> assertFailure (show d)
]
-- ─── Expressions ─────────────────────────────────────────────────────────────
exprTests :: TestTree
exprTests = testGroup "expressions"
[ testCase "boolean and" $ do
p <- parseOk "let x : Bool = a && b;"
d <- singleDecl p
case d of
DLet _ _ (EInfix OpAnd (EVar "a") (EVar "b")) -> return ()
_ -> assertFailure (show d)
, testCase "set membership with 'in'" $ do
p <- parseOk "let x : Bool = ct.state in { Established, Related };"
d <- singleDecl p
case d of
DLet _ _ (EInfix OpIn (EQual ["ct","state"]) (ESet _)) -> return ()
_ -> assertFailure (show d)
, testCase "equality comparison" $ do
p <- parseOk "let x : Bool = tcp.dport == :22;"
d <- singleDecl p
case d of
DLet _ _ (EInfix OpEq (EQual ["tcp","dport"]) (ELit (LPort 22))) -> return ()
_ -> assertFailure (show d)
, testCase "if-then-else" $ do
p <- parseOk "let x : Action = if a then Allow else Drop;"
d <- singleDecl p
case d of
DLet _ _ (EIf (EVar "a") (EVar "Allow") (EVar "Drop")) -> return ()
_ -> assertFailure (show d)
, testCase "perform expression" $ do
p <- parseOk "let x : Action = perform Log.emit(Info, \"msg\");"
d <- singleDecl p
case d of
DLet _ _ (EPerform ["Log","emit"] [ELit (LString "Info"), ELit (LString "msg")]) -> return ()
DLet _ _ (EPerform ["Log","emit"] _) -> return () -- arg parsing flexible
_ -> assertFailure (show d)
, testCase "do block" $ do
p <- parseOk "let x : Action = do { y <- foo; y };"
d <- singleDecl p
case d of
DLet _ _ (EDo [DSBind "y" _, DSExpr (EVar "y")]) -> return ()
_ -> assertFailure (show d)
, testCase "nested case" $ do
p <- parseOk
"let x : Action = case e of { \
\ | a -> Allow; \
\ | _ -> Drop; \
\};"
d <- singleDecl p
case d of
DLet _ _ (ECase (EVar "e") [Arm (PVar "a") Nothing _, Arm PWild Nothing _]) -> return ()
_ -> assertFailure (show d)
, testCase "lambda" $ do
p <- parseOk "let x : Frame -> Action = \\frame -> Allow;"
d <- singleDecl p
case d of
DLet _ _ (ELam "frame" (EVar "Allow")) -> return ()
_ -> assertFailure (show d)
, testCase "string concat" $ do
p <- parseOk "let x : String = \"hello\" ++ \" world\";"
d <- singleDecl p
case d of
DLet _ _ (EInfix OpConcat _ _) -> return ()
_ -> assertFailure (show d)
, testCase "negation" $ do
p <- parseOk "let x : Bool = !flag;"
d <- singleDecl p
case d of
DLet _ _ (ENot (EVar "flag")) -> return ()
_ -> assertFailure (show d)
, testCase "set literal" $ do
p <- parseOk "let x : Set<Int> = { 22, 80, 443 };"
d <- singleDecl p
case d of
DLet _ _ (ESet [ELit (LInt 22), ELit (LInt 80), ELit (LInt 443)]) -> return ()
_ -> assertFailure (show d)
]
-- ─── Policy ──────────────────────────────────────────────────────────────────
policyTests :: TestTree
policyTests = testGroup "policy"
[ testCase "minimal policy" $ do
p <- parseOk
"policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
d <- singleDecl p
case d of
DPolicy "output" _ (PolicyMeta HOutput TFilter (Priority 0)) [_] -> return ()
_ -> assertFailure (show d)
, testCase "NAT prerouting" $ do
p <- parseOk
"policy nat_pre : Frame \
\ on { hook = Prerouting, table = NAT, priority = DstNat } \
\ = { | _ -> Allow; };"
d <- singleDecl p
case d of
DPolicy _ _ (PolicyMeta HPrerouting TNAT (Priority (-100))) _ -> return ()
_ -> assertFailure (show d)
, testCase "arm with guard" $ do
p <- parseOk
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { \
\ | _ if ct.state in { Established, Related } -> Allow; \
\ | _ -> Drop; \
\ };"
d <- singleDecl p
case d of
DPolicy _ _ _ [Arm PWild (Just _) _, Arm PWild Nothing _] -> return ()
_ -> assertFailure (show d)
, testCase "Frame pattern with path" $ do
p <- parseOk
"policy forward : Frame \
\ on { hook = Forward, table = Filter, priority = Filter } \
\ = { \
\ | Frame(iif in lan_zone -> wan, _) -> Allow; \
\ | _ -> Drop; \
\ };"
d <- singleDecl p
case d of
DPolicy _ _ _ (Arm (PFrame (Just _) _) Nothing _ : _) -> return ()
_ -> assertFailure (show d)
, testCase "Frame pattern without Ether (layer stripping)" $ do
p <- parseOk
"policy input : Frame \
\ on { hook = Input, table = Filter, priority = Filter } \
\ = { \
\ | Frame(_, IPv4(ip, TCP(tcp, _))) if tcp.dport == :22 -> Allow; \
\ | _ -> Drop; \
\ };"
d <- singleDecl p
case d of
DPolicy _ _ _ (Arm (PFrame Nothing (PCtor "IPv4" _)) _ _ : _) -> return ()
_ -> assertFailure (show d)
, testCase "policy arm calls rule" $ do
p <- parseOk
"policy forward : Frame \
\ on { hook = Forward, table = Filter, priority = Filter } \
\ = { \
\ | frame -> blockOutboundWG(frame); \
\ };"
d <- singleDecl p
case d of
DPolicy _ _ _ [Arm (PVar "frame") Nothing (EApp (EVar "blockOutboundWG") _)] ->
return ()
_ -> assertFailure (show d)
, testCase "Continue arm is parsed" $ do
p <- parseOk
"rule r : Frame -> Action = \
\ \\frame -> case frame of { \
\ | _ -> Continue; \
\ };"
d <- singleDecl p
case d of
DRule _ _ _ -> return ()
_ -> assertFailure (show d)
]
-- ─── Rule ────────────────────────────────────────────────────────────────────
ruleTests :: TestTree
ruleTests = testGroup "rule"
[ testCase "simple rule" $ do
p <- parseOk
"rule blockAll : Frame -> Action = \
\ \\frame -> case frame of { | _ -> Drop; };"
d <- singleDecl p
case d of
DRule "blockAll" _ (ELam "frame" (ECase _ _)) -> return ()
_ -> assertFailure (show d)
, testCase "rule with effects in type" $ do
p <- parseOk
"rule logged : Frame -> <Log> Action = \
\ \\f -> case f of { | _ -> Allow; };"
d <- singleDecl p
case d of
DRule "logged" (TFun _ (TEffect ["Log"] _)) _ -> return ()
_ -> assertFailure (show d)
, testCase "nested case in rule" $ do
p <- parseOk
"rule check : Frame -> <FlowMatch> Action = \
\ \\frame -> \
\ case frame of { \
\ | Frame(_, IPv4(ip, UDP(udp, _))) -> \
\ case perform FlowMatch.check(ip, wg) of { \
\ | Matched -> Drop; \
\ | _ -> Continue; \
\ }; \
\ | _ -> Continue; \
\ };"
d <- singleDecl p
case d of
DRule "check" _ (ELam _ (ECase _ _)) -> return ()
_ -> assertFailure (show d)
]
-- ─── Config ──────────────────────────────────────────────────────────────────
configTests :: TestTree
configTests = testGroup "config"
[ testCase "default table name" $ do
p <- parseOk "interface wan : WAN {};"
configTable (progConfig p) @?= "fwl"
, testCase "custom table name" $ do
p <- parseOk "config { table = \"myrules\"; } interface wan : WAN {};"
configTable (progConfig p) @?= "myrules"
]
-- ─── Error cases ─────────────────────────────────────────────────────────────
errorTests :: TestTree
errorTests = testGroup "parse errors"
[ testCase "missing semicolon" $
parseFail "interface wan : WAN {}"
, testCase "unknown hook" $
parseFail
"policy p : Frame \
\ on { hook = Bogus, table = Filter, priority = Filter } \
\ = { | _ -> Allow; };"
, testCase "empty arm block with no arms is ok" $ do
p <- parseOk
"policy output : Frame \
\ on { hook = Output, table = Filter, priority = Filter } \
\ = {};"
d <- singleDecl p
case d of
DPolicy _ _ _ [] -> return ()
_ -> assertFailure (show d)
, testCase "CIDR without prefix fails" $
parseFail "interface lan : LAN { cidr4 = { 10.0.0.1 }; };"
]

15
test/Spec.hs Normal file
View File

@@ -0,0 +1,15 @@
module Main where
import Test.Tasty
import Test.Tasty.HUnit
import qualified ParserTests
import qualified CheckTests
import qualified CompileTests
main :: IO ()
main = defaultMain $ testGroup "FWL"
[ ParserTests.tests
, CheckTests.tests
, CompileTests.tests
]