diff --git a/app/Main.hs b/app/Main.hs new file mode 100644 index 0000000..ca98eaa --- /dev/null +++ b/app/Main.hs @@ -0,0 +1,54 @@ +module Main where + +import System.Environment (getArgs) +import System.Exit (exitFailure, exitSuccess) +import System.IO (hPutStrLn, stderr) +import qualified Data.ByteString.Lazy.Char8 as BL + +import FWL.Parser (parseFile) +import FWL.Pretty (prettyProgram) +import FWL.Check (checkProgram) +import FWL.Compile (compileToJson) + +main :: IO () +main = do + args <- getArgs + case args of + ["check", fp] -> runCheck fp + ["compile", fp] -> runCompile fp + ["pretty", fp] -> runPretty fp + _ -> do + putStrLn "Usage: fwlc " + putStrLn " check -- parse and static-check" + putStrLn " compile -- emit nftables JSON to stdout" + putStrLn " pretty -- parse and re-print" + exitFailure + +runCheck :: FilePath -> IO () +runCheck fp = do + result <- parseFile fp + case result of + Left err -> hPutStrLn stderr ("Parse error:\n" ++ show err) >> exitFailure + Right prog -> do + let errs = checkProgram prog + if null errs + then putStrLn "OK" >> exitSuccess + else mapM_ (hPutStrLn stderr . show) errs >> exitFailure + +runCompile :: FilePath -> IO () +runCompile fp = do + result <- parseFile fp + case result of + Left err -> hPutStrLn stderr ("Parse error:\n" ++ show err) >> exitFailure + Right prog -> do + let errs = checkProgram prog + if null errs + then BL.putStrLn (compileToJson prog) + else mapM_ (hPutStrLn stderr . ("Check error: " ++) . show) errs >> exitFailure + +runPretty :: FilePath -> IO () +runPretty fp = do + result <- parseFile fp + case result of + Left err -> hPutStrLn stderr ("Parse error:\n" ++ show err) >> exitFailure + Right prog -> putStr (prettyProgram prog) diff --git a/cabal.project b/cabal.project new file mode 100644 index 0000000..e6fdbad --- /dev/null +++ b/cabal.project @@ -0,0 +1 @@ +packages: . diff --git a/doc/initial_proposal.md b/doc/proposal.md similarity index 100% rename from doc/initial_proposal.md rename to doc/proposal.md diff --git a/examples/router.fwl b/examples/router.fwl new file mode 100644 index 0000000..889049a --- /dev/null +++ b/examples/router.fwl @@ -0,0 +1,95 @@ +-- Example: home router firewall in FWL +-- Compile with: fwlc compile examples/router.fwl + +interface wan : WAN { dynamic; }; +interface lan : LAN { cidr4 = { 10.17.1.0/24 }; }; +interface wg0 : WireGuard {}; + +zone lan_zone = { lan, wg0 }; + +import rfc1918 : CIDRSet from "builtin:rfc1918"; + +let forwards : Map<(Protocol, Port), (IP, Port)> = { + (tcp, :8080) -> (10.17.1.10, :80), + (tcp, :2222) -> (10.17.1.11, :22) +}; + +-- WireGuard handshake detection (compiles to ct mark state machine) +pattern WGInitiation : (UDPHeader, Bytes) = + (udp { length = 156 }, [0x01 _*]); + +pattern WGResponse : (UDPHeader, Bytes) = + (udp { length = 100 }, [0x02 _*]); + +flow WireGuardHandshake : FlowPattern = + WGInitiation . WGResponse within 5s; + +-- Block LAN clients from tunnelling out via WireGuard +rule blockOutboundWG : Frame -> Action = + \frame -> + case frame of { + | Frame(iif in lan_zone -> wan, IPv4(ip, UDP(udp, payload))) + if matches(WGInitiation, (udp, payload)) -> + case perform FlowMatch.check(flowOf(ip, wg), WireGuardHandshake) of { + | Matched -> do { + perform Log.emit(Warn, "WG blocked"); + Drop + }; + | _ -> Continue; + }; + | _ -> Continue; + }; + +-- Inbound to router +policy input : Frame + on { hook = Input, table = Filter, priority = Filter } + = { + | _ if ct.state in { Established, Related } -> Allow; + | Frame(lo, _) -> Allow; + | Frame(_, IPv6(ip6, ICMPv6(_, _))) + if ip6.src in fe80::/10 -> Allow; + | Frame(_, IPv4(_, TCP(tcp, _))) + if tcp.dport == :22 -> Allow; + | Frame(_, IPv4(_, UDP(udp, _))) + if udp.dport == :51944 -> Allow; + | _ -> Drop; + }; + +-- Forwarded traffic +policy forward : Frame + on { hook = Forward, table = Filter, priority = Filter } + = { + | _ if ct.state in { Established, Related } -> Allow; + | frame if iif in lan_zone && oif == wan -> blockOutboundWG(frame); + | _ if ct.status == DNAT -> Allow; + | Frame(iif in lan_zone -> wan, _) -> Allow; + | Frame(iif in lan_zone -> lan_zone, _) -> Allow; + | Frame(wan -> lan_zone, IPv4(ip, TCP(tcp, _))) + if (ip.dst, tcp.dport) in forwards -> Allow; + | _ -> Drop; + }; + +-- Outbound from router +policy output : Frame + on { hook = Output, table = Filter, priority = Filter } + = { + | _ -> Allow; + }; + +-- NAT +policy nat_prerouting : Frame + on { hook = Prerouting, table = NAT, priority = DstNat } + = { + | Frame(_, IPv4(ip, _)) -> + if perform FIB.daddrLocal(ip.dst) + then DNATMap(forwards) + else Allow; + | _ -> Allow; + }; + +policy nat_postrouting : Frame + on { hook = Postrouting, table = NAT, priority = SrcNat } + = { + | Frame(_ -> wan, IPv4(ip, _)) if ip.src in rfc1918 -> Masquerade; + | _ -> Allow; + }; diff --git a/fwl.cabal b/fwl.cabal new file mode 100644 index 0000000..aca2c07 --- /dev/null +++ b/fwl.cabal @@ -0,0 +1,58 @@ +cabal-version: 3.0 +name: fwl +version: 0.1.0.0 +synopsis: Firewall Language — MVP +build-type: Simple + +common shared + ghc-options: -Wall + default-language: Haskell2010 + +library + import: shared + hs-source-dirs: src + exposed-modules: + FWL.AST + , FWL.Lexer + , FWL.Parser + , FWL.Pretty + , FWL.Check + , FWL.Compile + build-depends: + base >= 4.14 + , parsec >= 3.1 + , aeson >= 2.0 + , aeson-pretty >= 0.8 + , text >= 1.2 + , containers >= 0.6 + , mtl >= 2.2 + , prettyprinter >= 1.7 + , bytestring >= 0.11 + , word8 >= 0.1 + +executable fwlc + import: shared + main-is: Main.hs + hs-source-dirs: app + build-depends: + base, fwl, text, aeson-pretty, bytestring + +test-suite fwl-tests + import: shared + type: exitcode-stdio-1.0 + main-is: Spec.hs + hs-source-dirs: test + other-modules: + FWL.Util + , ParserTests + , CheckTests + , CompileTests + build-depends: + base, fwl + , tasty >= 1.4 + , tasty-hunit >= 0.10 + , aeson >= 2.0 + , aeson-pretty >= 0.8 + , bytestring >= 0.11 + , parsec >= 3.1 + , vector >= 0.12 diff --git a/src/FWL/AST.hs b/src/FWL/AST.hs new file mode 100644 index 0000000..2049661 --- /dev/null +++ b/src/FWL/AST.hs @@ -0,0 +1,233 @@ +module FWL.AST where + +import Data.Bits ((.&.), (.|.), shiftL, shiftR) +import Data.Word (Word8) -- Word8 still used for ByteElem/hex literals + +type Name = String + +-- ─── Program ──────────────────────────────────────────────────────────────── + +data Program = Program + { progConfig :: Config + , progDecls :: [Decl] + } deriving (Show) + +data Config = Config + { configTable :: String -- default "fwl" + } deriving (Show) + +defaultConfig :: Config +defaultConfig = Config { configTable = "fwl" } + +-- ─── Declarations ─────────────────────────────────────────────────────────── + +data Decl + = DInterface Name IfaceKind [IfaceProp] + | DZone Name [Name] + | DImport Name Type FilePath + | DLet Name Type Expr + | DPattern Name Type Pat + | DFlow Name FlowExpr + | DRule Name Type Expr + | DPolicy Name Type PolicyMeta ArmBlock + deriving (Show) + +data PolicyMeta = PolicyMeta + { pmHook :: Hook + , pmTable :: TableName + , pmPriority :: Priority + } deriving (Show) + +data Hook = HInput | HForward | HOutput | HPrerouting | HPostrouting + deriving (Show, Eq) +data TableName = TFilter | TNAT + deriving (Show, Eq) +-- Priority is always an integer in the nftables JSON. +-- Named constants are resolved to their numeric values at parse time. +newtype Priority = Priority { priorityValue :: Int } + deriving (Show, Eq) + +-- Standard nftables priority constants +pRaw, pConnTrackDefrag, pConnTrack, pMangle, pDstNat, pFilter, pSecurity, pSrcNat :: Priority +pRaw = Priority (-300) +pConnTrackDefrag = Priority (-400) +pConnTrack = Priority (-200) +pMangle = Priority (-150) +pDstNat = Priority (-100) +pFilter = Priority 0 +pSecurity = Priority 50 +pSrcNat = Priority 100 + +data IfaceKind = IWan | ILan | IWireGuard | IUser Name + deriving (Show) + +data IfaceProp + = IPDynamic + | IPCidr4 [CIDR] + | IPCidr6 [CIDR] + deriving (Show) + +-- | A CIDR block: base address literal paired with prefix length. +-- e.g. (LIPv4 (10,0,0,0), 8) represents 10.0.0.0/8 +type CIDR = (Literal, Int) + +-- ─── Patterns ─────────────────────────────────────────────────────────────── + +data Pat + = PWild + | PVar Name + | PNamed Name + | PCtor Name [Pat] + | PRecord Name [FieldPat] + | PTuple [Pat] + | PFrame (Maybe PathPat) Pat + | PBytes [ByteElem] + deriving (Show) + +data FieldPat + = FPEq Name Literal + | FPBind Name + | FPAs Name Name + deriving (Show) + +data PathPat = PathPat (Maybe EndpointPat) (Maybe EndpointPat) + deriving (Show) + +data EndpointPat + = EPWild + | EPName Name + | EPMember Name Name + deriving (Show) + +data ByteElem + = BEHex Word8 + | BEWild + | BEWildStar + deriving (Show) + +-- ─── Flow ─────────────────────────────────────────────────────────────────── + +data FlowExpr + = FAtom Name + | FSeq FlowExpr FlowExpr (Maybe Duration) + deriving (Show) + +type Duration = (Int, TimeUnit) + +-- Fix 1: TimeUnit must derive Eq because Literal (which embeds it via +-- LDuration) derives Eq, requiring all constituent types to also have Eq. +data TimeUnit = Seconds | Millis | Minutes | Hours + deriving (Show, Eq) + +-- ─── Types ────────────────────────────────────────────────────────────────── + +data Type + = TName Name [Type] + | TTuple [Type] + | TFun Type Type + | TEffect [Name] Type + deriving (Show) + +-- ─── Expressions ──────────────────────────────────────────────────────────── + +data Expr + = EVar Name + | EQual [Name] + | ELit Literal + | ELam Name Expr + | EApp Expr Expr + | ECase Expr ArmBlock + | EIf Expr Expr Expr + | EDo [DoStmt] + | ELet Name Expr Expr + | ETuple [Expr] + | ESet [Expr] + | EMap [(Expr, Expr)] + | EPerform [Name] [Expr] + | EInfix InfixOp Expr Expr + | ENot Expr + deriving (Show) + +data InfixOp + = OpAnd | OpOr + | OpEq | OpNeq | OpLt | OpLte | OpGt | OpGte + | OpIn + | OpConcat + | OpThen + | OpBind + deriving (Show, Eq) + +data DoStmt + = DSBind Name Expr + | DSExpr Expr + deriving (Show) + +type ArmBlock = [Arm] +data Arm = Arm Pat (Maybe Expr) Expr + deriving (Show) + +-- ─── Literals ─────────────────────────────────────────────────────────────── + +-- IP addresses are stored as plain Integers for easy arithmetic, +-- CIDR validation (mask host bits), and future subnet math. +-- IPv4: 32-bit value in the low 32 bits. +-- IPv6: 128-bit value. +-- CIDR host-bit validation: (addr .&. hostMask prefix bits) == 0 +data IPVersion = IPv4 | IPv6 + deriving (Show, Eq) + +data Literal + = LInt Int + | LString String + | LBool Bool + | LIP IPVersion Integer -- unified IP address representation + | LCIDR Literal Int -- base address + prefix length + | LPort Int + | LDuration Int TimeUnit + | LHex Word8 + deriving (Show, Eq) + +-- ─── IP address helpers ────────────────────────────────────────────────────── + +-- | Build an IPv4 literal from four octets. +ipv4Lit :: Int -> Int -> Int -> Int -> Literal +ipv4Lit a b c d = + LIP IPv4 (fromIntegral a `shiftL` 24 + .|. fromIntegral b `shiftL` 16 + .|. fromIntegral c `shiftL` 8 + .|. fromIntegral d) + +-- | Check that a CIDR has no host bits set. +cidrHostBitsZero :: Integer -> Int -> Int -> Bool +cidrHostBitsZero addr prefix bits = + let hostBits = bits - prefix + hostMask = (1 `shiftL` hostBits) - 1 + in (addr .&. hostMask) == 0 + +-- | Render an IPv4 integer as a dotted-decimal string. +renderIPv4 :: Integer -> String +renderIPv4 n = + show ((n `shiftR` 24) .&. 0xff) ++ "." ++ + show ((n `shiftR` 16) .&. 0xff) ++ "." ++ + show ((n `shiftR` 8) .&. 0xff) ++ "." ++ + show (n .&. 0xff) + +-- | Render an IPv6 integer as a condensed colon-hex string. +renderIPv6 :: Integer -> String +renderIPv6 n = + let groups = [ fromIntegral ((n `shiftR` (i * 16)) .&. 0xffff) :: Int + | i <- [7,6..0] ] + hexGroups = map (`showHex` "") groups + in concatIntersperse ":" hexGroups + where + showHex x s = let h = showHexInt x in h ++ s + showHexInt x + | x == 0 = "0" + | otherwise = reverse (go x) + where go 0 = [] + go v = let (q,r) = v `divMod` 16 + c = "0123456789abcdef" !! r + in c : go q + concatIntersperse _ [] = "" + concatIntersperse _ [x] = x + concatIntersperse s (x:xs) = x ++ s ++ concatIntersperse s xs diff --git a/src/FWL/Check.hs b/src/FWL/Check.hs new file mode 100644 index 0000000..f3bd15d --- /dev/null +++ b/src/FWL/Check.hs @@ -0,0 +1,207 @@ +{- | Static checks for MVP: + 1. Undefined name detection (interfaces, zones, patterns, rules/policies) + 2. Policy arm termination: last arm of a policy must not be Continue + 3. Named pattern cycle detection + 4. CIDR exhaustiveness stub (warns but does not error for MVP) +-} +module FWL.Check + ( checkProgram + , CheckError(..) + ) where + +import Data.List (foldl', nub) +import qualified Data.Map.Strict as Map +import qualified Data.Set as Set + +import FWL.AST + +data CheckError + = UndefinedName String String -- kind, name + | PolicyNoContinue String -- policy name + | PatternCycle [String] -- cycle path + | DuplicateDecl String String -- kind, name + deriving (Show, Eq) + +type Env = Map.Map String DeclKind +data DeclKind = KInterface | KZone | KLet | KPattern | KFlow | KRule | KPolicy + deriving (Show, Eq) + +checkProgram :: Program -> [CheckError] +checkProgram (Program _ decls) = + dupErrs ++ nameErrs ++ policyErrs ++ cycleErrs + where + env = buildEnv decls + dupErrs = findDups decls + nameErrs = concatMap (checkDecl env) decls + policyErrs = concatMap checkPolicyTermination decls + cycleErrs = checkPatternCycles decls + +-- ─── Environment ───────────────────────────────────────────────────────────── + +buildEnv :: [Decl] -> Env +buildEnv = foldl' addDecl Map.empty + where + addDecl m (DInterface n _ _) = Map.insert n KInterface m + addDecl m (DZone n _) = Map.insert n KZone m + addDecl m (DLet n _ _) = Map.insert n KLet m + addDecl m (DPattern n _ _) = Map.insert n KPattern m + addDecl m (DFlow n _) = Map.insert n KFlow m + addDecl m (DRule n _ _) = Map.insert n KRule m + addDecl m (DPolicy n _ _ _) = Map.insert n KPolicy m + addDecl m _ = m + +findDups :: [Decl] -> [CheckError] +findDups decls = go [] Set.empty decls + where + go acc _ [] = acc + go acc seen (d:ds) = + let n = declName d in + if Set.member n seen + then go (DuplicateDecl (declKindStr d) n : acc) seen ds + else go acc (Set.insert n seen) ds + +declName :: Decl -> String +declName (DInterface n _ _) = n +declName (DZone n _) = n +declName (DImport n _ _) = n +declName (DLet n _ _) = n +declName (DPattern n _ _) = n +declName (DFlow n _) = n +declName (DRule n _ _) = n +declName (DPolicy n _ _ _) = n + +declKindStr :: Decl -> String +declKindStr (DInterface _ _ _) = "interface" +declKindStr (DZone _ _) = "zone" +declKindStr (DImport _ _ _) = "import" +declKindStr (DLet _ _ _) = "let" +declKindStr (DPattern _ _ _) = "pattern" +declKindStr (DFlow _ _) = "flow" +declKindStr (DRule _ _ _) = "rule" +declKindStr (DPolicy _ _ _ _) = "policy" + +-- ─── Name resolution ───────────────────────────────────────────────────────── + +checkDecl :: Env -> Decl -> [CheckError] +checkDecl env (DZone _ ns) = concatMap (checkName env "interface or zone") ns +checkDecl env (DPattern _ _ p) = checkPat env p +checkDecl env (DFlow _ fe) = checkFlow env fe +checkDecl env (DRule _ _ e) = checkExpr env e +checkDecl env (DPolicy _ _ _ ab) = concatMap (checkArm env) ab +checkDecl env (DLet _ _ e) = checkExpr env e +checkDecl _ _ = [] + +checkName :: Env -> String -> String -> [CheckError] +checkName env kind n + | Map.member n env = [] + | isBuiltin n = [] + | otherwise = [UndefinedName kind n] + +isBuiltin :: String -> Bool +isBuiltin n = n `elem` + [ "ct", "iif", "oif", "lo", "wan", "lan" + , "tcp", "udp", "ip", "ip6", "eth" + , "Established", "Related", "DNAT" + , "Allow", "Drop", "Continue", "Masquerade" + , "Matched", "Unmatched" + , "true", "false" + ] + +checkPat :: Env -> Pat -> [CheckError] +checkPat _ PWild = [] +checkPat _ (PVar _) = [] +checkPat env (PNamed n) = checkName env "pattern" n +checkPat env (PCtor _ ps) = concatMap (checkPat env) ps +checkPat env (PRecord _ fs) = concatMap (checkFP env) fs +checkPat env (PTuple ps) = concatMap (checkPat env) ps +checkPat env (PFrame mp inner)= maybe [] (checkPath env) mp ++ checkPat env inner +checkPat _ (PBytes _) = [] + +checkFP :: Env -> FieldPat -> [CheckError] +checkFP _ _ = [] -- field names checked by type-checker later + +checkPath :: Env -> PathPat -> [CheckError] +checkPath env (PathPat ms md) = + maybe [] (checkEP env) ms ++ maybe [] (checkEP env) md + +checkEP :: Env -> EndpointPat -> [CheckError] +checkEP _ EPWild = [] +checkEP env (EPName n) = checkName env "interface or zone" n +checkEP env (EPMember _ z) = checkName env "zone" z + +checkFlow :: Env -> FlowExpr -> [CheckError] +checkFlow env (FAtom n) = checkName env "pattern" n +checkFlow env (FSeq a b _) = checkFlow env a ++ checkFlow env b + +checkArm :: Env -> Arm -> [CheckError] +checkArm env (Arm p mg e) = + checkPat env p ++ + maybe [] (checkExpr env) mg ++ + checkExpr env e + +checkExpr :: Env -> Expr -> [CheckError] +checkExpr env (EVar n) = checkName env "name" n +checkExpr _ (EQual _) = [] -- qualified names: deferred +checkExpr _ (ELit _) = [] +checkExpr env (ELam _ e) = checkExpr env e +checkExpr env (EApp f x) = checkExpr env f ++ checkExpr env x +checkExpr env (ECase e ab) = checkExpr env e ++ concatMap (checkArm env) ab +checkExpr env (EIf c t f) = concatMap (checkExpr env) [c,t,f] +checkExpr env (EDo ss) = concatMap (checkStmt env) ss +checkExpr env (ELet _ e1 e2) = checkExpr env e1 ++ checkExpr env e2 +checkExpr env (ETuple es) = concatMap (checkExpr env) es +checkExpr env (ESet es) = concatMap (checkExpr env) es +checkExpr env (EMap ms) = concatMap (\(k,v) -> checkExpr env k ++ checkExpr env v) ms +checkExpr env (EPerform _ as_) = concatMap (checkExpr env) as_ +checkExpr env (EInfix _ l r) = checkExpr env l ++ checkExpr env r +checkExpr env (ENot e) = checkExpr env e + +checkStmt :: Env -> DoStmt -> [CheckError] +checkStmt env (DSBind _ e) = checkExpr env e +checkStmt env (DSExpr e) = checkExpr env e + +-- ─── Policy termination ─────────────────────────────────────────────────────── + +-- The last arm of a policy block must not unconditionally return Continue. +checkPolicyTermination :: Decl -> [CheckError] +checkPolicyTermination (DPolicy n _ _ arms) + | null arms = [PolicyNoContinue n] + | isContinue (last arms) = [PolicyNoContinue n] + | otherwise = [] + where + isContinue (Arm PWild Nothing (EVar "Continue")) = True + isContinue _ = False +checkPolicyTermination _ = [] + +-- ─── Pattern cycle detection ───────────────────────────────────────────────── + +checkPatternCycles :: [Decl] -> [CheckError] +checkPatternCycles decls = + [ PatternCycle c + | c <- findCycles graph + ] + where + patDecls = [(n, p) | DPattern n _ p <- decls] + graph = Map.fromList [(n, nub (refsInPat p)) | (n,p) <- patDecls] + allPats = Set.fromList (map fst patDecls) + + refsInPat :: Pat -> [String] + refsInPat (PNamed r) = [r | Set.member r allPats] + refsInPat (PCtor _ ps) = concatMap refsInPat ps + refsInPat (PTuple ps) = concatMap refsInPat ps + refsInPat (PFrame _ p) = refsInPat p + refsInPat _ = [] + +findCycles :: Map.Map String [String] -> [[String]] +findCycles graph = go Set.empty Set.empty [] (Map.keys graph) + where + go _ _ _ [] = [] + go visited onPath path (n:ns) + | Set.member n visited = go visited onPath path ns + | Set.member n onPath = [path] + | otherwise = + let onPath' = Set.insert n onPath + path' = path ++ [n] + deps = Map.findWithDefault [] n graph + cycles = go visited onPath' path' deps + in cycles ++ go (Set.insert n visited) onPath path ns diff --git a/src/FWL/Compile.hs b/src/FWL/Compile.hs new file mode 100644 index 0000000..83227e4 --- /dev/null +++ b/src/FWL/Compile.hs @@ -0,0 +1,313 @@ +{-# LANGUAGE OverloadedStrings #-} +{- | Compile a checked FWL program to nftables JSON using Aeson. + All policies (Filter and NAT) go into one table named by Config. + Layer stripping: Frame patterns that omit Ether compile identically + to those that include it. +-} +module FWL.Compile + ( compileProgram + , compileToJson + ) where + +import Data.List (intercalate) +import Data.Maybe (mapMaybe) +import qualified Data.Map.Strict as Map +import Data.Aeson ((.=), Value(..), object, toJSON) +import qualified Data.Aeson as A +import qualified Data.Text as T +import qualified Data.ByteString.Lazy as BL +import Data.Aeson.Encode.Pretty (encodePretty) + +import FWL.AST + +-- ─── Entry points ──────────────────────────────────────────────────────────── + +compileToJson :: Program -> BL.ByteString +compileToJson = encodePretty . programToValue + +-- exposed for tests +compileProgram :: Program -> Value +compileProgram = programToValue + +programToValue :: Program -> Value +programToValue (Program cfg decls) = + object [ "nftables" .= toJSON + (metainfo : tableObj : chainObjs ++ mapObjs ++ ruleObjs) ] + where + env = buildEnv decls + tbl = configTable cfg + + metainfo = object [ "metainfo" .= object + [ "json_schema_version" .= (1 :: Int) ] ] + tableObj = object [ "table" .= tableValue tbl ] + + policies = [ (n, pm, ab) | DPolicy n _ pm ab <- decls ] + chainObjs = map (\(n, pm, _ ) -> chainDeclValue tbl n pm) policies + ruleObjs = concatMap + (\(n, _, ab) -> concatMap (armToRuleValues env tbl n) ab) + policies + + letDecls = [ (n, t, e) | DLet n t e <- decls ] + mapObjs = mapMaybe (\(n, _, e) -> letToMapValue tbl n e) letDecls + +-- ─── Table / Chain declarations ────────────────────────────────────────────── + +tableValue :: String -> Value +tableValue tbl = object + [ "family" .= ("inet" :: String) + , "name" .= tbl + ] + +chainDeclValue :: String -> Name -> PolicyMeta -> Value +chainDeclValue tbl n pm = object + [ "chain" .= object + [ "family" .= ("inet" :: String) + , "table" .= tbl + , "name" .= n + , "type" .= chainTypeStr (pmTable pm) + , "hook" .= hookStr (pmHook pm) + , "prio" .= priorityInt (pmPriority pm) + , "policy" .= defaultPolicyStr (pmHook pm) + ] + ] + +chainTypeStr :: TableName -> String +chainTypeStr TFilter = "filter" +chainTypeStr TNAT = "nat" + +hookStr :: Hook -> String +hookStr HInput = "input" +hookStr HForward = "forward" +hookStr HOutput = "output" +hookStr HPrerouting = "prerouting" +hookStr HPostrouting = "postrouting" + +-- Priority is emitted as an integer in nftables JSON. +priorityInt :: Priority -> Int +priorityInt = priorityValue + +defaultPolicyStr :: Hook -> String +defaultPolicyStr HInput = "drop" +defaultPolicyStr HForward = "drop" +defaultPolicyStr _ = "accept" + +-- ─── Arm → Rule objects ────────────────────────────────────────────────────── + +armToRuleValues :: CompileEnv -> String -> Name -> Arm -> [Value] +armToRuleValues env tbl chain (Arm p mg body) = + case compileAction env body of + Nothing -> [] + Just verdict -> + let patExprs = compilePat env p + guardExprs = maybe [] (compileGuard env) mg + allExprs = patExprs ++ guardExprs ++ [verdict] + in [ object + [ "rule" .= object + [ "family" .= ("inet" :: String) + , "table" .= tbl + , "chain" .= chain + , "expr" .= toJSON allExprs + ] + ] + ] + +-- ─── Pattern → [Value] ─────────────────────────────────────────────────────── + +type CompileEnv = Map.Map String Decl + +buildEnv :: [Decl] -> CompileEnv +buildEnv = foldr (\d m -> Map.insert (declNameOf d) d m) Map.empty + where + declNameOf (DInterface n _ _) = n + declNameOf (DZone n _) = n + declNameOf (DPattern n _ _) = n + declNameOf (DFlow n _) = n + declNameOf (DRule n _ _) = n + declNameOf (DPolicy n _ _ _) = n + declNameOf (DLet n _ _) = n + declNameOf (DImport n _ _) = n + +compilePat :: CompileEnv -> Pat -> [Value] +compilePat _ PWild = [] +compilePat _ (PVar _) = [] +compilePat env (PNamed n) = expandNamedPat env n +compilePat env (PFrame mp inner) = + maybe [] (compilePathPat env) mp ++ compilePat env inner +compilePat env (PCtor n ps) = compileCtorPat env n ps +compilePat _ (PRecord n fs) = compileRecordPat n fs +compilePat env (PTuple ps) = concatMap (compilePat env) ps +compilePat _ (PBytes _) = [] + +expandNamedPat :: CompileEnv -> Name -> [Value] +expandNamedPat env n = + case Map.lookup n env of + Just (DPattern _ _ p) -> compilePat env p + _ -> [] + +compileCtorPat :: CompileEnv -> String -> [Pat] -> [Value] +compileCtorPat env ctor ps = case ctor of + "Ether" -> children + "IPv4" -> matchMeta "nfproto" "ipv4" : children + "IPv6" -> matchMeta "nfproto" "ipv6" : children + "TCP" -> matchPayload "th" "protocol" "tcp" : children + "UDP" -> matchPayload "th" "protocol" "udp" : children + "ICMPv6" -> matchPayload "ip6" "nexthdr" "ipv6-icmp" : children + "ICMP" -> matchPayload "ip" "protocol" "icmp" : children + _ -> children + where + children = concatMap (compilePat env) ps + +compileRecordPat :: String -> [FieldPat] -> [Value] +compileRecordPat proto = mapMaybe go + where + go (FPEq field lit) = Just (matchPayload proto field (renderLit lit)) + go _ = Nothing + +compilePathPat :: CompileEnv -> PathPat -> [Value] +compilePathPat _ (PathPat ms md) = + maybe [] (compileEndpoint "iifname") ms ++ + maybe [] (compileEndpoint "oifname") md + +compileEndpoint :: String -> EndpointPat -> [Value] +compileEndpoint _ EPWild = [] +compileEndpoint dir (EPName n) = [matchMeta dir n] +compileEndpoint dir (EPMember _ z) = [matchInSet (metaVal dir) [z]] + +-- ─── Guard → [Value] ───────────────────────────────────────────────────────── + +compileGuard :: CompileEnv -> Expr -> [Value] +compileGuard env (EInfix OpAnd l r) = compileGuard env l ++ compileGuard env r +compileGuard _ (EInfix OpIn l r) = [compileInExpr l r] +compileGuard _ (EInfix OpEq l r) = [matchExpr "==" (exprVal l) (exprVal r)] +compileGuard _ (EInfix OpNeq l r) = [matchExpr "!=" (exprVal l) (exprVal r)] +compileGuard _ _ = [] + +compileInExpr :: Expr -> Expr -> Value +-- Fix 4: put the more-specific ct patterns BEFORE the generic 2-element +-- EQual case to eliminate the overlapping pattern match warning. +compileInExpr (EQual ["ct", "state"]) (ESet vs) = ctMatch "state" vs +compileInExpr (EQual ["ct", "status"]) (ESet vs) = ctMatch "status" vs +compileInExpr l (ESet vs) = + matchExpr "in" (exprVal l) (setVal (map exprToStr vs)) +compileInExpr l r = + matchExpr "==" (exprVal l) (exprVal r) + +ctMatch :: String -> [Expr] -> Value +ctMatch key vs = matchExpr "in" + (object ["ct" .= object ["key" .= (key :: String)]]) + (setVal (map exprToStr vs)) + +-- ─── Action → Maybe Value ───────────────────────────────────────────────────── + +compileAction :: CompileEnv -> Expr -> Maybe Value +compileAction _ (EVar "Allow") = Just (object ["accept" .= Null]) +compileAction _ (EVar "Drop") = Just (object ["drop" .= Null]) +compileAction _ (EVar "Continue") = Nothing +compileAction _ (EVar "Masquerade") = Just (object ["masquerade" .= Null]) +compileAction _ (EApp (EVar "DNAT") arg) = + Just $ object ["dnat" .= object ["addr" .= exprToStr arg]] +compileAction _ (EApp (EVar "DNATMap") arg) = + Just $ object ["dnat" .= object ["addr" .= object + [ "map" .= object [ "key" .= object ["concat" .= Array mempty] + , "data" .= exprToStr arg ]]]] +compileAction env (EApp (EVar rn) _) = + case Map.lookup rn env of + Just (DRule _ _ _) -> Just $ object ["jump" .= object ["target" .= rn]] + _ -> Just (object ["accept" .= Null]) +compileAction _ _ = Just (object ["accept" .= Null]) + +-- ─── Let → Map object ──────────────────────────────────────────────────────── + +letToMapValue :: String -> Name -> Expr -> Maybe Value +letToMapValue tbl n (EMap entries) = Just $ object + [ "map" .= object + [ "family" .= ("inet" :: String) + , "table" .= tbl + , "name" .= n + , "type" .= ("inetproto . inetservice" :: String) + , "map" .= ("ipv4_addr . inetservice" :: String) + , "elem" .= toJSON (map renderMapElem entries) + ] + ] +letToMapValue _ _ _ = Nothing + +renderMapElem :: (Expr, Expr) -> Value +renderMapElem (k, v) = toJSON + [ object ["concat" .= toJSON [exprToStr k]] + , A.String (toText (exprToStr v)) + ] + +-- ─── Aeson building blocks ─────────────────────────────────────────────────── + +matchExpr :: String -> Value -> Value -> Value +matchExpr op l r = object + [ "match" .= object + [ "op" .= (op :: String) + , "left" .= l + , "right" .= r + ] + ] + +matchMeta :: String -> String -> Value +matchMeta key val = matchExpr "==" (metaVal key) (A.String (toText val)) + +matchPayload :: String -> String -> String -> Value +matchPayload proto field val = + matchExpr "==" (payloadVal proto field) (A.String (toText val)) + +matchInSet :: Value -> [String] -> Value +matchInSet lhs vals = matchExpr "in" lhs (setVal vals) + +metaVal :: String -> Value +metaVal key = object ["meta" .= object ["key" .= (key :: String)]] + +payloadVal :: String -> String -> Value +payloadVal proto field = + object ["payload" .= object + [ "protocol" .= (proto :: String) + , "field" .= (field :: String) + ]] + +setVal :: [String] -> Value +setVal vs = object ["set" .= toJSON vs] + +-- ─── Expression helpers ─────────────────────────────────────────────────────── + +-- Fix 3 (overlap): specific ct pattern first, generic 2-element case second. +exprVal :: Expr -> Value +exprVal (EQual ["ct", k]) = object ["ct" .= object ["key" .= (k :: String)]] +exprVal (EQual [p, f]) = payloadVal p f +exprVal (EQual ns) = A.String (toText (intercalate "." ns)) +exprVal (EVar n) = metaVal n +exprVal (ELit l) = A.String (toText (renderLit l)) +exprVal (ESet vs) = setVal (map exprToStr vs) +exprVal e = A.String (toText (exprToStr e)) + +exprToStr :: Expr -> String +exprToStr (EVar n) = n +exprToStr (ELit l) = renderLit l +exprToStr (EQual ns) = intercalate "." ns +exprToStr (ETuple es) = intercalate " . " (map exprToStr es) +exprToStr _ = "_" + +-- Fix 2: Use Data.Text.pack via OverloadedStrings + fromString instead of +-- the fragile read(show s) hack. With OverloadedStrings enabled, string +-- literals already produce the correct Text/Key types; for runtime String +toText :: String -> T.Text +toText = T.pack + +renderLit :: Literal -> String +renderLit (LInt n) = show n +renderLit (LString s) = s +renderLit (LBool True) = "true" +renderLit (LBool False) = "false" +renderLit (LIPv4 (a, b, c, d)) = + show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d +renderLit (LIPv6 _) = "::1" +renderLit (LCIDR ip p) = renderLit ip ++ "/" ++ show p +renderLit (LPort p) = show p +renderLit (LDuration n Seconds) = show n ++ "s" +renderLit (LDuration n Millis) = show n ++ "ms" +renderLit (LDuration n Minutes) = show n ++ "m" +renderLit (LDuration n Hours) = show n ++ "h" +renderLit (LHex b) = show b diff --git a/src/FWL/Lexer.hs b/src/FWL/Lexer.hs new file mode 100644 index 0000000..b6511e6 --- /dev/null +++ b/src/FWL/Lexer.hs @@ -0,0 +1,101 @@ +module FWL.Lexer where + +import Text.Parsec +import Text.Parsec.String (Parser) +import qualified Text.Parsec.Token as Tok +import Text.Parsec.Language (emptyDef) + +-- ─── Language definition ───────────────────────────────────────────────────── + +fwlDef :: Tok.LanguageDef () +fwlDef = emptyDef + { Tok.commentLine = "--" + , Tok.commentStart = "{-" + , Tok.commentEnd = "-}" + , Tok.identStart = letter <|> char '_' + , Tok.identLetter = alphaNum <|> char '_' + , Tok.reservedNames = + -- Only genuine syntactic keywords belong here. + -- Semantic values used as constructors, actions, type names, or + -- pattern references (Allow, Drop, Log, Matched, Frame, etc.) must + -- NOT be reserved so that `identifier` can consume them in those + -- positions. + [ "config", "table" + , "interface", "zone", "import", "from" + , "let", "in", "pattern", "flow", "rule", "policy", "on" + , "case", "of", "if", "then", "else", "do", "perform" + , "within", "as", "dynamic", "cidr4", "cidr6" + , "hook", "priority" + , "WAN", "LAN", "WireGuard" + , "Input", "Forward", "Output", "Prerouting", "Postrouting" + , "Filter", "NAT", "Mangle", "DstNat", "SrcNat", "Raw", "ConnTrack" + , "true", "false" + ] + , Tok.reservedOpNames = + [ "->", "<-", "=>", "::", ":", "=", ".", ".." + , "\\", "|", "," + , "&&", "||", "!", "==" , "!=", "<", "<=", ">", ">=" + , "++", ">>", ">>=" + , "∈" + ] + , Tok.caseSensitive = True + } + +lexer :: Tok.TokenParser () +lexer = Tok.makeTokenParser fwlDef + +-- ─── Token helpers ─────────────────────────────────────────────────────────── + +identifier :: Parser String +identifier = Tok.identifier lexer + +reserved :: String -> Parser () +reserved = Tok.reserved lexer + +reservedOp :: String -> Parser () +reservedOp = Tok.reservedOp lexer + +symbol :: String -> Parser String +symbol = Tok.symbol lexer + +parens :: Parser a -> Parser a +parens = Tok.parens lexer + +braces :: Parser a -> Parser a +braces = Tok.braces lexer + +angles :: Parser a -> Parser a +angles = Tok.angles lexer + +brackets :: Parser a -> Parser a +brackets = Tok.brackets lexer + +semi :: Parser String +semi = Tok.semi lexer + +comma :: Parser String +comma = Tok.comma lexer + +colon :: Parser String +colon = Tok.colon lexer + +dot :: Parser String +dot = Tok.dot lexer + +whiteSpace :: Parser () +whiteSpace = Tok.whiteSpace lexer + +stringLit :: Parser String +stringLit = Tok.stringLiteral lexer + +natural :: Parser Integer +natural = Tok.natural lexer + +commaSep :: Parser a -> Parser [a] +commaSep = Tok.commaSep lexer + +commaSep1 :: Parser a -> Parser [a] +commaSep1 = Tok.commaSep1 lexer + +semiSep :: Parser a -> Parser [a] +semiSep = Tok.semiSep lexer diff --git a/src/FWL/Parser.hs b/src/FWL/Parser.hs new file mode 100644 index 0000000..24a1dfd --- /dev/null +++ b/src/FWL/Parser.hs @@ -0,0 +1,659 @@ +module FWL.Parser + ( parseProgram + , parseFile + ) where + +import Control.Monad (void) +import Data.Bits ((.&.), (.|.), shiftL) +import Data.List (foldl') +import Data.Word (Word8) +import Numeric (readHex) +import Text.Parsec +import Text.Parsec.String (Parser) +import Data.Functor.Identity (Identity) +import qualified Text.Parsec.Expr as Ex + +import FWL.AST +import FWL.Lexer + +-- ─── Entry points ──────────────────────────────────────────────────────────── + +parseProgram :: String -> String -> Either ParseError Program +parseProgram src input = parse program src input + +parseFile :: FilePath -> IO (Either ParseError Program) +parseFile fp = parseProgram fp <$> readFile fp + +-- ─── Top-level ─────────────────────────────────────────────────────────────── + +program :: Parser Program +program = do + whiteSpace + cfg <- option defaultConfig configBlock + ds <- many decl + eof + return (Program cfg ds) + +configBlock :: Parser Config +configBlock = do + reserved "config" + props <- braces (semiSep configProp) + optional semi + return $ foldr applyProp defaultConfig props + where + applyProp ("table", v) c = c { configTable = v } + applyProp _ c = c + +configProp :: Parser (String, String) +configProp = do + reserved "table" + reservedOp "=" + v <- stringLit + return ("table", v) + +-- ─── Declarations ──────────────────────────────────────────────────────────── + +decl :: Parser Decl +decl = interfaceDecl + <|> zoneDecl + <|> importDecl + <|> letDecl + <|> patternDecl + <|> flowDecl + <|> ruleDecl + <|> policyDecl + +interfaceDecl :: Parser Decl +interfaceDecl = do + reserved "interface" + n <- identifier + reservedOp ":" + k <- ifaceKind + ps <- braces (endBy ifaceProp semi) + _ <- semi + return (DInterface n k ps) + +ifaceKind :: Parser IfaceKind +ifaceKind = (reserved "WAN" >> return IWan) + <|> (reserved "LAN" >> return ILan) + <|> (reserved "WireGuard" >> return IWireGuard) + <|> (IUser <$> identifier) + +ifaceProp :: Parser IfaceProp +ifaceProp = (reserved "dynamic" >> return IPDynamic) + <|> (reserved "cidr4" >> reservedOp "=" >> IPCidr4 <$> cidrSet) + <|> (reserved "cidr6" >> reservedOp "=" >> IPCidr6 <$> cidrSet) + +cidrSet :: Parser [CIDR] +cidrSet = braces (commaSep1 cidrLit) + +zoneDecl :: Parser Decl +zoneDecl = do + reserved "zone" + n <- identifier + reservedOp "=" + ns <- braces (commaSep1 identifier) + _ <- semi + return (DZone n ns) + +importDecl :: Parser Decl +importDecl = do + reserved "import" + n <- identifier + reservedOp ":" + t <- typeP + reserved "from" + s <- stringLit + _ <- semi + return (DImport n t s) + +letDecl :: Parser Decl +letDecl = do + reserved "let" + n <- identifier + reservedOp ":" + t <- typeP + reservedOp "=" + e <- expr + _ <- semi + return (DLet n t e) + +patternDecl :: Parser Decl +patternDecl = do + reserved "pattern" + n <- identifier + reservedOp ":" + t <- typeP + reservedOp "=" + p <- pat + _ <- semi + return (DPattern n t p) + +flowDecl :: Parser Decl +flowDecl = do + reserved "flow" + n <- identifier + reservedOp ":" + reserved "FlowPattern" + reservedOp "=" + f <- flowExpr + _ <- semi + return (DFlow n f) + +ruleDecl :: Parser Decl +ruleDecl = do + reserved "rule" + n <- identifier + reservedOp ":" + t <- typeP + reservedOp "=" + e <- expr + _ <- semi + return (DRule n t e) + +policyDecl :: Parser Decl +policyDecl = do + reserved "policy" + n <- identifier + reservedOp ":" + t <- typeP + reserved "on" + pm <- braces policyMeta + reservedOp "=" + ab <- armBlock + _ <- semi + return (DPolicy n t pm ab) + +policyMeta :: Parser PolicyMeta +policyMeta = do + props <- commaSep1 metaProp + let h = foldr (\p a -> case p of Left v -> v; _ -> a) HInput props + tb = foldr (\p a -> case p of Right (Left v) -> v; _ -> a) TFilter props + pr = foldr (\p a -> case p of Right (Right v) -> v; _ -> a) pFilter props + return (PolicyMeta h tb pr) + +metaProp :: Parser (Either Hook (Either TableName Priority)) +metaProp + = (reserved "hook" >> reservedOp "=" >> fmap (Left) hookP) + <|> (reserved "table" >> reservedOp "=" >> fmap (Right . Left) tableNameP) + <|> (reserved "priority" >> reservedOp "=" >> fmap (Right . Right) priorityP) + +hookP :: Parser Hook +hookP = (reserved "Input" >> return HInput) + <|> (reserved "Forward" >> return HForward) + <|> (reserved "Output" >> return HOutput) + <|> (reserved "Prerouting" >> return HPrerouting) + <|> (reserved "Postrouting" >> return HPostrouting) + +tableNameP :: Parser TableName +tableNameP = (reserved "Filter" >> return TFilter) + <|> (reserved "NAT" >> return TNAT) + +priorityP :: Parser Priority +priorityP + = (reserved "Filter" >> return pFilter) + <|> (reserved "DstNat" >> return pDstNat) + <|> (reserved "SrcNat" >> return pSrcNat) + <|> (reserved "Mangle" >> return pMangle) + <|> (reserved "Raw" >> return pRaw) + <|> (reserved "ConnTrack" >> return pConnTrack) + <|> (Priority . fromIntegral <$> integerP) + where + -- Accept optional leading minus for negative priorities + integerP = do + neg <- option 1 (char '-' >> return (-1)) + n <- natural + whiteSpace + return (neg * fromIntegral n) + +-- ─── Arm blocks ────────────────────────────────────────────────────────────── + +armBlock :: Parser ArmBlock +armBlock = braces (many arm) + +arm :: Parser Arm +arm = do + _ <- symbol "|" + p <- pat + g <- optionMaybe (reserved "if" >> expr) + reservedOp "->" + e <- expr + _ <- semi + return (Arm p g e) + +-- ─── Patterns ──────────────────────────────────────────────────────────────── + +pat :: Parser Pat +pat = wildcardPat + <|> try framePat + <|> try tuplePat + <|> bytesPat + <|> try recordPat + <|> try namedOrCtorPat + +wildcardPat :: Parser Pat +wildcardPat = symbol "_" >> return PWild + +-- Frame(...) — optional path then inner pattern +-- Layer stripping: if the inner pattern is not Ether/IPv4/IPv6/etc the +-- type-checker will peel outer layers automatically. Parser just stores +-- whatever the user wrote. +framePat :: Parser Pat +framePat = do + reserved "Frame" + (mp, inner) <- parens frameArgs + return (PFrame mp inner) + +frameArgs :: Parser (Maybe PathPat, Pat) +frameArgs = try withPath <|> withoutPath + where + withPath = do + pp <- pathPat + _ <- comma + inner <- pat + return (Just pp, inner) + withoutPath = do + inner <- pat + return (Nothing, inner) + +pathPat :: Parser PathPat +pathPat = do + src <- optionMaybe (try endpointPat) + dst <- optionMaybe (try (reservedOp "->" >> endpointPat)) + case (src, dst) of + (Nothing, Nothing) -> fail "empty path pattern" + _ -> return (PathPat src dst) + +endpointPat :: Parser EndpointPat +endpointPat + = (symbol "_" >> return EPWild) + <|> try (do n <- identifier + memberOp + z <- identifier + return (EPMember n z)) + <|> (EPName <$> identifier) + +memberOp :: Parser () +memberOp = (reservedOp "∈" <|> reserved "in") >> return () + +tuplePat :: Parser Pat +tuplePat = do + ps <- parens (commaSep2 pat) + return (PTuple ps) + +commaSep2 :: Parser a -> Parser [a] +commaSep2 p = do + x <- p + _ <- comma + xs <- commaSep1 p + return (x:xs) + +bytesPat :: Parser Pat +bytesPat = brackets (PBytes <$> many byteElem) + +byteElem :: Parser ByteElem +byteElem + = try (symbol "_*" >> return BEWildStar) + <|> try (symbol "_" >> return BEWild) + <|> (BEHex <$> hexByte) + +hexByte :: Parser Word8 +hexByte = do + void (string "0x") + h1 <- hexDigit + h2 <- hexDigit + whiteSpace + case (readHex [h1,h2] :: [(Integer, String)]) of + [(v,"")] -> return (fromIntegral v) + _ -> fail "invalid hex byte" + +-- Record pattern: ident { fields } +recordPat :: Parser Pat +recordPat = do + n <- identifier + fs <- braces (commaSep fieldPat) + return (PRecord n fs) + +fieldPat :: Parser FieldPat +fieldPat = do + n <- identifier + try (reservedOp "=" >> FPEq n <$> fieldLiteral) + <|> try (reserved "as" >> FPAs n <$> identifier) + <|> return (FPBind n) + +-- Port literals (:22) are valid in record field position as well as plain literals. +fieldLiteral :: Parser Literal +fieldLiteral = try portLit <|> literal + where + portLit = do + void (char ':') + n <- fromIntegral <$> natural + return (LPort n) + +-- Named pattern reference OR constructor: starts with uppercase-ish ident +namedOrCtorPat :: Parser Pat +namedOrCtorPat = do + n <- identifier + args <- optionMaybe (try (parens (commaSep pat))) + case args of + Nothing -> return (PNamed n) -- bare name = named pattern ref + Just ps -> return (PCtor n ps) + +-- ─── Flow expressions ──────────────────────────────────────────────────────── + +flowExpr :: Parser FlowExpr +flowExpr = do + first <- FAtom <$> identifier + rest <- many (reservedOp "." >> identifier) + mw <- optionMaybe (reserved "within" >> durationLit) + return $ buildSeq (first : map FAtom rest) mw + where + buildSeq [x] mw = case mw of + Nothing -> x + Just w -> FSeq x x (Just w) -- degenerate + buildSeq (x:xs) mw = FSeq x (buildSeq xs mw) mw + buildSeq [] _ = error "impossible" + +durationLit :: Parser Duration +durationLit = do + n <- fromIntegral <$> natural + u <- (char 's' >> return Seconds) + <|> (string "ms" >> return Millis) + <|> (char 'm' >> return Minutes) + <|> (char 'h' >> return Hours) + whiteSpace + return (n, u) + +-- ─── Types ─────────────────────────────────────────────────────────────────── + +typeP :: Parser Type +typeP = do + t <- baseType + option t (reservedOp "->" >> TFun t <$> typeP) + +baseType :: Parser Type +baseType + = effectType + <|> try tupleTy + <|> simpleTy + +effectType :: Parser Type +effectType = do + effs <- angles (commaSep identifier) + t <- simpleTy + return (TEffect effs t) + +tupleTy :: Parser Type +tupleTy = TTuple <$> parens (commaSep2 typeP) + +simpleTy :: Parser Type +simpleTy = do + n <- identifier + args <- option [] (angles (commaSep typeP)) + return (TName n args) + +-- ─── Expressions ───────────────────────────────────────────────────────────── + +expr :: Parser Expr +expr = lamExpr + <|> ifExpr + <|> doExpr + <|> caseExpr + <|> letExpr + <|> infixExpr + +lamExpr :: Parser Expr +lamExpr = do + reservedOp "\\" + n <- identifier + reservedOp "->" + e <- expr + return (ELam n e) + +ifExpr :: Parser Expr +ifExpr = do + reserved "if" + c <- expr + reserved "then" + t <- expr + reserved "else" + f <- expr + return (EIf c t f) + +doExpr :: Parser Expr +doExpr = reserved "do" >> braces (EDo <$> semiSep doStmt) + +doStmt :: Parser DoStmt +doStmt = try bindStmt <|> (DSExpr <$> expr) + +bindStmt :: Parser DoStmt +bindStmt = do + n <- identifier + reservedOp "<-" + e <- expr + return (DSBind n e) + +caseExpr :: Parser Expr +caseExpr = do + reserved "case" + e <- expr + reserved "of" + ab <- armBlock + return (ECase e ab) + +letExpr :: Parser Expr +letExpr = do + reserved "let" + n <- identifier + reservedOp "=" + e1 <- expr + reserved "in" + e2 <- expr + return (ELet n e1 e2) + +-- Operator table for infix expressions +infixExpr :: Parser Expr +infixExpr = Ex.buildExpressionParser opTable appExpr + +opTable :: Ex.OperatorTable String () Identity Expr +opTable = + [ [ prefix "!" ENot ] + , [ infixL "==" OpEq, infixL "!=" OpNeq + , infixL "<" OpLt, infixL "<=" OpLte + , infixL ">" OpGt, infixL ">=" OpGte + , infixIn ] + , [ infixR "&&" OpAnd ] + , [ infixR "||" OpOr ] + , [ infixR "++" OpConcat ] + , [ infixL ">>=" OpBind ] + , [ infixL ">>" OpThen ] + ] + where + prefix op f = Ex.Prefix (reservedOp op >> return f) + infixL op c = Ex.Infix (reservedOp op >> return (EInfix c)) Ex.AssocLeft + infixR op c = Ex.Infix (reservedOp op >> return (EInfix c)) Ex.AssocRight + infixIn = Ex.Infix + ((memberOp <|> reserved "in") >> return (EInfix OpIn)) + Ex.AssocNone + +appExpr :: Parser Expr +appExpr = do + f <- atom + args <- many atom + return (foldl EApp f args) + +atom :: Parser Expr +atom + = try performExpr + <|> try mapLit + <|> try setLit + <|> try tupleLit + <|> try (parens expr) + <|> try litExpr + <|> try portExpr + <|> qualNameExpr + +performExpr :: Parser Expr +performExpr = do + reserved "perform" + parts <- sepBy1 identifier dot + args <- parens (commaSep expr) + return (EPerform parts args) + +qualNameExpr :: Parser Expr +qualNameExpr = do + parts <- sepBy1 identifier (try (dot <* notFollowedBy digit)) + case parts of + [n] -> return (EVar n) + ns -> return (EQual ns) + +litExpr :: Parser Expr +litExpr = ELit <$> literal + +portExpr :: Parser Expr +portExpr = do + void (char ':') + n <- fromIntegral <$> natural + return (ELit (LPort n)) + +tupleLit :: Parser Expr +tupleLit = ETuple <$> parens (commaSep2 expr) + +setLit :: Parser Expr +setLit = braces $ do + items <- commaSep expr + return (ESet items) + +-- map literal: { expr -> expr, ... } +mapLit :: Parser Expr +mapLit = braces $ do + entries <- commaSep1 mapEntry + return (EMap entries) + +mapEntry :: Parser (Expr, Expr) +mapEntry = do + k <- expr + reservedOp "->" + v <- expr + return (k, v) + +-- ─── Literals ──────────────────────────────────────────────────────────────── + +literal :: Parser Literal +literal + = try ipOrCidrLit + <|> try hexLit + <|> try (LBool True <$ reserved "true") + <|> try (LBool False <$ reserved "false") + <|> try (LString <$> stringLit) + <|> try (LInt . fromIntegral <$> natural) + +hexLit :: Parser Literal +hexLit = LHex <$> hexByte + +-- ─── IP / CIDR parsing ─────────────────────────────────────────────────────── + +-- | Parse an IPv4 or IPv6 address, optionally followed by /prefix. +-- Tries IPv6 first (it can start with hex digits too), then IPv4. +ipOrCidrLit :: Parser Literal +ipOrCidrLit = do + ip <- try ipv6Lit <|> ipv4Lit_ + mPrefix <- optionMaybe (char '/' >> fromIntegral <$> natural) + whiteSpace + return $ case mPrefix of + Nothing -> ip + Just p -> LCIDR ip p + +-- | IPv4: four decimal octets separated by dots → LIP IPv4 (32-bit Integer) +ipv4Lit_ :: Parser Literal +ipv4Lit_ = do + a <- octet + void (char '.') + b <- octet + void (char '.') + c <- octet + void (char '.') + d <- octet + return $ LIP IPv4 + ( fromIntegral a `shiftL` 24 + .|. fromIntegral b `shiftL` 16 + .|. fromIntegral c `shiftL` 8 + .|. fromIntegral d) + where + octet = do + n <- fromIntegral <$> natural + if n > 255 then fail "octet out of range" else return n + +-- | IPv6: full notation, :: abbreviation, and optional embedded IPv4. +-- Stores as LIP IPv6 (128-bit Integer). +ipv6Lit :: Parser Literal +ipv6Lit = do + (left, right) <- ipv6Groups + let missing = 8 - length left - length right + when (missing < 0) $ fail "too many groups in IPv6 address" + let groups = left ++ replicate missing 0 ++ right + when (length groups /= 8) $ fail "invalid IPv6 address" + let val = foldl' (\acc g -> (acc `shiftL` 16) .|. fromIntegral g) (0::Integer) groups + return (LIP IPv6 val) + +-- Returns (left-of-::, right-of-::). +-- If no :: present, left has all 8 groups and right is empty. +ipv6Groups :: Parser ([Int], [Int]) +ipv6Groups = do + -- must start with a hex digit or ':' (for ::) + ahead <- lookAhead (hexDigit <|> char ':') + case ahead of + ':' -> do + void (string "::") + right <- ipv6RightGroups + return ([], right) + _ -> do + left <- ipv6LeftGroups + mDbl <- optionMaybe (try (string "::")) + case mDbl of + Nothing -> return (left, []) + Just _ -> do + right <- ipv6RightGroups + return (left, right) + +-- Parse a run of hex16:hex16:... stopping before :: or end +ipv6LeftGroups :: Parser [Int] +ipv6LeftGroups = do + first <- hex16 + rest <- many (try (char ':' >> notFollowedBy (char ':') >> hex16)) + return (first : rest) + +-- Parse groups to the right of ::, including optional embedded IPv4 +ipv6RightGroups :: Parser [Int] +ipv6RightGroups = option [] $ + try ipv4EmbeddedGroups <|> ipv6LeftGroups + +-- IPv4-mapped groups: e.g. ffff:192.168.1.1 -> [0xffff, 0xc0a8, 0x0101] +ipv4EmbeddedGroups :: Parser [Int] +ipv4EmbeddedGroups = do + prefix <- many (try (hex16 <* char ':' <* lookAhead digit)) + a <- octet_; void (char '.') + b <- octet_; void (char '.') + c <- octet_; void (char '.') + d <- octet_ + let hi = (a `shiftL` 8) .|. b + lo = (c `shiftL` 8) .|. d + return (prefix ++ [hi, lo]) + where + octet_ = do + n <- fromIntegral <$> natural + if n > 255 then fail "IPv4 octet out of range" else return n + +hex16 :: Parser Int +hex16 = do + digits <- many1 hexDigit + case (reads ("0x" ++ digits)) :: [(Int,String)] of + [(v,"")] -> if v > 0xffff then fail "hex16 out of range" else return v + _ -> fail "invalid hex group" + +cidrLit :: Parser CIDR +cidrLit = do + l <- ipOrCidrLit + case l of + LCIDR ip p -> return (ip, p) + _ -> fail "expected CIDR notation (address/prefix)" diff --git a/src/FWL/Pretty.hs b/src/FWL/Pretty.hs new file mode 100644 index 0000000..a2430fa --- /dev/null +++ b/src/FWL/Pretty.hs @@ -0,0 +1,187 @@ +-- | Pretty printer: round-trips the AST back to FWL source. +module FWL.Pretty (prettyProgram) where + +import Data.List (intercalate) +import FWL.AST + +prettyProgram :: Program -> String +prettyProgram (Program cfg ds) = + prettyConfig cfg ++ "\n" ++ unlines (map prettyDecl ds) + +prettyConfig :: Config -> String +prettyConfig (Config t) + | t == "fwl" = "" + | otherwise = "config { table = \"" ++ t ++ "\"; }\n" + +prettyDecl :: Decl -> String +prettyDecl (DInterface n k ps) = + "interface " ++ n ++ " : " ++ prettyKind k ++ " {\n" ++ + concatMap (\p -> " " ++ prettyIfaceProp p ++ ";\n") ps ++ + "};" +prettyDecl (DZone n ns) = + "zone " ++ n ++ " = { " ++ intercalate ", " ns ++ " };" +prettyDecl (DImport n t s) = + "import " ++ n ++ " : " ++ prettyType t ++ " from \"" ++ s ++ "\";" +prettyDecl (DLet n t e) = + "let " ++ n ++ " : " ++ prettyType t ++ " = " ++ prettyExpr e ++ ";" +prettyDecl (DPattern n t p) = + "pattern " ++ n ++ " : " ++ prettyType t ++ " = " ++ prettyPat p ++ ";" +prettyDecl (DFlow n f) = + "flow " ++ n ++ " : FlowPattern = " ++ prettyFlow f ++ ";" +prettyDecl (DRule n t e) = + "rule " ++ n ++ " : " ++ prettyType t ++ " =\n " ++ prettyExpr e ++ ";" +prettyDecl (DPolicy n t pm ab) = + "policy " ++ n ++ " : " ++ prettyType t ++ "\n" ++ + " on { hook = " ++ prettyHook (pmHook pm) ++ + ", table = " ++ prettyTable (pmTable pm) ++ + ", priority = " ++ prettyPriority (pmPriority pm) ++ " }\n" ++ + " = " ++ prettyArmBlock ab ++ ";" + +prettyKind :: IfaceKind -> String +prettyKind IWan = "WAN" +prettyKind ILan = "LAN" +prettyKind IWireGuard = "WireGuard" +prettyKind (IUser n) = n + +prettyIfaceProp :: IfaceProp -> String +prettyIfaceProp IPDynamic = "dynamic" +prettyIfaceProp (IPCidr4 cs) = "cidr4 = { " ++ intercalate ", " (map prettyCidr cs) ++ " }" +prettyIfaceProp (IPCidr6 cs) = "cidr6 = { " ++ intercalate ", " (map prettyCidr cs) ++ " }" + +prettyCidr :: CIDR -> String +prettyCidr (LIPv4 (a,b,c,d), p) = + show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ "/" ++ show p +prettyCidr (ip, p) = prettyLit ip ++ "/" ++ show p + +prettyHook :: Hook -> String +prettyHook HInput = "Input" +prettyHook HForward = "Forward" +prettyHook HOutput = "Output" +prettyHook HPrerouting = "Prerouting" +prettyHook HPostrouting = "Postrouting" + +prettyTable :: TableName -> String +prettyTable TFilter = "Filter" +prettyTable TNAT = "NAT" + +prettyPriority :: Priority -> String +prettyPriority p = show (priorityValue p) + +prettyType :: Type -> String +prettyType (TName n []) = n +prettyType (TName n ts) = n ++ "<" ++ intercalate ", " (map prettyType ts) ++ ">" +prettyType (TTuple ts) = "(" ++ intercalate ", " (map prettyType ts) ++ ")" +prettyType (TFun a b) = prettyType a ++ " -> " ++ prettyType b +prettyType (TEffect es t) = "<" ++ intercalate ", " es ++ "> " ++ prettyType t + +prettyPat :: Pat -> String +prettyPat PWild = "_" +prettyPat (PVar n) = n +prettyPat (PNamed n) = n +prettyPat (PCtor n ps) = n ++ "(" ++ intercalate ", " (map prettyPat ps) ++ ")" +prettyPat (PRecord n fs) = n ++ " { " ++ intercalate ", " (map prettyFP fs) ++ " }" +prettyPat (PTuple ps) = "(" ++ intercalate ", " (map prettyPat ps) ++ ")" +prettyPat (PFrame mp inner)= + "Frame(" ++ maybe "" (\pp -> prettyPath pp ++ ", ") mp ++ prettyPat inner ++ ")" +prettyPat (PBytes bs) = "[" ++ unwords (map prettyBE bs) ++ "]" + +prettyFP :: FieldPat -> String +prettyFP (FPEq n l) = n ++ " = " ++ prettyLit l +prettyFP (FPBind n) = n +prettyFP (FPAs n v) = n ++ " as " ++ v + +prettyPath :: PathPat -> String +prettyPath (PathPat ms md) = + maybe "_" prettyEP ms ++ maybe "" (\d -> " -> " ++ prettyEP d) md + +prettyEP :: EndpointPat -> String +prettyEP EPWild = "_" +prettyEP (EPName n) = n +prettyEP (EPMember n z) = n ++ " in " ++ z + +prettyBE :: ByteElem -> String +prettyBE (BEHex w) = "0x" ++ pad (show w) -- simplified + where pad s = if length s < 2 then '0':s else s +prettyBE BEWild = "_" +prettyBE BEWildStar = "_*" + +prettyFlow :: FlowExpr -> String +prettyFlow (FAtom n) = n +prettyFlow (FSeq a b mw) = + prettyFlow a ++ " . " ++ prettyFlow b ++ + maybe "" (\(n,u) -> " within " ++ show n ++ prettyUnit u) mw + +prettyUnit :: TimeUnit -> String +prettyUnit Seconds = "s" +prettyUnit Millis = "ms" +prettyUnit Minutes = "m" +prettyUnit Hours = "h" + +prettyExpr :: Expr -> String +prettyExpr (EVar n) = n +prettyExpr (EQual ns) = intercalate "." ns +prettyExpr (ELit l) = prettyLit l +prettyExpr (ELam n e) = "\\" ++ n ++ " -> " ++ prettyExpr e +prettyExpr (EApp f x) = prettyExpr f ++ " " ++ prettyAtom x +prettyExpr (ECase e ab) = + "case " ++ prettyExpr e ++ " of " ++ prettyArmBlock ab +prettyExpr (EIf c t f) = + "if " ++ prettyExpr c ++ " then " ++ prettyExpr t ++ " else " ++ prettyExpr f +prettyExpr (EDo ss) = + "do { " ++ intercalate "; " (map prettyStmt ss) ++ " }" +prettyExpr (ELet n e1 e2) = + "let " ++ n ++ " = " ++ prettyExpr e1 ++ " in " ++ prettyExpr e2 +prettyExpr (ETuple es) = "(" ++ intercalate ", " (map prettyExpr es) ++ ")" +prettyExpr (ESet es) = "{ " ++ intercalate ", " (map prettyExpr es) ++ " }" +prettyExpr (EMap ms) = + "{ " ++ intercalate ", " (map (\(k,v) -> prettyExpr k ++ " -> " ++ prettyExpr v) ms) ++ " }" +prettyExpr (EPerform ns as_) = + "perform " ++ intercalate "." ns ++ "(" ++ intercalate ", " (map prettyExpr as_) ++ ")" +prettyExpr (EInfix op l r) = + prettyAtom l ++ " " ++ prettyOp op ++ " " ++ prettyAtom r +prettyExpr (ENot e) = "!" ++ prettyAtom e + +prettyAtom :: Expr -> String +prettyAtom e@(EInfix _ _ _) = "(" ++ prettyExpr e ++ ")" +prettyAtom e@(ELam _ _) = "(" ++ prettyExpr e ++ ")" +prettyAtom e = prettyExpr e + +prettyOp :: InfixOp -> String +prettyOp OpAnd = "&&" +prettyOp OpOr = "||" +prettyOp OpEq = "==" +prettyOp OpNeq = "!=" +prettyOp OpLt = "<" +prettyOp OpLte = "<=" +prettyOp OpGt = ">" +prettyOp OpGte = ">=" +prettyOp OpIn = "in" +prettyOp OpConcat = "++" +prettyOp OpThen = ">>" +prettyOp OpBind = ">>=" + +prettyStmt :: DoStmt -> String +prettyStmt (DSBind n e) = n ++ " <- " ++ prettyExpr e +prettyStmt (DSExpr e) = prettyExpr e + +prettyArmBlock :: ArmBlock -> String +prettyArmBlock arms = + "{\n" ++ + concatMap (\(Arm p mg e) -> + " | " ++ prettyPat p ++ + maybe "" (\g -> " if " ++ prettyExpr g) mg ++ + " -> " ++ prettyExpr e ++ ";\n") arms ++ + " }" + +prettyLit :: Literal -> String +prettyLit (LInt n) = show n +prettyLit (LString s) = "\"" ++ s ++ "\"" +prettyLit (LBool True) = "true" +prettyLit (LBool False) = "false" +prettyLit (LIPv4 (a,b,c,d)) = + show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d +prettyLit (LIPv6 _) = "" +prettyLit (LCIDR ip p) = prettyLit ip ++ "/" ++ show p +prettyLit (LPort p) = ":" ++ show p +prettyLit (LDuration n u) = show n ++ prettyUnit u +prettyLit (LHex b) = "0x" ++ show b diff --git a/test/CheckTests.hs b/test/CheckTests.hs new file mode 100644 index 0000000..ad00f4e --- /dev/null +++ b/test/CheckTests.hs @@ -0,0 +1,224 @@ +module CheckTests (tests) where + +import Test.Tasty +import Test.Tasty.HUnit + +import FWL.Check +import FWL.Util + +tests :: TestTree +tests = testGroup "Check" + [ undefinedNameTests + , duplicateTests + , policyTerminationTests + , patternCycleTests + , cleanProgramTests + ] + +-- ─── Helper ────────────────────────────────────────────────────────────────── + +checkSrc :: String -> IO [CheckError] +checkSrc src = do + p <- parseOk src + return (checkProgram p) + +assertNoErrors :: String -> IO () +assertNoErrors src = do + errs <- checkSrc src + case errs of + [] -> return () + _ -> assertFailure ("Unexpected errors: " ++ show errs) + +assertHasError :: (CheckError -> Bool) -> String -> IO () +assertHasError p src = do + errs <- checkSrc src + if any p errs + then return () + else assertFailure ("Expected error not found. Got: " ++ show errs) + +isUndefined :: String -> CheckError -> Bool +isUndefined n (UndefinedName _ m) = m == n +isUndefined _ _ = False + +isDuplicate :: String -> CheckError -> Bool +isDuplicate n (DuplicateDecl _ m) = m == n +isDuplicate _ _ = False + +isNoContinue :: String -> CheckError -> Bool +isNoContinue n (PolicyNoContinue m) = m == n +isNoContinue _ _ = False + +isCycle :: CheckError -> Bool +isCycle (PatternCycle _) = True +isCycle _ = False + +-- ─── Undefined name tests ──────────────────────────────────────────────────── + +undefinedNameTests :: TestTree +undefinedNameTests = testGroup "undefined names" + [ testCase "zone references unknown interface" $ + assertHasError (isUndefined "ghost") + "zone bad_zone = { lan, ghost };" + + , testCase "zone references known interface — no error" $ + assertNoErrors + "interface lan : LAN {}; \ + \zone good = { lan };" + + , testCase "pattern references undefined named pattern" $ + assertHasError (isUndefined "Undefined") + "pattern Bad : Frame = Frame(_, IPv4(ip, Undefined));" + + , testCase "pattern references known named pattern — no error" $ + assertNoErrors + "pattern WGInit : (UDPHeader,Bytes) = (udp { length = 156 }, [0x01 _*]); \ + \pattern Compound : Frame = Frame(_, IPv4(ip, WGInit));" + + , testCase "flow references undefined pattern" $ + assertHasError (isUndefined "Ghost") + "flow Bad : FlowPattern = Ghost;" + + , testCase "flow references known pattern — no error" $ + assertNoErrors + "pattern P : T = udp { length = 1 }; \ + \flow F : FlowPattern = P;" + + , testCase "policy guard references undeclared zone" $ + -- 'unknown_zone' not declared; check should flag it + assertHasError (isUndefined "unknown_zone") + "policy fwd : Frame \ + \ on { hook = Forward, table = Filter, priority = Filter } \ + \ = { | Frame(iif in unknown_zone -> wan, _) -> Allow; \ + \ | _ -> Drop; \ + \ };" + + , testCase "policy references known zone — no error" $ + assertNoErrors + "interface lan : LAN {}; \ + \zone trusted = { lan }; \ + \policy fwd : Frame \ + \ on { hook = Forward, table = Filter, priority = Filter } \ + \ = { | Frame(iif in trusted -> wan, _) -> Allow; \ + \ | _ -> Drop; \ + \ };" + ] + +-- ─── Duplicate declaration tests ───────────────────────────────────────────── + +duplicateTests :: TestTree +duplicateTests = testGroup "duplicates" + [ testCase "duplicate interface" $ + assertHasError (isDuplicate "lan") + "interface lan : LAN {}; \ + \interface lan : WAN {};" + + , testCase "duplicate zone" $ + assertHasError (isDuplicate "z") + "zone z = { a }; \ + \zone z = { b };" + + , testCase "duplicate pattern" $ + assertHasError (isDuplicate "P") + "pattern P : T = udp { length = 1 }; \ + \pattern P : T = udp { length = 2 };" + + , testCase "duplicate policy" $ + assertHasError (isDuplicate "input") + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; }; \ + \policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ -> Drop; };" + + , testCase "distinct names — no error" $ + assertNoErrors + "interface lan : LAN {}; \ + \interface wan : WAN { dynamic; }; \ + \zone z = { lan };" + ] + +-- ─── Policy termination tests ──────────────────────────────────────────────── + +policyTerminationTests :: TestTree +policyTerminationTests = testGroup "policy termination" + [ testCase "last arm is Continue — error" $ + assertHasError (isNoContinue "bad_policy") + "policy bad_policy : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ -> Continue; };" + + , testCase "last arm is Drop — ok" $ + assertNoErrors + "policy good : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ if ct.state in { Established } -> Allow; \ + \ | _ -> Drop; \ + \ };" + + , testCase "last arm is Allow — ok" $ + assertNoErrors + "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + + , testCase "Continue in non-last arm is fine" $ + assertNoErrors + "rule r : Frame -> Action = \ + \ \\f -> case f of { \ + \ | Frame(_, IPv4(ip, _)) -> Continue; \ + \ | _ -> Drop; \ + \ };" + + , testCase "empty policy body — error" $ + assertHasError (isNoContinue "empty") + "policy empty : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = {};" + ] + +-- ─── Pattern cycle tests ───────────────────────────────────────────────────── + +patternCycleTests :: TestTree +patternCycleTests = testGroup "pattern cycles" + [ testCase "direct self-reference — cycle error" $ + assertHasError isCycle + "pattern Loop : T = Frame(_, Loop);" + + , testCase "mutual cycle — cycle error" $ + assertHasError isCycle + "pattern A : T = Frame(_, B); \ + \pattern B : T = Frame(_, A);" + + , testCase "linear chain — no cycle" $ + assertNoErrors + "pattern Base : T = udp { length = 1 }; \ + \pattern Mid : T = Frame(_, Base); \ + \pattern Top : T = Frame(_, Mid);" + ] + +-- ─── Clean full programs ────────────────────────────────────────────────────── + +cleanProgramTests :: TestTree +cleanProgramTests = testGroup "clean programs" + [ testCase "minimal router skeleton" $ + assertNoErrors + "interface wan : WAN { dynamic; }; \ + \interface lan : LAN { cidr4 = { 10.17.1.0/24 }; }; \ + \interface wg0 : WireGuard {}; \ + \zone lan_zone = { lan, wg0 }; \ + \policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ if ct.state in { Established, Related } -> Allow; \ + \ | _ -> Drop; \ + \ }; \ + \policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + + , testCase "pattern and flow declarations" $ + assertNoErrors + "pattern WGInit : (UDPHeader,Bytes) = (udp { length = 156 }, [0x01 _*]); \ + \pattern WGResp : (UDPHeader,Bytes) = (udp { length = 100 }, [0x02 _*]); \ + \flow WGHandshake : FlowPattern = WGInit . WGResp within 5s;" + ] diff --git a/test/CompileTests.hs b/test/CompileTests.hs new file mode 100644 index 0000000..3a05ad1 --- /dev/null +++ b/test/CompileTests.hs @@ -0,0 +1,384 @@ +{-# LANGUAGE OverloadedStrings #-} +module CompileTests (tests) where + +import Test.Tasty +import Test.Tasty.HUnit +import qualified Data.Aeson as A +import qualified Data.Aeson.Key as AK +import qualified Data.Aeson.KeyMap as AKM +import qualified Data.Vector as V +import qualified Data.ByteString.Lazy.Char8 as BL8 + +import FWL.AST +import FWL.Compile +import FWL.Util + +tests :: TestTree +tests = testGroup "Compile" + [ jsonStructureTests + , chainTests + , ruleExprTests + , verdictTests + , layerStrippingTests + , continueTests + , configTests + ] + +-- ─── Helpers ───────────────────────────────────────────────────────────────── + +compileToValue :: String -> IO A.Value +compileToValue src = do + p <- parseOk src + case A.decode (compileToJson p) of + Nothing -> assertFailure "Compiled output is not valid JSON" >> undefined + Just v -> return v + +-- Navigate a Value by a list of string keys / numeric indices. +at :: [String] -> A.Value -> Maybe A.Value +at [] v = Just v +at (k:ks) (A.Object o) = + case AKM.lookup (AK.fromString k) o of + Nothing -> Nothing + Just v -> at ks v +at (k:ks) (A.Array arr) = + case reads k of + [(i,"")] | i < V.length arr -> at ks (arr V.! i) + _ -> Nothing +at _ _ = Nothing + +nftArr :: A.Value -> IO [A.Value] +nftArr v = + case at ["nftables"] v of + Just (A.Array arr) -> return (V.toList arr) + _ -> assertFailure "Missing top-level 'nftables' array" >> undefined + +withKey :: String -> [A.Value] -> [A.Value] +withKey k = filter (\v -> case at [k] v of Just _ -> True; _ -> False) + +-- ─── JSON structure tests ──────────────────────────────────────────────────── + +jsonStructureTests :: TestTree +jsonStructureTests = testGroup "JSON structure" + [ testCase "output is valid JSON" $ do + _ <- compileToValue + "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + return () + + , testCase "top-level nftables array present" $ do + v <- compileToValue "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + _ <- nftArr v + return () + + , testCase "metainfo is first element" $ do + v <- compileToValue "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + arr <- nftArr v + case arr of + (first:_) -> case at ["metainfo"] first of + Just _ -> return () + Nothing -> assertFailure "First element is not metainfo" + [] -> assertFailure "Empty nftables array" + + , testCase "table object present" $ do + v <- compileToValue "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + arr <- nftArr v + assertBool "Expected at least one table object" + (not (null (withKey "table" arr))) + + , testCase "default table name is fwl" $ do + v <- compileToValue "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + arr <- nftArr v + case withKey "table" arr of + (t:_) -> at ["table","name"] t @?= Just (A.String "fwl") + [] -> assertFailure "No table object" + + , testCase "custom table name respected" $ do + v <- compileToValue + "config { table = \"custom\"; } \ + \policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + arr <- nftArr v + case withKey "table" arr of + (t:_) -> at ["table","name"] t @?= Just (A.String "custom") + [] -> assertFailure "No table object" + ] + +-- ─── Chain declaration tests ───────────────────────────────────────────────── + +chainTests :: TestTree +chainTests = testGroup "chain declarations" + [ testCase "filter input chain has correct hook" $ do + v <- compileToValue + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ -> Drop; };" + arr <- nftArr v + case withKey "chain" arr of + (c:_) -> at ["chain","hook"] c @?= Just (A.String "input") + [] -> assertFailure "No chain" + + , testCase "filter chain type is filter" $ do + v <- compileToValue + "policy fwd : Frame \ + \ on { hook = Forward, table = Filter, priority = Filter } \ + \ = { | _ -> Drop; };" + arr <- nftArr v + case withKey "chain" arr of + (c:_) -> at ["chain","type"] c @?= Just (A.String "filter") + [] -> assertFailure "No chain" + + , testCase "NAT chain type is nat" $ do + v <- compileToValue + "policy nat_post : Frame \ + \ on { hook = Postrouting, table = NAT, priority = SrcNat } \ + \ = { | _ -> Allow; };" + arr <- nftArr v + case withKey "chain" arr of + (c:_) -> at ["chain","type"] c @?= Just (A.String "nat") + [] -> assertFailure "No chain" + + , testCase "input chain default policy is drop" $ do + v <- compileToValue + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ -> Drop; };" + arr <- nftArr v + case withKey "chain" arr of + (c:_) -> at ["chain","policy"] c @?= Just (A.String "drop") + [] -> assertFailure "No chain" + + , testCase "output chain default policy is accept" $ do + v <- compileToValue + "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + arr <- nftArr v + case withKey "chain" arr of + (c:_) -> at ["chain","policy"] c @?= Just (A.String "accept") + [] -> assertFailure "No chain" + + , testCase "chain name matches policy name" $ do + v <- compileToValue + "policy my_input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ -> Drop; };" + arr <- nftArr v + case withKey "chain" arr of + (c:_) -> at ["chain","name"] c @?= Just (A.String "my_input") + [] -> assertFailure "No chain" + + , testCase "two policies produce two chains" $ do + v <- compileToValue + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ -> Drop; }; \ + \policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + arr <- nftArr v + length (withKey "chain" arr) @?= 2 + ] + +-- ─── Rule expression tests ─────────────────────────────────────────────────── + +ruleExprs :: [A.Value] -> [A.Value] +ruleExprs arr = + [ e | r <- withKey "rule" arr + , Just (A.Array es) <- [at ["rule","expr"] r] + , e <- V.toList es ] + +ruleExprTests :: TestTree +ruleExprTests = testGroup "rule expressions" + [ testCase "two arms produce two rules" $ do + v <- compileToValue + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ if ct.state in { Established, Related } -> Allow; \ + \ | _ -> Drop; \ + \ };" + arr <- nftArr v + length (withKey "rule" arr) @?= 2 + + , testCase "arm without guard produces one rule" $ do + v <- compileToValue + "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + arr <- nftArr v + length (withKey "rule" arr) @?= 1 + + , testCase "rule expr array is present" $ do + v <- compileToValue + "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + arr <- nftArr v + case withKey "rule" arr of + (r:_) -> case at ["rule","expr"] r of + Just (A.Array _) -> return () + _ -> assertFailure "Missing or non-array 'expr'" + [] -> assertFailure "No rule" + + , testCase "IPv4 ctor emits nfproto match" $ do + v <- compileToValue + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | Frame(_, IPv4(ip, _)) -> Allow; \ + \ | _ -> Drop; \ + \ };" + arr <- nftArr v + let matches = withKey "match" (ruleExprs arr) + hasNfp = any (\m -> + at ["match","left","meta","key"] m == Just (A.String "nfproto")) + matches + assertBool "Expected nfproto match for IPv4 ctor" hasNfp + + , testCase "record field pat emits payload match" $ do + v <- compileToValue + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | Frame(_, TCP(tcp { dport = :22 }, _)) -> Allow; \ + \ | _ -> Drop; \ + \ };" + arr <- nftArr v + let matches = withKey "match" (ruleExprs arr) + hasPort = any (\m -> + at ["match","right"] m == Just (A.String "22")) + matches + assertBool "Expected port 22 payload match" hasPort + ] + +-- ─── Verdict tests ─────────────────────────────────────────────────────────── + +allExprs :: [A.Value] -> [A.Value] +allExprs arr = + concatMap (\r -> case at ["rule","expr"] r of + Just (A.Array es) -> V.toList es; _ -> []) + (withKey "rule" arr) + +verdictTests :: TestTree +verdictTests = testGroup "verdicts" + [ testCase "Allow compiles to accept" $ do + v <- compileToValue + "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + arr <- nftArr v + assertBool "Expected accept verdict" + (not (null (withKey "accept" (allExprs arr)))) + + , testCase "Drop compiles to drop" $ do + v <- compileToValue + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ -> Drop; };" + arr <- nftArr v + assertBool "Expected drop verdict" + (not (null (withKey "drop" (allExprs arr)))) + + , testCase "Masquerade compiles to masquerade" $ do + v <- compileToValue + "policy nat_post : Frame \ + \ on { hook = Postrouting, table = NAT, priority = SrcNat } \ + \ = { | _ -> Masquerade; };" + arr <- nftArr v + assertBool "Expected masquerade verdict" + (not (null (withKey "masquerade" (allExprs arr)))) + + , testCase "rule call compiles to jump" $ do + v <- compileToValue + "rule blockAll : Frame -> Action = \\f -> case f of { | _ -> Drop; }; \ + \policy fwd : Frame \ + \ on { hook = Forward, table = Filter, priority = Filter } \ + \ = { | frame -> blockAll(frame); };" + arr <- nftArr v + assertBool "Expected jump verdict for rule call" + (not (null (withKey "jump" (allExprs arr)))) + ] + +-- ─── Layer stripping tests ─────────────────────────────────────────────────── + +layerStrippingTests :: TestTree +layerStrippingTests = testGroup "layer stripping" + [ testCase "Frame with and without Ether both emit nfproto match" $ do + let withEther = + "policy p1 : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | Frame(_, Ether(_, IPv4(ip, _))) -> Allow; \ + \ | _ -> Drop; \ + \ };" + withoutEther = + "policy p1 : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | Frame(_, IPv4(ip, _)) -> Allow; \ + \ | _ -> Drop; \ + \ };" + v1 <- compileToValue withEther + v2 <- compileToValue withoutEther + arr1 <- nftArr v1 + arr2 <- nftArr v2 + let nfp arr = filter + (\m -> at ["match","left","meta","key"] m == Just (A.String "nfproto")) + (withKey "match" (ruleExprs arr)) + assertBool "Both should produce nfproto matches" + (not (null (nfp arr1)) && not (null (nfp arr2))) + ] + +-- ─── Continue tests ─────────────────────────────────────────────────────────── + +continueTests :: TestTree +continueTests = testGroup "Continue" + [ testCase "two terminal arms produce two rules" $ do + v <- compileToValue + "policy fwd : Frame \ + \ on { hook = Forward, table = Filter, priority = Filter } \ + \ = { | _ if ct.state in { Established } -> Allow; \ + \ | _ -> Drop; \ + \ };" + arr <- nftArr v + length (withKey "rule" arr) @?= 2 + + , testCase "non-Continue arms still produce rules" $ do + v <- compileToValue + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ if ct.state in { Established } -> Allow; \ + \ | _ -> Drop; \ + \ };" + arr <- nftArr v + assertBool "Should have rules for non-Continue arms" + (not (null (withKey "rule" arr))) + ] + +-- ─── Config tests ───────────────────────────────────────────────────────────── + +configTests :: TestTree +configTests = testGroup "config" + [ testCase "all rule objects reference correct table" $ do + v <- compileToValue + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ -> Drop; };" + arr <- nftArr v + mapM_ (\r -> at ["rule","table"] r @?= Just (A.String "fwl")) + (withKey "rule" arr) + + , testCase "chain objects reference correct table" $ do + v <- compileToValue + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { | _ -> Drop; };" + arr <- nftArr v + mapM_ (\c -> at ["chain","table"] c @?= Just (A.String "fwl")) + (withKey "chain" arr) + ] diff --git a/test/FWL/Util.hs b/test/FWL/Util.hs new file mode 100644 index 0000000..b5bf60e --- /dev/null +++ b/test/FWL/Util.hs @@ -0,0 +1,44 @@ +-- | Shared test utilities. +module FWL.Util where + +import Test.Tasty.HUnit +import Text.Parsec.String (Parser) +import Text.Parsec (parse) + +import FWL.Parser (parseProgram) +import FWL.AST + +-- | Assert a parser succeeds and return the result. +shouldParse :: (Show a) => Parser a -> String -> IO a +shouldParse p input = + case parse p "" input of + Left err -> assertFailure ("Unexpected parse error:\n" ++ show err) + >> undefined + Right v -> return v + +-- | Assert a parser fails. +shouldFailParse :: (Show a) => Parser a -> String -> IO () +shouldFailParse p input = + case parse p "" input of + Left _ -> return () + Right v -> assertFailure ("Expected parse failure but got: " ++ show v) + +-- | Parse a full program, asserting success. +parseOk :: String -> IO Program +parseOk src = + case parseProgram "" src of + Left err -> assertFailure ("Parse error:\n" ++ show err) >> undefined + Right p -> return p + +-- | Parse a full program, asserting failure. +parseFail :: String -> IO () +parseFail src = + case parseProgram "" src of + Left _ -> return () + Right p -> assertFailure ("Expected parse failure, got:\n" ++ show p) + +-- | Extract the single declaration from a one-decl program. +singleDecl :: Program -> IO Decl +singleDecl (Program _ [d]) = return d +singleDecl (Program _ ds) = + assertFailure ("Expected 1 decl, got " ++ show (length ds)) >> undefined diff --git a/test/ParserTests.hs b/test/ParserTests.hs new file mode 100644 index 0000000..c1f544a --- /dev/null +++ b/test/ParserTests.hs @@ -0,0 +1,516 @@ +module ParserTests (tests) where + +import Test.Tasty +import Test.Tasty.HUnit + +import FWL.AST +import FWL.Util + +tests :: TestTree +tests = testGroup "Parser" + [ interfaceTests + , zoneTests + , importTests + , letTests + , patternTests + , flowTests + , typeTests + , exprTests + , policyTests + , ruleTests + , configTests + , errorTests + ] + +-- ─── Interface ─────────────────────────────────────────────────────────────── + +interfaceTests :: TestTree +interfaceTests = testGroup "interface" + [ testCase "WAN dynamic" $ do + p <- parseOk "interface wan : WAN { dynamic; };" + d <- singleDecl p + case d of + DInterface "wan" IWan [IPDynamic] -> return () + _ -> assertFailure (show d) + + , testCase "LAN with cidr4" $ do + p <- parseOk "interface lan : LAN { cidr4 = { 10.0.0.0/8 }; };" + d <- singleDecl p + case d of + DInterface "lan" ILan [IPCidr4 [(LIPv4 (10,0,0,0), 8)]] -> return () + _ -> assertFailure (show d) + + , testCase "LAN with cidr4 and cidr6" $ do + p <- parseOk + "interface lan : LAN { \ + \ cidr4 = { 10.17.1.0/24 }; \ + \ cidr6 = { 192.168.0.0/16 }; \ + \};" + d <- singleDecl p + case d of + DInterface "lan" ILan [IPCidr4 _, IPCidr6 _] -> return () + _ -> assertFailure (show d) + + , testCase "WireGuard interface" $ do + p <- parseOk "interface wg0 : WireGuard {};" + d <- singleDecl p + case d of + DInterface "wg0" IWireGuard [] -> return () + _ -> assertFailure (show d) + + , testCase "user-defined kind" $ do + p <- parseOk "interface eth0 : Bridge {};" + d <- singleDecl p + case d of + DInterface "eth0" (IUser "Bridge") [] -> return () + _ -> assertFailure (show d) + + , testCase "multiple CIDRs in set" $ do + p <- parseOk + "interface lan : LAN { \ + \ cidr4 = { 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 }; \ + \};" + d <- singleDecl p + case d of + DInterface _ _ [IPCidr4 cidrs] -> length cidrs @?= 3 + _ -> assertFailure (show d) + ] + +-- ─── Zone ──────────────────────────────────────────────────────────────────── + +zoneTests :: TestTree +zoneTests = testGroup "zone" + [ testCase "single member" $ do + p <- parseOk "zone trusted = { lan };" + d <- singleDecl p + case d of + DZone "trusted" ["lan"] -> return () + _ -> assertFailure (show d) + + , testCase "multiple members" $ do + p <- parseOk "zone lan_zone = { lan, wg0, vlan10 };" + d <- singleDecl p + case d of + DZone "lan_zone" ["lan","wg0","vlan10"] -> return () + _ -> assertFailure (show d) + ] + +-- ─── Import ────────────────────────────────────────────────────────────────── + +importTests :: TestTree +importTests = testGroup "import" + [ testCase "basic import" $ do + p <- parseOk "import rfc1918 : CIDRSet from \"builtin:rfc1918\";" + d <- singleDecl p + case d of + DImport "rfc1918" (TName "CIDRSet" []) "builtin:rfc1918" -> return () + _ -> assertFailure (show d) + ] + +-- ─── Let ───────────────────────────────────────────────────────────────────── + +letTests :: TestTree +letTests = testGroup "let" + [ testCase "simple integer" $ do + p <- parseOk "let timeout : Int = 30;" + d <- singleDecl p + case d of + DLet "timeout" (TName "Int" []) (ELit (LInt 30)) -> return () + _ -> assertFailure (show d) + + , testCase "map literal" $ do + p <- parseOk + "let forwards : Map<(Protocol,Port),(IP,Port)> = { \ + \ (tcp, :8080) -> (10.0.0.1, :80) \ + \};" + d <- singleDecl p + case d of + DLet "forwards" _ (EMap [_]) -> return () + _ -> assertFailure (show d) + + , testCase "string literal" $ do + p <- parseOk "let name : String = \"hello\";" + d <- singleDecl p + case d of + DLet "name" _ (ELit (LString "hello")) -> return () + _ -> assertFailure (show d) + ] + +-- ─── Pattern ───────────────────────────────────────────────────────────────── + +patternTests :: TestTree +patternTests = testGroup "pattern" + [ testCase "tuple with record field" $ do + p <- parseOk + "pattern WGInitiation : (UDPHeader, Bytes) = \ + \ (udp { length = 156 }, [0x01 _*]);" + d <- singleDecl p + case d of + DPattern "WGInitiation" _ (PTuple [PRecord "udp" _, PBytes _]) -> return () + _ -> assertFailure (show d) + + , testCase "byte pattern elements" $ do + p <- parseOk + "pattern WGResponse : (UDPHeader, Bytes) = \ + \ (udp { length = 100 }, [0x02 _ _*]);" + d <- singleDecl p + case d of + DPattern "WGResponse" _ (PTuple [_, PBytes [BEHex 0x02, BEWild, BEWildStar]]) -> + return () + _ -> assertFailure (show d) + + , testCase "named pattern reference in ctor" $ do + p <- parseOk + "pattern Complex : Frame = \ + \ Frame(_, IPv4(ip, WGInitiation));" + d <- singleDecl p + case d of + DPattern "Complex" _ (PFrame Nothing (PCtor "IPv4" [PVar "ip", PNamed "WGInitiation"])) -> + return () + _ -> assertFailure (show d) + + , testCase "record with field bind" $ do + p <- parseOk "pattern HasTCP : TCP = tcp { dport };" + d <- singleDecl p + case d of + DPattern "HasTCP" _ (PRecord "tcp" [FPBind "dport"]) -> return () + _ -> assertFailure (show d) + + , testCase "record with field equality" $ do + p <- parseOk "pattern SSH : TCP = tcp { dport = :22 };" + d <- singleDecl p + case d of + DPattern "SSH" _ (PRecord "tcp" [FPEq "dport" (LPort 22)]) -> return () + _ -> assertFailure (show d) + ] + +-- ─── Flow ──────────────────────────────────────────────────────────────────── + +flowTests :: TestTree +flowTests = testGroup "flow" + [ testCase "two-step sequence with within" $ do + p <- parseOk + "flow WireGuardHandshake : FlowPattern = \ + \ WGInitiation . WGResponse within 5s;" + d <- singleDecl p + case d of + DFlow "WireGuardHandshake" (FSeq (FAtom "WGInitiation") (FAtom "WGResponse") (Just (5, Seconds))) -> + return () + _ -> assertFailure (show d) + + , testCase "single atom flow" $ do + p <- parseOk "flow Simple : FlowPattern = Ping;" + d <- singleDecl p + case d of + DFlow "Simple" (FAtom "Ping") -> return () + _ -> assertFailure (show d) + + , testCase "duration in milliseconds" $ do + p <- parseOk "flow Fast : FlowPattern = A . B within 500ms;" + d <- singleDecl p + case d of + DFlow "Fast" (FSeq _ _ (Just (500, Millis))) -> return () + _ -> assertFailure (show d) + ] + +-- ─── Types ─────────────────────────────────────────────────────────────────── + +typeTests :: TestTree +typeTests = testGroup "types" + [ testCase "simple name" $ do + p <- parseOk "let x : Frame = Allow;" + d <- singleDecl p + case d of + DLet _ (TName "Frame" []) _ -> return () + _ -> assertFailure (show d) + + , testCase "generic type" $ do + p <- parseOk "let x : Map = Allow;" + d <- singleDecl p + case d of + DLet _ (TName "Map" [TName "Int" [], TName "String" []]) _ -> return () + _ -> assertFailure (show d) + + , testCase "function type" $ do + p <- parseOk "let x : Frame -> Action = Allow;" + d <- singleDecl p + case d of + DLet _ (TFun (TName "Frame" []) (TName "Action" [])) _ -> return () + _ -> assertFailure (show d) + + , testCase "effect type" $ do + p <- parseOk "let x : Action = Allow;" + d <- singleDecl p + case d of + DLet _ (TEffect ["Log","FlowMatch"] (TName "Action" [])) _ -> return () + _ -> assertFailure (show d) + + , testCase "tuple type" $ do + p <- parseOk "let x : (Int, String) = Allow;" + d <- singleDecl p + case d of + DLet _ (TTuple [TName "Int" [], TName "String" []]) _ -> return () + _ -> assertFailure (show d) + + , testCase "function with effects" $ do + p <- parseOk "let x : Frame -> Action = Allow;" + d <- singleDecl p + case d of + DLet _ (TFun _ (TEffect ["Log"] _)) _ -> return () + _ -> assertFailure (show d) + ] + +-- ─── Expressions ───────────────────────────────────────────────────────────── + +exprTests :: TestTree +exprTests = testGroup "expressions" + [ testCase "boolean and" $ do + p <- parseOk "let x : Bool = a && b;" + d <- singleDecl p + case d of + DLet _ _ (EInfix OpAnd (EVar "a") (EVar "b")) -> return () + _ -> assertFailure (show d) + + , testCase "set membership with 'in'" $ do + p <- parseOk "let x : Bool = ct.state in { Established, Related };" + d <- singleDecl p + case d of + DLet _ _ (EInfix OpIn (EQual ["ct","state"]) (ESet _)) -> return () + _ -> assertFailure (show d) + + , testCase "equality comparison" $ do + p <- parseOk "let x : Bool = tcp.dport == :22;" + d <- singleDecl p + case d of + DLet _ _ (EInfix OpEq (EQual ["tcp","dport"]) (ELit (LPort 22))) -> return () + _ -> assertFailure (show d) + + , testCase "if-then-else" $ do + p <- parseOk "let x : Action = if a then Allow else Drop;" + d <- singleDecl p + case d of + DLet _ _ (EIf (EVar "a") (EVar "Allow") (EVar "Drop")) -> return () + _ -> assertFailure (show d) + + , testCase "perform expression" $ do + p <- parseOk "let x : Action = perform Log.emit(Info, \"msg\");" + d <- singleDecl p + case d of + DLet _ _ (EPerform ["Log","emit"] [ELit (LString "Info"), ELit (LString "msg")]) -> return () + DLet _ _ (EPerform ["Log","emit"] _) -> return () -- arg parsing flexible + _ -> assertFailure (show d) + + , testCase "do block" $ do + p <- parseOk "let x : Action = do { y <- foo; y };" + d <- singleDecl p + case d of + DLet _ _ (EDo [DSBind "y" _, DSExpr (EVar "y")]) -> return () + _ -> assertFailure (show d) + + , testCase "nested case" $ do + p <- parseOk + "let x : Action = case e of { \ + \ | a -> Allow; \ + \ | _ -> Drop; \ + \};" + d <- singleDecl p + case d of + DLet _ _ (ECase (EVar "e") [Arm (PVar "a") Nothing _, Arm PWild Nothing _]) -> return () + _ -> assertFailure (show d) + + , testCase "lambda" $ do + p <- parseOk "let x : Frame -> Action = \\frame -> Allow;" + d <- singleDecl p + case d of + DLet _ _ (ELam "frame" (EVar "Allow")) -> return () + _ -> assertFailure (show d) + + , testCase "string concat" $ do + p <- parseOk "let x : String = \"hello\" ++ \" world\";" + d <- singleDecl p + case d of + DLet _ _ (EInfix OpConcat _ _) -> return () + _ -> assertFailure (show d) + + , testCase "negation" $ do + p <- parseOk "let x : Bool = !flag;" + d <- singleDecl p + case d of + DLet _ _ (ENot (EVar "flag")) -> return () + _ -> assertFailure (show d) + + , testCase "set literal" $ do + p <- parseOk "let x : Set = { 22, 80, 443 };" + d <- singleDecl p + case d of + DLet _ _ (ESet [ELit (LInt 22), ELit (LInt 80), ELit (LInt 443)]) -> return () + _ -> assertFailure (show d) + ] + +-- ─── Policy ────────────────────────────────────────────────────────────────── + +policyTests :: TestTree +policyTests = testGroup "policy" + [ testCase "minimal policy" $ do + p <- parseOk + "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + d <- singleDecl p + case d of + DPolicy "output" _ (PolicyMeta HOutput TFilter (Priority 0)) [_] -> return () + _ -> assertFailure (show d) + + , testCase "NAT prerouting" $ do + p <- parseOk + "policy nat_pre : Frame \ + \ on { hook = Prerouting, table = NAT, priority = DstNat } \ + \ = { | _ -> Allow; };" + d <- singleDecl p + case d of + DPolicy _ _ (PolicyMeta HPrerouting TNAT (Priority (-100))) _ -> return () + _ -> assertFailure (show d) + + , testCase "arm with guard" $ do + p <- parseOk + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { \ + \ | _ if ct.state in { Established, Related } -> Allow; \ + \ | _ -> Drop; \ + \ };" + d <- singleDecl p + case d of + DPolicy _ _ _ [Arm PWild (Just _) _, Arm PWild Nothing _] -> return () + _ -> assertFailure (show d) + + , testCase "Frame pattern with path" $ do + p <- parseOk + "policy forward : Frame \ + \ on { hook = Forward, table = Filter, priority = Filter } \ + \ = { \ + \ | Frame(iif in lan_zone -> wan, _) -> Allow; \ + \ | _ -> Drop; \ + \ };" + d <- singleDecl p + case d of + DPolicy _ _ _ (Arm (PFrame (Just _) _) Nothing _ : _) -> return () + _ -> assertFailure (show d) + + , testCase "Frame pattern without Ether (layer stripping)" $ do + p <- parseOk + "policy input : Frame \ + \ on { hook = Input, table = Filter, priority = Filter } \ + \ = { \ + \ | Frame(_, IPv4(ip, TCP(tcp, _))) if tcp.dport == :22 -> Allow; \ + \ | _ -> Drop; \ + \ };" + d <- singleDecl p + case d of + DPolicy _ _ _ (Arm (PFrame Nothing (PCtor "IPv4" _)) _ _ : _) -> return () + _ -> assertFailure (show d) + + , testCase "policy arm calls rule" $ do + p <- parseOk + "policy forward : Frame \ + \ on { hook = Forward, table = Filter, priority = Filter } \ + \ = { \ + \ | frame -> blockOutboundWG(frame); \ + \ };" + d <- singleDecl p + case d of + DPolicy _ _ _ [Arm (PVar "frame") Nothing (EApp (EVar "blockOutboundWG") _)] -> + return () + _ -> assertFailure (show d) + + , testCase "Continue arm is parsed" $ do + p <- parseOk + "rule r : Frame -> Action = \ + \ \\frame -> case frame of { \ + \ | _ -> Continue; \ + \ };" + d <- singleDecl p + case d of + DRule _ _ _ -> return () + _ -> assertFailure (show d) + ] + +-- ─── Rule ──────────────────────────────────────────────────────────────────── + +ruleTests :: TestTree +ruleTests = testGroup "rule" + [ testCase "simple rule" $ do + p <- parseOk + "rule blockAll : Frame -> Action = \ + \ \\frame -> case frame of { | _ -> Drop; };" + d <- singleDecl p + case d of + DRule "blockAll" _ (ELam "frame" (ECase _ _)) -> return () + _ -> assertFailure (show d) + + , testCase "rule with effects in type" $ do + p <- parseOk + "rule logged : Frame -> Action = \ + \ \\f -> case f of { | _ -> Allow; };" + d <- singleDecl p + case d of + DRule "logged" (TFun _ (TEffect ["Log"] _)) _ -> return () + _ -> assertFailure (show d) + + , testCase "nested case in rule" $ do + p <- parseOk + "rule check : Frame -> Action = \ + \ \\frame -> \ + \ case frame of { \ + \ | Frame(_, IPv4(ip, UDP(udp, _))) -> \ + \ case perform FlowMatch.check(ip, wg) of { \ + \ | Matched -> Drop; \ + \ | _ -> Continue; \ + \ }; \ + \ | _ -> Continue; \ + \ };" + d <- singleDecl p + case d of + DRule "check" _ (ELam _ (ECase _ _)) -> return () + _ -> assertFailure (show d) + ] + +-- ─── Config ────────────────────────────────────────────────────────────────── + +configTests :: TestTree +configTests = testGroup "config" + [ testCase "default table name" $ do + p <- parseOk "interface wan : WAN {};" + configTable (progConfig p) @?= "fwl" + + , testCase "custom table name" $ do + p <- parseOk "config { table = \"myrules\"; } interface wan : WAN {};" + configTable (progConfig p) @?= "myrules" + ] + +-- ─── Error cases ───────────────────────────────────────────────────────────── + +errorTests :: TestTree +errorTests = testGroup "parse errors" + [ testCase "missing semicolon" $ + parseFail "interface wan : WAN {}" + + , testCase "unknown hook" $ + parseFail + "policy p : Frame \ + \ on { hook = Bogus, table = Filter, priority = Filter } \ + \ = { | _ -> Allow; };" + + , testCase "empty arm block with no arms is ok" $ do + p <- parseOk + "policy output : Frame \ + \ on { hook = Output, table = Filter, priority = Filter } \ + \ = {};" + d <- singleDecl p + case d of + DPolicy _ _ _ [] -> return () + _ -> assertFailure (show d) + + , testCase "CIDR without prefix fails" $ + parseFail "interface lan : LAN { cidr4 = { 10.0.0.1 }; };" + ] diff --git a/test/Spec.hs b/test/Spec.hs new file mode 100644 index 0000000..eb96207 --- /dev/null +++ b/test/Spec.hs @@ -0,0 +1,15 @@ +module Main where + +import Test.Tasty +import Test.Tasty.HUnit + +import qualified ParserTests +import qualified CheckTests +import qualified CompileTests + +main :: IO () +main = defaultMain $ testGroup "FWL" + [ ParserTests.tests + , CheckTests.tests + , CompileTests.tests + ]