v2 perplexed
This commit is contained in:
54
app/Main.hs
Normal file
54
app/Main.hs
Normal file
@@ -0,0 +1,54 @@
|
||||
module Main where
|
||||
|
||||
import System.Environment (getArgs)
|
||||
import System.Exit (exitFailure, exitSuccess)
|
||||
import System.IO (hPutStrLn, stderr)
|
||||
import qualified Data.ByteString.Lazy.Char8 as BL
|
||||
|
||||
import FWL.Parser (parseFile)
|
||||
import FWL.Pretty (prettyProgram)
|
||||
import FWL.Check (checkProgram)
|
||||
import FWL.Compile (compileToJson)
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
args <- getArgs
|
||||
case args of
|
||||
["check", fp] -> runCheck fp
|
||||
["compile", fp] -> runCompile fp
|
||||
["pretty", fp] -> runPretty fp
|
||||
_ -> do
|
||||
putStrLn "Usage: fwlc <command> <file.fwl>"
|
||||
putStrLn " check <file> -- parse and static-check"
|
||||
putStrLn " compile <file> -- emit nftables JSON to stdout"
|
||||
putStrLn " pretty <file> -- parse and re-print"
|
||||
exitFailure
|
||||
|
||||
runCheck :: FilePath -> IO ()
|
||||
runCheck fp = do
|
||||
result <- parseFile fp
|
||||
case result of
|
||||
Left err -> hPutStrLn stderr ("Parse error:\n" ++ show err) >> exitFailure
|
||||
Right prog -> do
|
||||
let errs = checkProgram prog
|
||||
if null errs
|
||||
then putStrLn "OK" >> exitSuccess
|
||||
else mapM_ (hPutStrLn stderr . show) errs >> exitFailure
|
||||
|
||||
runCompile :: FilePath -> IO ()
|
||||
runCompile fp = do
|
||||
result <- parseFile fp
|
||||
case result of
|
||||
Left err -> hPutStrLn stderr ("Parse error:\n" ++ show err) >> exitFailure
|
||||
Right prog -> do
|
||||
let errs = checkProgram prog
|
||||
if null errs
|
||||
then BL.putStrLn (compileToJson prog)
|
||||
else mapM_ (hPutStrLn stderr . ("Check error: " ++) . show) errs >> exitFailure
|
||||
|
||||
runPretty :: FilePath -> IO ()
|
||||
runPretty fp = do
|
||||
result <- parseFile fp
|
||||
case result of
|
||||
Left err -> hPutStrLn stderr ("Parse error:\n" ++ show err) >> exitFailure
|
||||
Right prog -> putStr (prettyProgram prog)
|
||||
1
cabal.project
Normal file
1
cabal.project
Normal file
@@ -0,0 +1 @@
|
||||
packages: .
|
||||
95
examples/router.fwl
Normal file
95
examples/router.fwl
Normal file
@@ -0,0 +1,95 @@
|
||||
-- Example: home router firewall in FWL
|
||||
-- Compile with: fwlc compile examples/router.fwl
|
||||
|
||||
interface wan : WAN { dynamic; };
|
||||
interface lan : LAN { cidr4 = { 10.17.1.0/24 }; };
|
||||
interface wg0 : WireGuard {};
|
||||
|
||||
zone lan_zone = { lan, wg0 };
|
||||
|
||||
import rfc1918 : CIDRSet from "builtin:rfc1918";
|
||||
|
||||
let forwards : Map<(Protocol, Port), (IP, Port)> = {
|
||||
(tcp, :8080) -> (10.17.1.10, :80),
|
||||
(tcp, :2222) -> (10.17.1.11, :22)
|
||||
};
|
||||
|
||||
-- WireGuard handshake detection (compiles to ct mark state machine)
|
||||
pattern WGInitiation : (UDPHeader, Bytes) =
|
||||
(udp { length = 156 }, [0x01 _*]);
|
||||
|
||||
pattern WGResponse : (UDPHeader, Bytes) =
|
||||
(udp { length = 100 }, [0x02 _*]);
|
||||
|
||||
flow WireGuardHandshake : FlowPattern =
|
||||
WGInitiation . WGResponse within 5s;
|
||||
|
||||
-- Block LAN clients from tunnelling out via WireGuard
|
||||
rule blockOutboundWG : Frame -> <FlowMatch, Log> Action =
|
||||
\frame ->
|
||||
case frame of {
|
||||
| Frame(iif in lan_zone -> wan, IPv4(ip, UDP(udp, payload)))
|
||||
if matches(WGInitiation, (udp, payload)) ->
|
||||
case perform FlowMatch.check(flowOf(ip, wg), WireGuardHandshake) of {
|
||||
| Matched -> do {
|
||||
perform Log.emit(Warn, "WG blocked");
|
||||
Drop
|
||||
};
|
||||
| _ -> Continue;
|
||||
};
|
||||
| _ -> Continue;
|
||||
};
|
||||
|
||||
-- Inbound to router
|
||||
policy input : Frame
|
||||
on { hook = Input, table = Filter, priority = Filter }
|
||||
= {
|
||||
| _ if ct.state in { Established, Related } -> Allow;
|
||||
| Frame(lo, _) -> Allow;
|
||||
| Frame(_, IPv6(ip6, ICMPv6(_, _)))
|
||||
if ip6.src in fe80::/10 -> Allow;
|
||||
| Frame(_, IPv4(_, TCP(tcp, _)))
|
||||
if tcp.dport == :22 -> Allow;
|
||||
| Frame(_, IPv4(_, UDP(udp, _)))
|
||||
if udp.dport == :51944 -> Allow;
|
||||
| _ -> Drop;
|
||||
};
|
||||
|
||||
-- Forwarded traffic
|
||||
policy forward : Frame
|
||||
on { hook = Forward, table = Filter, priority = Filter }
|
||||
= {
|
||||
| _ if ct.state in { Established, Related } -> Allow;
|
||||
| frame if iif in lan_zone && oif == wan -> blockOutboundWG(frame);
|
||||
| _ if ct.status == DNAT -> Allow;
|
||||
| Frame(iif in lan_zone -> wan, _) -> Allow;
|
||||
| Frame(iif in lan_zone -> lan_zone, _) -> Allow;
|
||||
| Frame(wan -> lan_zone, IPv4(ip, TCP(tcp, _)))
|
||||
if (ip.dst, tcp.dport) in forwards -> Allow;
|
||||
| _ -> Drop;
|
||||
};
|
||||
|
||||
-- Outbound from router
|
||||
policy output : Frame
|
||||
on { hook = Output, table = Filter, priority = Filter }
|
||||
= {
|
||||
| _ -> Allow;
|
||||
};
|
||||
|
||||
-- NAT
|
||||
policy nat_prerouting : Frame
|
||||
on { hook = Prerouting, table = NAT, priority = DstNat }
|
||||
= {
|
||||
| Frame(_, IPv4(ip, _)) ->
|
||||
if perform FIB.daddrLocal(ip.dst)
|
||||
then DNATMap(forwards)
|
||||
else Allow;
|
||||
| _ -> Allow;
|
||||
};
|
||||
|
||||
policy nat_postrouting : Frame
|
||||
on { hook = Postrouting, table = NAT, priority = SrcNat }
|
||||
= {
|
||||
| Frame(_ -> wan, IPv4(ip, _)) if ip.src in rfc1918 -> Masquerade;
|
||||
| _ -> Allow;
|
||||
};
|
||||
58
fwl.cabal
Normal file
58
fwl.cabal
Normal file
@@ -0,0 +1,58 @@
|
||||
cabal-version: 3.0
|
||||
name: fwl
|
||||
version: 0.1.0.0
|
||||
synopsis: Firewall Language — MVP
|
||||
build-type: Simple
|
||||
|
||||
common shared
|
||||
ghc-options: -Wall
|
||||
default-language: Haskell2010
|
||||
|
||||
library
|
||||
import: shared
|
||||
hs-source-dirs: src
|
||||
exposed-modules:
|
||||
FWL.AST
|
||||
, FWL.Lexer
|
||||
, FWL.Parser
|
||||
, FWL.Pretty
|
||||
, FWL.Check
|
||||
, FWL.Compile
|
||||
build-depends:
|
||||
base >= 4.14
|
||||
, parsec >= 3.1
|
||||
, aeson >= 2.0
|
||||
, aeson-pretty >= 0.8
|
||||
, text >= 1.2
|
||||
, containers >= 0.6
|
||||
, mtl >= 2.2
|
||||
, prettyprinter >= 1.7
|
||||
, bytestring >= 0.11
|
||||
, word8 >= 0.1
|
||||
|
||||
executable fwlc
|
||||
import: shared
|
||||
main-is: Main.hs
|
||||
hs-source-dirs: app
|
||||
build-depends:
|
||||
base, fwl, text, aeson-pretty, bytestring
|
||||
|
||||
test-suite fwl-tests
|
||||
import: shared
|
||||
type: exitcode-stdio-1.0
|
||||
main-is: Spec.hs
|
||||
hs-source-dirs: test
|
||||
other-modules:
|
||||
FWL.Util
|
||||
, ParserTests
|
||||
, CheckTests
|
||||
, CompileTests
|
||||
build-depends:
|
||||
base, fwl
|
||||
, tasty >= 1.4
|
||||
, tasty-hunit >= 0.10
|
||||
, aeson >= 2.0
|
||||
, aeson-pretty >= 0.8
|
||||
, bytestring >= 0.11
|
||||
, parsec >= 3.1
|
||||
, vector >= 0.12
|
||||
233
src/FWL/AST.hs
Normal file
233
src/FWL/AST.hs
Normal file
@@ -0,0 +1,233 @@
|
||||
module FWL.AST where
|
||||
|
||||
import Data.Bits ((.&.), (.|.), shiftL, shiftR)
|
||||
import Data.Word (Word8) -- Word8 still used for ByteElem/hex literals
|
||||
|
||||
type Name = String
|
||||
|
||||
-- ─── Program ────────────────────────────────────────────────────────────────
|
||||
|
||||
data Program = Program
|
||||
{ progConfig :: Config
|
||||
, progDecls :: [Decl]
|
||||
} deriving (Show)
|
||||
|
||||
data Config = Config
|
||||
{ configTable :: String -- default "fwl"
|
||||
} deriving (Show)
|
||||
|
||||
defaultConfig :: Config
|
||||
defaultConfig = Config { configTable = "fwl" }
|
||||
|
||||
-- ─── Declarations ───────────────────────────────────────────────────────────
|
||||
|
||||
data Decl
|
||||
= DInterface Name IfaceKind [IfaceProp]
|
||||
| DZone Name [Name]
|
||||
| DImport Name Type FilePath
|
||||
| DLet Name Type Expr
|
||||
| DPattern Name Type Pat
|
||||
| DFlow Name FlowExpr
|
||||
| DRule Name Type Expr
|
||||
| DPolicy Name Type PolicyMeta ArmBlock
|
||||
deriving (Show)
|
||||
|
||||
data PolicyMeta = PolicyMeta
|
||||
{ pmHook :: Hook
|
||||
, pmTable :: TableName
|
||||
, pmPriority :: Priority
|
||||
} deriving (Show)
|
||||
|
||||
data Hook = HInput | HForward | HOutput | HPrerouting | HPostrouting
|
||||
deriving (Show, Eq)
|
||||
data TableName = TFilter | TNAT
|
||||
deriving (Show, Eq)
|
||||
-- Priority is always an integer in the nftables JSON.
|
||||
-- Named constants are resolved to their numeric values at parse time.
|
||||
newtype Priority = Priority { priorityValue :: Int }
|
||||
deriving (Show, Eq)
|
||||
|
||||
-- Standard nftables priority constants
|
||||
pRaw, pConnTrackDefrag, pConnTrack, pMangle, pDstNat, pFilter, pSecurity, pSrcNat :: Priority
|
||||
pRaw = Priority (-300)
|
||||
pConnTrackDefrag = Priority (-400)
|
||||
pConnTrack = Priority (-200)
|
||||
pMangle = Priority (-150)
|
||||
pDstNat = Priority (-100)
|
||||
pFilter = Priority 0
|
||||
pSecurity = Priority 50
|
||||
pSrcNat = Priority 100
|
||||
|
||||
data IfaceKind = IWan | ILan | IWireGuard | IUser Name
|
||||
deriving (Show)
|
||||
|
||||
data IfaceProp
|
||||
= IPDynamic
|
||||
| IPCidr4 [CIDR]
|
||||
| IPCidr6 [CIDR]
|
||||
deriving (Show)
|
||||
|
||||
-- | A CIDR block: base address literal paired with prefix length.
|
||||
-- e.g. (LIPv4 (10,0,0,0), 8) represents 10.0.0.0/8
|
||||
type CIDR = (Literal, Int)
|
||||
|
||||
-- ─── Patterns ───────────────────────────────────────────────────────────────
|
||||
|
||||
data Pat
|
||||
= PWild
|
||||
| PVar Name
|
||||
| PNamed Name
|
||||
| PCtor Name [Pat]
|
||||
| PRecord Name [FieldPat]
|
||||
| PTuple [Pat]
|
||||
| PFrame (Maybe PathPat) Pat
|
||||
| PBytes [ByteElem]
|
||||
deriving (Show)
|
||||
|
||||
data FieldPat
|
||||
= FPEq Name Literal
|
||||
| FPBind Name
|
||||
| FPAs Name Name
|
||||
deriving (Show)
|
||||
|
||||
data PathPat = PathPat (Maybe EndpointPat) (Maybe EndpointPat)
|
||||
deriving (Show)
|
||||
|
||||
data EndpointPat
|
||||
= EPWild
|
||||
| EPName Name
|
||||
| EPMember Name Name
|
||||
deriving (Show)
|
||||
|
||||
data ByteElem
|
||||
= BEHex Word8
|
||||
| BEWild
|
||||
| BEWildStar
|
||||
deriving (Show)
|
||||
|
||||
-- ─── Flow ───────────────────────────────────────────────────────────────────
|
||||
|
||||
data FlowExpr
|
||||
= FAtom Name
|
||||
| FSeq FlowExpr FlowExpr (Maybe Duration)
|
||||
deriving (Show)
|
||||
|
||||
type Duration = (Int, TimeUnit)
|
||||
|
||||
-- Fix 1: TimeUnit must derive Eq because Literal (which embeds it via
|
||||
-- LDuration) derives Eq, requiring all constituent types to also have Eq.
|
||||
data TimeUnit = Seconds | Millis | Minutes | Hours
|
||||
deriving (Show, Eq)
|
||||
|
||||
-- ─── Types ──────────────────────────────────────────────────────────────────
|
||||
|
||||
data Type
|
||||
= TName Name [Type]
|
||||
| TTuple [Type]
|
||||
| TFun Type Type
|
||||
| TEffect [Name] Type
|
||||
deriving (Show)
|
||||
|
||||
-- ─── Expressions ────────────────────────────────────────────────────────────
|
||||
|
||||
data Expr
|
||||
= EVar Name
|
||||
| EQual [Name]
|
||||
| ELit Literal
|
||||
| ELam Name Expr
|
||||
| EApp Expr Expr
|
||||
| ECase Expr ArmBlock
|
||||
| EIf Expr Expr Expr
|
||||
| EDo [DoStmt]
|
||||
| ELet Name Expr Expr
|
||||
| ETuple [Expr]
|
||||
| ESet [Expr]
|
||||
| EMap [(Expr, Expr)]
|
||||
| EPerform [Name] [Expr]
|
||||
| EInfix InfixOp Expr Expr
|
||||
| ENot Expr
|
||||
deriving (Show)
|
||||
|
||||
data InfixOp
|
||||
= OpAnd | OpOr
|
||||
| OpEq | OpNeq | OpLt | OpLte | OpGt | OpGte
|
||||
| OpIn
|
||||
| OpConcat
|
||||
| OpThen
|
||||
| OpBind
|
||||
deriving (Show, Eq)
|
||||
|
||||
data DoStmt
|
||||
= DSBind Name Expr
|
||||
| DSExpr Expr
|
||||
deriving (Show)
|
||||
|
||||
type ArmBlock = [Arm]
|
||||
data Arm = Arm Pat (Maybe Expr) Expr
|
||||
deriving (Show)
|
||||
|
||||
-- ─── Literals ───────────────────────────────────────────────────────────────
|
||||
|
||||
-- IP addresses are stored as plain Integers for easy arithmetic,
|
||||
-- CIDR validation (mask host bits), and future subnet math.
|
||||
-- IPv4: 32-bit value in the low 32 bits.
|
||||
-- IPv6: 128-bit value.
|
||||
-- CIDR host-bit validation: (addr .&. hostMask prefix bits) == 0
|
||||
data IPVersion = IPv4 | IPv6
|
||||
deriving (Show, Eq)
|
||||
|
||||
data Literal
|
||||
= LInt Int
|
||||
| LString String
|
||||
| LBool Bool
|
||||
| LIP IPVersion Integer -- unified IP address representation
|
||||
| LCIDR Literal Int -- base address + prefix length
|
||||
| LPort Int
|
||||
| LDuration Int TimeUnit
|
||||
| LHex Word8
|
||||
deriving (Show, Eq)
|
||||
|
||||
-- ─── IP address helpers ──────────────────────────────────────────────────────
|
||||
|
||||
-- | Build an IPv4 literal from four octets.
|
||||
ipv4Lit :: Int -> Int -> Int -> Int -> Literal
|
||||
ipv4Lit a b c d =
|
||||
LIP IPv4 (fromIntegral a `shiftL` 24
|
||||
.|. fromIntegral b `shiftL` 16
|
||||
.|. fromIntegral c `shiftL` 8
|
||||
.|. fromIntegral d)
|
||||
|
||||
-- | Check that a CIDR has no host bits set.
|
||||
cidrHostBitsZero :: Integer -> Int -> Int -> Bool
|
||||
cidrHostBitsZero addr prefix bits =
|
||||
let hostBits = bits - prefix
|
||||
hostMask = (1 `shiftL` hostBits) - 1
|
||||
in (addr .&. hostMask) == 0
|
||||
|
||||
-- | Render an IPv4 integer as a dotted-decimal string.
|
||||
renderIPv4 :: Integer -> String
|
||||
renderIPv4 n =
|
||||
show ((n `shiftR` 24) .&. 0xff) ++ "." ++
|
||||
show ((n `shiftR` 16) .&. 0xff) ++ "." ++
|
||||
show ((n `shiftR` 8) .&. 0xff) ++ "." ++
|
||||
show (n .&. 0xff)
|
||||
|
||||
-- | Render an IPv6 integer as a condensed colon-hex string.
|
||||
renderIPv6 :: Integer -> String
|
||||
renderIPv6 n =
|
||||
let groups = [ fromIntegral ((n `shiftR` (i * 16)) .&. 0xffff) :: Int
|
||||
| i <- [7,6..0] ]
|
||||
hexGroups = map (`showHex` "") groups
|
||||
in concatIntersperse ":" hexGroups
|
||||
where
|
||||
showHex x s = let h = showHexInt x in h ++ s
|
||||
showHexInt x
|
||||
| x == 0 = "0"
|
||||
| otherwise = reverse (go x)
|
||||
where go 0 = []
|
||||
go v = let (q,r) = v `divMod` 16
|
||||
c = "0123456789abcdef" !! r
|
||||
in c : go q
|
||||
concatIntersperse _ [] = ""
|
||||
concatIntersperse _ [x] = x
|
||||
concatIntersperse s (x:xs) = x ++ s ++ concatIntersperse s xs
|
||||
207
src/FWL/Check.hs
Normal file
207
src/FWL/Check.hs
Normal file
@@ -0,0 +1,207 @@
|
||||
{- | Static checks for MVP:
|
||||
1. Undefined name detection (interfaces, zones, patterns, rules/policies)
|
||||
2. Policy arm termination: last arm of a policy must not be Continue
|
||||
3. Named pattern cycle detection
|
||||
4. CIDR exhaustiveness stub (warns but does not error for MVP)
|
||||
-}
|
||||
module FWL.Check
|
||||
( checkProgram
|
||||
, CheckError(..)
|
||||
) where
|
||||
|
||||
import Data.List (foldl', nub)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import qualified Data.Set as Set
|
||||
|
||||
import FWL.AST
|
||||
|
||||
data CheckError
|
||||
= UndefinedName String String -- kind, name
|
||||
| PolicyNoContinue String -- policy name
|
||||
| PatternCycle [String] -- cycle path
|
||||
| DuplicateDecl String String -- kind, name
|
||||
deriving (Show, Eq)
|
||||
|
||||
type Env = Map.Map String DeclKind
|
||||
data DeclKind = KInterface | KZone | KLet | KPattern | KFlow | KRule | KPolicy
|
||||
deriving (Show, Eq)
|
||||
|
||||
checkProgram :: Program -> [CheckError]
|
||||
checkProgram (Program _ decls) =
|
||||
dupErrs ++ nameErrs ++ policyErrs ++ cycleErrs
|
||||
where
|
||||
env = buildEnv decls
|
||||
dupErrs = findDups decls
|
||||
nameErrs = concatMap (checkDecl env) decls
|
||||
policyErrs = concatMap checkPolicyTermination decls
|
||||
cycleErrs = checkPatternCycles decls
|
||||
|
||||
-- ─── Environment ─────────────────────────────────────────────────────────────
|
||||
|
||||
buildEnv :: [Decl] -> Env
|
||||
buildEnv = foldl' addDecl Map.empty
|
||||
where
|
||||
addDecl m (DInterface n _ _) = Map.insert n KInterface m
|
||||
addDecl m (DZone n _) = Map.insert n KZone m
|
||||
addDecl m (DLet n _ _) = Map.insert n KLet m
|
||||
addDecl m (DPattern n _ _) = Map.insert n KPattern m
|
||||
addDecl m (DFlow n _) = Map.insert n KFlow m
|
||||
addDecl m (DRule n _ _) = Map.insert n KRule m
|
||||
addDecl m (DPolicy n _ _ _) = Map.insert n KPolicy m
|
||||
addDecl m _ = m
|
||||
|
||||
findDups :: [Decl] -> [CheckError]
|
||||
findDups decls = go [] Set.empty decls
|
||||
where
|
||||
go acc _ [] = acc
|
||||
go acc seen (d:ds) =
|
||||
let n = declName d in
|
||||
if Set.member n seen
|
||||
then go (DuplicateDecl (declKindStr d) n : acc) seen ds
|
||||
else go acc (Set.insert n seen) ds
|
||||
|
||||
declName :: Decl -> String
|
||||
declName (DInterface n _ _) = n
|
||||
declName (DZone n _) = n
|
||||
declName (DImport n _ _) = n
|
||||
declName (DLet n _ _) = n
|
||||
declName (DPattern n _ _) = n
|
||||
declName (DFlow n _) = n
|
||||
declName (DRule n _ _) = n
|
||||
declName (DPolicy n _ _ _) = n
|
||||
|
||||
declKindStr :: Decl -> String
|
||||
declKindStr (DInterface _ _ _) = "interface"
|
||||
declKindStr (DZone _ _) = "zone"
|
||||
declKindStr (DImport _ _ _) = "import"
|
||||
declKindStr (DLet _ _ _) = "let"
|
||||
declKindStr (DPattern _ _ _) = "pattern"
|
||||
declKindStr (DFlow _ _) = "flow"
|
||||
declKindStr (DRule _ _ _) = "rule"
|
||||
declKindStr (DPolicy _ _ _ _) = "policy"
|
||||
|
||||
-- ─── Name resolution ─────────────────────────────────────────────────────────
|
||||
|
||||
checkDecl :: Env -> Decl -> [CheckError]
|
||||
checkDecl env (DZone _ ns) = concatMap (checkName env "interface or zone") ns
|
||||
checkDecl env (DPattern _ _ p) = checkPat env p
|
||||
checkDecl env (DFlow _ fe) = checkFlow env fe
|
||||
checkDecl env (DRule _ _ e) = checkExpr env e
|
||||
checkDecl env (DPolicy _ _ _ ab) = concatMap (checkArm env) ab
|
||||
checkDecl env (DLet _ _ e) = checkExpr env e
|
||||
checkDecl _ _ = []
|
||||
|
||||
checkName :: Env -> String -> String -> [CheckError]
|
||||
checkName env kind n
|
||||
| Map.member n env = []
|
||||
| isBuiltin n = []
|
||||
| otherwise = [UndefinedName kind n]
|
||||
|
||||
isBuiltin :: String -> Bool
|
||||
isBuiltin n = n `elem`
|
||||
[ "ct", "iif", "oif", "lo", "wan", "lan"
|
||||
, "tcp", "udp", "ip", "ip6", "eth"
|
||||
, "Established", "Related", "DNAT"
|
||||
, "Allow", "Drop", "Continue", "Masquerade"
|
||||
, "Matched", "Unmatched"
|
||||
, "true", "false"
|
||||
]
|
||||
|
||||
checkPat :: Env -> Pat -> [CheckError]
|
||||
checkPat _ PWild = []
|
||||
checkPat _ (PVar _) = []
|
||||
checkPat env (PNamed n) = checkName env "pattern" n
|
||||
checkPat env (PCtor _ ps) = concatMap (checkPat env) ps
|
||||
checkPat env (PRecord _ fs) = concatMap (checkFP env) fs
|
||||
checkPat env (PTuple ps) = concatMap (checkPat env) ps
|
||||
checkPat env (PFrame mp inner)= maybe [] (checkPath env) mp ++ checkPat env inner
|
||||
checkPat _ (PBytes _) = []
|
||||
|
||||
checkFP :: Env -> FieldPat -> [CheckError]
|
||||
checkFP _ _ = [] -- field names checked by type-checker later
|
||||
|
||||
checkPath :: Env -> PathPat -> [CheckError]
|
||||
checkPath env (PathPat ms md) =
|
||||
maybe [] (checkEP env) ms ++ maybe [] (checkEP env) md
|
||||
|
||||
checkEP :: Env -> EndpointPat -> [CheckError]
|
||||
checkEP _ EPWild = []
|
||||
checkEP env (EPName n) = checkName env "interface or zone" n
|
||||
checkEP env (EPMember _ z) = checkName env "zone" z
|
||||
|
||||
checkFlow :: Env -> FlowExpr -> [CheckError]
|
||||
checkFlow env (FAtom n) = checkName env "pattern" n
|
||||
checkFlow env (FSeq a b _) = checkFlow env a ++ checkFlow env b
|
||||
|
||||
checkArm :: Env -> Arm -> [CheckError]
|
||||
checkArm env (Arm p mg e) =
|
||||
checkPat env p ++
|
||||
maybe [] (checkExpr env) mg ++
|
||||
checkExpr env e
|
||||
|
||||
checkExpr :: Env -> Expr -> [CheckError]
|
||||
checkExpr env (EVar n) = checkName env "name" n
|
||||
checkExpr _ (EQual _) = [] -- qualified names: deferred
|
||||
checkExpr _ (ELit _) = []
|
||||
checkExpr env (ELam _ e) = checkExpr env e
|
||||
checkExpr env (EApp f x) = checkExpr env f ++ checkExpr env x
|
||||
checkExpr env (ECase e ab) = checkExpr env e ++ concatMap (checkArm env) ab
|
||||
checkExpr env (EIf c t f) = concatMap (checkExpr env) [c,t,f]
|
||||
checkExpr env (EDo ss) = concatMap (checkStmt env) ss
|
||||
checkExpr env (ELet _ e1 e2) = checkExpr env e1 ++ checkExpr env e2
|
||||
checkExpr env (ETuple es) = concatMap (checkExpr env) es
|
||||
checkExpr env (ESet es) = concatMap (checkExpr env) es
|
||||
checkExpr env (EMap ms) = concatMap (\(k,v) -> checkExpr env k ++ checkExpr env v) ms
|
||||
checkExpr env (EPerform _ as_) = concatMap (checkExpr env) as_
|
||||
checkExpr env (EInfix _ l r) = checkExpr env l ++ checkExpr env r
|
||||
checkExpr env (ENot e) = checkExpr env e
|
||||
|
||||
checkStmt :: Env -> DoStmt -> [CheckError]
|
||||
checkStmt env (DSBind _ e) = checkExpr env e
|
||||
checkStmt env (DSExpr e) = checkExpr env e
|
||||
|
||||
-- ─── Policy termination ───────────────────────────────────────────────────────
|
||||
|
||||
-- The last arm of a policy block must not unconditionally return Continue.
|
||||
checkPolicyTermination :: Decl -> [CheckError]
|
||||
checkPolicyTermination (DPolicy n _ _ arms)
|
||||
| null arms = [PolicyNoContinue n]
|
||||
| isContinue (last arms) = [PolicyNoContinue n]
|
||||
| otherwise = []
|
||||
where
|
||||
isContinue (Arm PWild Nothing (EVar "Continue")) = True
|
||||
isContinue _ = False
|
||||
checkPolicyTermination _ = []
|
||||
|
||||
-- ─── Pattern cycle detection ─────────────────────────────────────────────────
|
||||
|
||||
checkPatternCycles :: [Decl] -> [CheckError]
|
||||
checkPatternCycles decls =
|
||||
[ PatternCycle c
|
||||
| c <- findCycles graph
|
||||
]
|
||||
where
|
||||
patDecls = [(n, p) | DPattern n _ p <- decls]
|
||||
graph = Map.fromList [(n, nub (refsInPat p)) | (n,p) <- patDecls]
|
||||
allPats = Set.fromList (map fst patDecls)
|
||||
|
||||
refsInPat :: Pat -> [String]
|
||||
refsInPat (PNamed r) = [r | Set.member r allPats]
|
||||
refsInPat (PCtor _ ps) = concatMap refsInPat ps
|
||||
refsInPat (PTuple ps) = concatMap refsInPat ps
|
||||
refsInPat (PFrame _ p) = refsInPat p
|
||||
refsInPat _ = []
|
||||
|
||||
findCycles :: Map.Map String [String] -> [[String]]
|
||||
findCycles graph = go Set.empty Set.empty [] (Map.keys graph)
|
||||
where
|
||||
go _ _ _ [] = []
|
||||
go visited onPath path (n:ns)
|
||||
| Set.member n visited = go visited onPath path ns
|
||||
| Set.member n onPath = [path]
|
||||
| otherwise =
|
||||
let onPath' = Set.insert n onPath
|
||||
path' = path ++ [n]
|
||||
deps = Map.findWithDefault [] n graph
|
||||
cycles = go visited onPath' path' deps
|
||||
in cycles ++ go (Set.insert n visited) onPath path ns
|
||||
313
src/FWL/Compile.hs
Normal file
313
src/FWL/Compile.hs
Normal file
@@ -0,0 +1,313 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{- | Compile a checked FWL program to nftables JSON using Aeson.
|
||||
All policies (Filter and NAT) go into one table named by Config.
|
||||
Layer stripping: Frame patterns that omit Ether compile identically
|
||||
to those that include it.
|
||||
-}
|
||||
module FWL.Compile
|
||||
( compileProgram
|
||||
, compileToJson
|
||||
) where
|
||||
|
||||
import Data.List (intercalate)
|
||||
import Data.Maybe (mapMaybe)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.Aeson ((.=), Value(..), object, toJSON)
|
||||
import qualified Data.Aeson as A
|
||||
import qualified Data.Text as T
|
||||
import qualified Data.ByteString.Lazy as BL
|
||||
import Data.Aeson.Encode.Pretty (encodePretty)
|
||||
|
||||
import FWL.AST
|
||||
|
||||
-- ─── Entry points ────────────────────────────────────────────────────────────
|
||||
|
||||
compileToJson :: Program -> BL.ByteString
|
||||
compileToJson = encodePretty . programToValue
|
||||
|
||||
-- exposed for tests
|
||||
compileProgram :: Program -> Value
|
||||
compileProgram = programToValue
|
||||
|
||||
programToValue :: Program -> Value
|
||||
programToValue (Program cfg decls) =
|
||||
object [ "nftables" .= toJSON
|
||||
(metainfo : tableObj : chainObjs ++ mapObjs ++ ruleObjs) ]
|
||||
where
|
||||
env = buildEnv decls
|
||||
tbl = configTable cfg
|
||||
|
||||
metainfo = object [ "metainfo" .= object
|
||||
[ "json_schema_version" .= (1 :: Int) ] ]
|
||||
tableObj = object [ "table" .= tableValue tbl ]
|
||||
|
||||
policies = [ (n, pm, ab) | DPolicy n _ pm ab <- decls ]
|
||||
chainObjs = map (\(n, pm, _ ) -> chainDeclValue tbl n pm) policies
|
||||
ruleObjs = concatMap
|
||||
(\(n, _, ab) -> concatMap (armToRuleValues env tbl n) ab)
|
||||
policies
|
||||
|
||||
letDecls = [ (n, t, e) | DLet n t e <- decls ]
|
||||
mapObjs = mapMaybe (\(n, _, e) -> letToMapValue tbl n e) letDecls
|
||||
|
||||
-- ─── Table / Chain declarations ──────────────────────────────────────────────
|
||||
|
||||
tableValue :: String -> Value
|
||||
tableValue tbl = object
|
||||
[ "family" .= ("inet" :: String)
|
||||
, "name" .= tbl
|
||||
]
|
||||
|
||||
chainDeclValue :: String -> Name -> PolicyMeta -> Value
|
||||
chainDeclValue tbl n pm = object
|
||||
[ "chain" .= object
|
||||
[ "family" .= ("inet" :: String)
|
||||
, "table" .= tbl
|
||||
, "name" .= n
|
||||
, "type" .= chainTypeStr (pmTable pm)
|
||||
, "hook" .= hookStr (pmHook pm)
|
||||
, "prio" .= priorityInt (pmPriority pm)
|
||||
, "policy" .= defaultPolicyStr (pmHook pm)
|
||||
]
|
||||
]
|
||||
|
||||
chainTypeStr :: TableName -> String
|
||||
chainTypeStr TFilter = "filter"
|
||||
chainTypeStr TNAT = "nat"
|
||||
|
||||
hookStr :: Hook -> String
|
||||
hookStr HInput = "input"
|
||||
hookStr HForward = "forward"
|
||||
hookStr HOutput = "output"
|
||||
hookStr HPrerouting = "prerouting"
|
||||
hookStr HPostrouting = "postrouting"
|
||||
|
||||
-- Priority is emitted as an integer in nftables JSON.
|
||||
priorityInt :: Priority -> Int
|
||||
priorityInt = priorityValue
|
||||
|
||||
defaultPolicyStr :: Hook -> String
|
||||
defaultPolicyStr HInput = "drop"
|
||||
defaultPolicyStr HForward = "drop"
|
||||
defaultPolicyStr _ = "accept"
|
||||
|
||||
-- ─── Arm → Rule objects ──────────────────────────────────────────────────────
|
||||
|
||||
armToRuleValues :: CompileEnv -> String -> Name -> Arm -> [Value]
|
||||
armToRuleValues env tbl chain (Arm p mg body) =
|
||||
case compileAction env body of
|
||||
Nothing -> []
|
||||
Just verdict ->
|
||||
let patExprs = compilePat env p
|
||||
guardExprs = maybe [] (compileGuard env) mg
|
||||
allExprs = patExprs ++ guardExprs ++ [verdict]
|
||||
in [ object
|
||||
[ "rule" .= object
|
||||
[ "family" .= ("inet" :: String)
|
||||
, "table" .= tbl
|
||||
, "chain" .= chain
|
||||
, "expr" .= toJSON allExprs
|
||||
]
|
||||
]
|
||||
]
|
||||
|
||||
-- ─── Pattern → [Value] ───────────────────────────────────────────────────────
|
||||
|
||||
type CompileEnv = Map.Map String Decl
|
||||
|
||||
buildEnv :: [Decl] -> CompileEnv
|
||||
buildEnv = foldr (\d m -> Map.insert (declNameOf d) d m) Map.empty
|
||||
where
|
||||
declNameOf (DInterface n _ _) = n
|
||||
declNameOf (DZone n _) = n
|
||||
declNameOf (DPattern n _ _) = n
|
||||
declNameOf (DFlow n _) = n
|
||||
declNameOf (DRule n _ _) = n
|
||||
declNameOf (DPolicy n _ _ _) = n
|
||||
declNameOf (DLet n _ _) = n
|
||||
declNameOf (DImport n _ _) = n
|
||||
|
||||
compilePat :: CompileEnv -> Pat -> [Value]
|
||||
compilePat _ PWild = []
|
||||
compilePat _ (PVar _) = []
|
||||
compilePat env (PNamed n) = expandNamedPat env n
|
||||
compilePat env (PFrame mp inner) =
|
||||
maybe [] (compilePathPat env) mp ++ compilePat env inner
|
||||
compilePat env (PCtor n ps) = compileCtorPat env n ps
|
||||
compilePat _ (PRecord n fs) = compileRecordPat n fs
|
||||
compilePat env (PTuple ps) = concatMap (compilePat env) ps
|
||||
compilePat _ (PBytes _) = []
|
||||
|
||||
expandNamedPat :: CompileEnv -> Name -> [Value]
|
||||
expandNamedPat env n =
|
||||
case Map.lookup n env of
|
||||
Just (DPattern _ _ p) -> compilePat env p
|
||||
_ -> []
|
||||
|
||||
compileCtorPat :: CompileEnv -> String -> [Pat] -> [Value]
|
||||
compileCtorPat env ctor ps = case ctor of
|
||||
"Ether" -> children
|
||||
"IPv4" -> matchMeta "nfproto" "ipv4" : children
|
||||
"IPv6" -> matchMeta "nfproto" "ipv6" : children
|
||||
"TCP" -> matchPayload "th" "protocol" "tcp" : children
|
||||
"UDP" -> matchPayload "th" "protocol" "udp" : children
|
||||
"ICMPv6" -> matchPayload "ip6" "nexthdr" "ipv6-icmp" : children
|
||||
"ICMP" -> matchPayload "ip" "protocol" "icmp" : children
|
||||
_ -> children
|
||||
where
|
||||
children = concatMap (compilePat env) ps
|
||||
|
||||
compileRecordPat :: String -> [FieldPat] -> [Value]
|
||||
compileRecordPat proto = mapMaybe go
|
||||
where
|
||||
go (FPEq field lit) = Just (matchPayload proto field (renderLit lit))
|
||||
go _ = Nothing
|
||||
|
||||
compilePathPat :: CompileEnv -> PathPat -> [Value]
|
||||
compilePathPat _ (PathPat ms md) =
|
||||
maybe [] (compileEndpoint "iifname") ms ++
|
||||
maybe [] (compileEndpoint "oifname") md
|
||||
|
||||
compileEndpoint :: String -> EndpointPat -> [Value]
|
||||
compileEndpoint _ EPWild = []
|
||||
compileEndpoint dir (EPName n) = [matchMeta dir n]
|
||||
compileEndpoint dir (EPMember _ z) = [matchInSet (metaVal dir) [z]]
|
||||
|
||||
-- ─── Guard → [Value] ─────────────────────────────────────────────────────────
|
||||
|
||||
compileGuard :: CompileEnv -> Expr -> [Value]
|
||||
compileGuard env (EInfix OpAnd l r) = compileGuard env l ++ compileGuard env r
|
||||
compileGuard _ (EInfix OpIn l r) = [compileInExpr l r]
|
||||
compileGuard _ (EInfix OpEq l r) = [matchExpr "==" (exprVal l) (exprVal r)]
|
||||
compileGuard _ (EInfix OpNeq l r) = [matchExpr "!=" (exprVal l) (exprVal r)]
|
||||
compileGuard _ _ = []
|
||||
|
||||
compileInExpr :: Expr -> Expr -> Value
|
||||
-- Fix 4: put the more-specific ct patterns BEFORE the generic 2-element
|
||||
-- EQual case to eliminate the overlapping pattern match warning.
|
||||
compileInExpr (EQual ["ct", "state"]) (ESet vs) = ctMatch "state" vs
|
||||
compileInExpr (EQual ["ct", "status"]) (ESet vs) = ctMatch "status" vs
|
||||
compileInExpr l (ESet vs) =
|
||||
matchExpr "in" (exprVal l) (setVal (map exprToStr vs))
|
||||
compileInExpr l r =
|
||||
matchExpr "==" (exprVal l) (exprVal r)
|
||||
|
||||
ctMatch :: String -> [Expr] -> Value
|
||||
ctMatch key vs = matchExpr "in"
|
||||
(object ["ct" .= object ["key" .= (key :: String)]])
|
||||
(setVal (map exprToStr vs))
|
||||
|
||||
-- ─── Action → Maybe Value ─────────────────────────────────────────────────────
|
||||
|
||||
compileAction :: CompileEnv -> Expr -> Maybe Value
|
||||
compileAction _ (EVar "Allow") = Just (object ["accept" .= Null])
|
||||
compileAction _ (EVar "Drop") = Just (object ["drop" .= Null])
|
||||
compileAction _ (EVar "Continue") = Nothing
|
||||
compileAction _ (EVar "Masquerade") = Just (object ["masquerade" .= Null])
|
||||
compileAction _ (EApp (EVar "DNAT") arg) =
|
||||
Just $ object ["dnat" .= object ["addr" .= exprToStr arg]]
|
||||
compileAction _ (EApp (EVar "DNATMap") arg) =
|
||||
Just $ object ["dnat" .= object ["addr" .= object
|
||||
[ "map" .= object [ "key" .= object ["concat" .= Array mempty]
|
||||
, "data" .= exprToStr arg ]]]]
|
||||
compileAction env (EApp (EVar rn) _) =
|
||||
case Map.lookup rn env of
|
||||
Just (DRule _ _ _) -> Just $ object ["jump" .= object ["target" .= rn]]
|
||||
_ -> Just (object ["accept" .= Null])
|
||||
compileAction _ _ = Just (object ["accept" .= Null])
|
||||
|
||||
-- ─── Let → Map object ────────────────────────────────────────────────────────
|
||||
|
||||
letToMapValue :: String -> Name -> Expr -> Maybe Value
|
||||
letToMapValue tbl n (EMap entries) = Just $ object
|
||||
[ "map" .= object
|
||||
[ "family" .= ("inet" :: String)
|
||||
, "table" .= tbl
|
||||
, "name" .= n
|
||||
, "type" .= ("inetproto . inetservice" :: String)
|
||||
, "map" .= ("ipv4_addr . inetservice" :: String)
|
||||
, "elem" .= toJSON (map renderMapElem entries)
|
||||
]
|
||||
]
|
||||
letToMapValue _ _ _ = Nothing
|
||||
|
||||
renderMapElem :: (Expr, Expr) -> Value
|
||||
renderMapElem (k, v) = toJSON
|
||||
[ object ["concat" .= toJSON [exprToStr k]]
|
||||
, A.String (toText (exprToStr v))
|
||||
]
|
||||
|
||||
-- ─── Aeson building blocks ───────────────────────────────────────────────────
|
||||
|
||||
matchExpr :: String -> Value -> Value -> Value
|
||||
matchExpr op l r = object
|
||||
[ "match" .= object
|
||||
[ "op" .= (op :: String)
|
||||
, "left" .= l
|
||||
, "right" .= r
|
||||
]
|
||||
]
|
||||
|
||||
matchMeta :: String -> String -> Value
|
||||
matchMeta key val = matchExpr "==" (metaVal key) (A.String (toText val))
|
||||
|
||||
matchPayload :: String -> String -> String -> Value
|
||||
matchPayload proto field val =
|
||||
matchExpr "==" (payloadVal proto field) (A.String (toText val))
|
||||
|
||||
matchInSet :: Value -> [String] -> Value
|
||||
matchInSet lhs vals = matchExpr "in" lhs (setVal vals)
|
||||
|
||||
metaVal :: String -> Value
|
||||
metaVal key = object ["meta" .= object ["key" .= (key :: String)]]
|
||||
|
||||
payloadVal :: String -> String -> Value
|
||||
payloadVal proto field =
|
||||
object ["payload" .= object
|
||||
[ "protocol" .= (proto :: String)
|
||||
, "field" .= (field :: String)
|
||||
]]
|
||||
|
||||
setVal :: [String] -> Value
|
||||
setVal vs = object ["set" .= toJSON vs]
|
||||
|
||||
-- ─── Expression helpers ───────────────────────────────────────────────────────
|
||||
|
||||
-- Fix 3 (overlap): specific ct pattern first, generic 2-element case second.
|
||||
exprVal :: Expr -> Value
|
||||
exprVal (EQual ["ct", k]) = object ["ct" .= object ["key" .= (k :: String)]]
|
||||
exprVal (EQual [p, f]) = payloadVal p f
|
||||
exprVal (EQual ns) = A.String (toText (intercalate "." ns))
|
||||
exprVal (EVar n) = metaVal n
|
||||
exprVal (ELit l) = A.String (toText (renderLit l))
|
||||
exprVal (ESet vs) = setVal (map exprToStr vs)
|
||||
exprVal e = A.String (toText (exprToStr e))
|
||||
|
||||
exprToStr :: Expr -> String
|
||||
exprToStr (EVar n) = n
|
||||
exprToStr (ELit l) = renderLit l
|
||||
exprToStr (EQual ns) = intercalate "." ns
|
||||
exprToStr (ETuple es) = intercalate " . " (map exprToStr es)
|
||||
exprToStr _ = "_"
|
||||
|
||||
-- Fix 2: Use Data.Text.pack via OverloadedStrings + fromString instead of
|
||||
-- the fragile read(show s) hack. With OverloadedStrings enabled, string
|
||||
-- literals already produce the correct Text/Key types; for runtime String
|
||||
toText :: String -> T.Text
|
||||
toText = T.pack
|
||||
|
||||
renderLit :: Literal -> String
|
||||
renderLit (LInt n) = show n
|
||||
renderLit (LString s) = s
|
||||
renderLit (LBool True) = "true"
|
||||
renderLit (LBool False) = "false"
|
||||
renderLit (LIPv4 (a, b, c, d)) =
|
||||
show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d
|
||||
renderLit (LIPv6 _) = "::1"
|
||||
renderLit (LCIDR ip p) = renderLit ip ++ "/" ++ show p
|
||||
renderLit (LPort p) = show p
|
||||
renderLit (LDuration n Seconds) = show n ++ "s"
|
||||
renderLit (LDuration n Millis) = show n ++ "ms"
|
||||
renderLit (LDuration n Minutes) = show n ++ "m"
|
||||
renderLit (LDuration n Hours) = show n ++ "h"
|
||||
renderLit (LHex b) = show b
|
||||
101
src/FWL/Lexer.hs
Normal file
101
src/FWL/Lexer.hs
Normal file
@@ -0,0 +1,101 @@
|
||||
module FWL.Lexer where
|
||||
|
||||
import Text.Parsec
|
||||
import Text.Parsec.String (Parser)
|
||||
import qualified Text.Parsec.Token as Tok
|
||||
import Text.Parsec.Language (emptyDef)
|
||||
|
||||
-- ─── Language definition ─────────────────────────────────────────────────────
|
||||
|
||||
fwlDef :: Tok.LanguageDef ()
|
||||
fwlDef = emptyDef
|
||||
{ Tok.commentLine = "--"
|
||||
, Tok.commentStart = "{-"
|
||||
, Tok.commentEnd = "-}"
|
||||
, Tok.identStart = letter <|> char '_'
|
||||
, Tok.identLetter = alphaNum <|> char '_'
|
||||
, Tok.reservedNames =
|
||||
-- Only genuine syntactic keywords belong here.
|
||||
-- Semantic values used as constructors, actions, type names, or
|
||||
-- pattern references (Allow, Drop, Log, Matched, Frame, etc.) must
|
||||
-- NOT be reserved so that `identifier` can consume them in those
|
||||
-- positions.
|
||||
[ "config", "table"
|
||||
, "interface", "zone", "import", "from"
|
||||
, "let", "in", "pattern", "flow", "rule", "policy", "on"
|
||||
, "case", "of", "if", "then", "else", "do", "perform"
|
||||
, "within", "as", "dynamic", "cidr4", "cidr6"
|
||||
, "hook", "priority"
|
||||
, "WAN", "LAN", "WireGuard"
|
||||
, "Input", "Forward", "Output", "Prerouting", "Postrouting"
|
||||
, "Filter", "NAT", "Mangle", "DstNat", "SrcNat", "Raw", "ConnTrack"
|
||||
, "true", "false"
|
||||
]
|
||||
, Tok.reservedOpNames =
|
||||
[ "->", "<-", "=>", "::", ":", "=", ".", ".."
|
||||
, "\\", "|", ","
|
||||
, "&&", "||", "!", "==" , "!=", "<", "<=", ">", ">="
|
||||
, "++", ">>", ">>="
|
||||
, "∈"
|
||||
]
|
||||
, Tok.caseSensitive = True
|
||||
}
|
||||
|
||||
lexer :: Tok.TokenParser ()
|
||||
lexer = Tok.makeTokenParser fwlDef
|
||||
|
||||
-- ─── Token helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
identifier :: Parser String
|
||||
identifier = Tok.identifier lexer
|
||||
|
||||
reserved :: String -> Parser ()
|
||||
reserved = Tok.reserved lexer
|
||||
|
||||
reservedOp :: String -> Parser ()
|
||||
reservedOp = Tok.reservedOp lexer
|
||||
|
||||
symbol :: String -> Parser String
|
||||
symbol = Tok.symbol lexer
|
||||
|
||||
parens :: Parser a -> Parser a
|
||||
parens = Tok.parens lexer
|
||||
|
||||
braces :: Parser a -> Parser a
|
||||
braces = Tok.braces lexer
|
||||
|
||||
angles :: Parser a -> Parser a
|
||||
angles = Tok.angles lexer
|
||||
|
||||
brackets :: Parser a -> Parser a
|
||||
brackets = Tok.brackets lexer
|
||||
|
||||
semi :: Parser String
|
||||
semi = Tok.semi lexer
|
||||
|
||||
comma :: Parser String
|
||||
comma = Tok.comma lexer
|
||||
|
||||
colon :: Parser String
|
||||
colon = Tok.colon lexer
|
||||
|
||||
dot :: Parser String
|
||||
dot = Tok.dot lexer
|
||||
|
||||
whiteSpace :: Parser ()
|
||||
whiteSpace = Tok.whiteSpace lexer
|
||||
|
||||
stringLit :: Parser String
|
||||
stringLit = Tok.stringLiteral lexer
|
||||
|
||||
natural :: Parser Integer
|
||||
natural = Tok.natural lexer
|
||||
|
||||
commaSep :: Parser a -> Parser [a]
|
||||
commaSep = Tok.commaSep lexer
|
||||
|
||||
commaSep1 :: Parser a -> Parser [a]
|
||||
commaSep1 = Tok.commaSep1 lexer
|
||||
|
||||
semiSep :: Parser a -> Parser [a]
|
||||
semiSep = Tok.semiSep lexer
|
||||
659
src/FWL/Parser.hs
Normal file
659
src/FWL/Parser.hs
Normal file
@@ -0,0 +1,659 @@
|
||||
module FWL.Parser
|
||||
( parseProgram
|
||||
, parseFile
|
||||
) where
|
||||
|
||||
import Control.Monad (void)
|
||||
import Data.Bits ((.&.), (.|.), shiftL)
|
||||
import Data.List (foldl')
|
||||
import Data.Word (Word8)
|
||||
import Numeric (readHex)
|
||||
import Text.Parsec
|
||||
import Text.Parsec.String (Parser)
|
||||
import Data.Functor.Identity (Identity)
|
||||
import qualified Text.Parsec.Expr as Ex
|
||||
|
||||
import FWL.AST
|
||||
import FWL.Lexer
|
||||
|
||||
-- ─── Entry points ────────────────────────────────────────────────────────────
|
||||
|
||||
parseProgram :: String -> String -> Either ParseError Program
|
||||
parseProgram src input = parse program src input
|
||||
|
||||
parseFile :: FilePath -> IO (Either ParseError Program)
|
||||
parseFile fp = parseProgram fp <$> readFile fp
|
||||
|
||||
-- ─── Top-level ───────────────────────────────────────────────────────────────
|
||||
|
||||
program :: Parser Program
|
||||
program = do
|
||||
whiteSpace
|
||||
cfg <- option defaultConfig configBlock
|
||||
ds <- many decl
|
||||
eof
|
||||
return (Program cfg ds)
|
||||
|
||||
configBlock :: Parser Config
|
||||
configBlock = do
|
||||
reserved "config"
|
||||
props <- braces (semiSep configProp)
|
||||
optional semi
|
||||
return $ foldr applyProp defaultConfig props
|
||||
where
|
||||
applyProp ("table", v) c = c { configTable = v }
|
||||
applyProp _ c = c
|
||||
|
||||
configProp :: Parser (String, String)
|
||||
configProp = do
|
||||
reserved "table"
|
||||
reservedOp "="
|
||||
v <- stringLit
|
||||
return ("table", v)
|
||||
|
||||
-- ─── Declarations ────────────────────────────────────────────────────────────
|
||||
|
||||
decl :: Parser Decl
|
||||
decl = interfaceDecl
|
||||
<|> zoneDecl
|
||||
<|> importDecl
|
||||
<|> letDecl
|
||||
<|> patternDecl
|
||||
<|> flowDecl
|
||||
<|> ruleDecl
|
||||
<|> policyDecl
|
||||
|
||||
interfaceDecl :: Parser Decl
|
||||
interfaceDecl = do
|
||||
reserved "interface"
|
||||
n <- identifier
|
||||
reservedOp ":"
|
||||
k <- ifaceKind
|
||||
ps <- braces (endBy ifaceProp semi)
|
||||
_ <- semi
|
||||
return (DInterface n k ps)
|
||||
|
||||
ifaceKind :: Parser IfaceKind
|
||||
ifaceKind = (reserved "WAN" >> return IWan)
|
||||
<|> (reserved "LAN" >> return ILan)
|
||||
<|> (reserved "WireGuard" >> return IWireGuard)
|
||||
<|> (IUser <$> identifier)
|
||||
|
||||
ifaceProp :: Parser IfaceProp
|
||||
ifaceProp = (reserved "dynamic" >> return IPDynamic)
|
||||
<|> (reserved "cidr4" >> reservedOp "=" >> IPCidr4 <$> cidrSet)
|
||||
<|> (reserved "cidr6" >> reservedOp "=" >> IPCidr6 <$> cidrSet)
|
||||
|
||||
cidrSet :: Parser [CIDR]
|
||||
cidrSet = braces (commaSep1 cidrLit)
|
||||
|
||||
zoneDecl :: Parser Decl
|
||||
zoneDecl = do
|
||||
reserved "zone"
|
||||
n <- identifier
|
||||
reservedOp "="
|
||||
ns <- braces (commaSep1 identifier)
|
||||
_ <- semi
|
||||
return (DZone n ns)
|
||||
|
||||
importDecl :: Parser Decl
|
||||
importDecl = do
|
||||
reserved "import"
|
||||
n <- identifier
|
||||
reservedOp ":"
|
||||
t <- typeP
|
||||
reserved "from"
|
||||
s <- stringLit
|
||||
_ <- semi
|
||||
return (DImport n t s)
|
||||
|
||||
letDecl :: Parser Decl
|
||||
letDecl = do
|
||||
reserved "let"
|
||||
n <- identifier
|
||||
reservedOp ":"
|
||||
t <- typeP
|
||||
reservedOp "="
|
||||
e <- expr
|
||||
_ <- semi
|
||||
return (DLet n t e)
|
||||
|
||||
patternDecl :: Parser Decl
|
||||
patternDecl = do
|
||||
reserved "pattern"
|
||||
n <- identifier
|
||||
reservedOp ":"
|
||||
t <- typeP
|
||||
reservedOp "="
|
||||
p <- pat
|
||||
_ <- semi
|
||||
return (DPattern n t p)
|
||||
|
||||
flowDecl :: Parser Decl
|
||||
flowDecl = do
|
||||
reserved "flow"
|
||||
n <- identifier
|
||||
reservedOp ":"
|
||||
reserved "FlowPattern"
|
||||
reservedOp "="
|
||||
f <- flowExpr
|
||||
_ <- semi
|
||||
return (DFlow n f)
|
||||
|
||||
ruleDecl :: Parser Decl
|
||||
ruleDecl = do
|
||||
reserved "rule"
|
||||
n <- identifier
|
||||
reservedOp ":"
|
||||
t <- typeP
|
||||
reservedOp "="
|
||||
e <- expr
|
||||
_ <- semi
|
||||
return (DRule n t e)
|
||||
|
||||
policyDecl :: Parser Decl
|
||||
policyDecl = do
|
||||
reserved "policy"
|
||||
n <- identifier
|
||||
reservedOp ":"
|
||||
t <- typeP
|
||||
reserved "on"
|
||||
pm <- braces policyMeta
|
||||
reservedOp "="
|
||||
ab <- armBlock
|
||||
_ <- semi
|
||||
return (DPolicy n t pm ab)
|
||||
|
||||
policyMeta :: Parser PolicyMeta
|
||||
policyMeta = do
|
||||
props <- commaSep1 metaProp
|
||||
let h = foldr (\p a -> case p of Left v -> v; _ -> a) HInput props
|
||||
tb = foldr (\p a -> case p of Right (Left v) -> v; _ -> a) TFilter props
|
||||
pr = foldr (\p a -> case p of Right (Right v) -> v; _ -> a) pFilter props
|
||||
return (PolicyMeta h tb pr)
|
||||
|
||||
metaProp :: Parser (Either Hook (Either TableName Priority))
|
||||
metaProp
|
||||
= (reserved "hook" >> reservedOp "=" >> fmap (Left) hookP)
|
||||
<|> (reserved "table" >> reservedOp "=" >> fmap (Right . Left) tableNameP)
|
||||
<|> (reserved "priority" >> reservedOp "=" >> fmap (Right . Right) priorityP)
|
||||
|
||||
hookP :: Parser Hook
|
||||
hookP = (reserved "Input" >> return HInput)
|
||||
<|> (reserved "Forward" >> return HForward)
|
||||
<|> (reserved "Output" >> return HOutput)
|
||||
<|> (reserved "Prerouting" >> return HPrerouting)
|
||||
<|> (reserved "Postrouting" >> return HPostrouting)
|
||||
|
||||
tableNameP :: Parser TableName
|
||||
tableNameP = (reserved "Filter" >> return TFilter)
|
||||
<|> (reserved "NAT" >> return TNAT)
|
||||
|
||||
priorityP :: Parser Priority
|
||||
priorityP
|
||||
= (reserved "Filter" >> return pFilter)
|
||||
<|> (reserved "DstNat" >> return pDstNat)
|
||||
<|> (reserved "SrcNat" >> return pSrcNat)
|
||||
<|> (reserved "Mangle" >> return pMangle)
|
||||
<|> (reserved "Raw" >> return pRaw)
|
||||
<|> (reserved "ConnTrack" >> return pConnTrack)
|
||||
<|> (Priority . fromIntegral <$> integerP)
|
||||
where
|
||||
-- Accept optional leading minus for negative priorities
|
||||
integerP = do
|
||||
neg <- option 1 (char '-' >> return (-1))
|
||||
n <- natural
|
||||
whiteSpace
|
||||
return (neg * fromIntegral n)
|
||||
|
||||
-- ─── Arm blocks ──────────────────────────────────────────────────────────────
|
||||
|
||||
armBlock :: Parser ArmBlock
|
||||
armBlock = braces (many arm)
|
||||
|
||||
arm :: Parser Arm
|
||||
arm = do
|
||||
_ <- symbol "|"
|
||||
p <- pat
|
||||
g <- optionMaybe (reserved "if" >> expr)
|
||||
reservedOp "->"
|
||||
e <- expr
|
||||
_ <- semi
|
||||
return (Arm p g e)
|
||||
|
||||
-- ─── Patterns ────────────────────────────────────────────────────────────────
|
||||
|
||||
pat :: Parser Pat
|
||||
pat = wildcardPat
|
||||
<|> try framePat
|
||||
<|> try tuplePat
|
||||
<|> bytesPat
|
||||
<|> try recordPat
|
||||
<|> try namedOrCtorPat
|
||||
|
||||
wildcardPat :: Parser Pat
|
||||
wildcardPat = symbol "_" >> return PWild
|
||||
|
||||
-- Frame(...) — optional path then inner pattern
|
||||
-- Layer stripping: if the inner pattern is not Ether/IPv4/IPv6/etc the
|
||||
-- type-checker will peel outer layers automatically. Parser just stores
|
||||
-- whatever the user wrote.
|
||||
framePat :: Parser Pat
|
||||
framePat = do
|
||||
reserved "Frame"
|
||||
(mp, inner) <- parens frameArgs
|
||||
return (PFrame mp inner)
|
||||
|
||||
frameArgs :: Parser (Maybe PathPat, Pat)
|
||||
frameArgs = try withPath <|> withoutPath
|
||||
where
|
||||
withPath = do
|
||||
pp <- pathPat
|
||||
_ <- comma
|
||||
inner <- pat
|
||||
return (Just pp, inner)
|
||||
withoutPath = do
|
||||
inner <- pat
|
||||
return (Nothing, inner)
|
||||
|
||||
pathPat :: Parser PathPat
|
||||
pathPat = do
|
||||
src <- optionMaybe (try endpointPat)
|
||||
dst <- optionMaybe (try (reservedOp "->" >> endpointPat))
|
||||
case (src, dst) of
|
||||
(Nothing, Nothing) -> fail "empty path pattern"
|
||||
_ -> return (PathPat src dst)
|
||||
|
||||
endpointPat :: Parser EndpointPat
|
||||
endpointPat
|
||||
= (symbol "_" >> return EPWild)
|
||||
<|> try (do n <- identifier
|
||||
memberOp
|
||||
z <- identifier
|
||||
return (EPMember n z))
|
||||
<|> (EPName <$> identifier)
|
||||
|
||||
memberOp :: Parser ()
|
||||
memberOp = (reservedOp "∈" <|> reserved "in") >> return ()
|
||||
|
||||
tuplePat :: Parser Pat
|
||||
tuplePat = do
|
||||
ps <- parens (commaSep2 pat)
|
||||
return (PTuple ps)
|
||||
|
||||
commaSep2 :: Parser a -> Parser [a]
|
||||
commaSep2 p = do
|
||||
x <- p
|
||||
_ <- comma
|
||||
xs <- commaSep1 p
|
||||
return (x:xs)
|
||||
|
||||
bytesPat :: Parser Pat
|
||||
bytesPat = brackets (PBytes <$> many byteElem)
|
||||
|
||||
byteElem :: Parser ByteElem
|
||||
byteElem
|
||||
= try (symbol "_*" >> return BEWildStar)
|
||||
<|> try (symbol "_" >> return BEWild)
|
||||
<|> (BEHex <$> hexByte)
|
||||
|
||||
hexByte :: Parser Word8
|
||||
hexByte = do
|
||||
void (string "0x")
|
||||
h1 <- hexDigit
|
||||
h2 <- hexDigit
|
||||
whiteSpace
|
||||
case (readHex [h1,h2] :: [(Integer, String)]) of
|
||||
[(v,"")] -> return (fromIntegral v)
|
||||
_ -> fail "invalid hex byte"
|
||||
|
||||
-- Record pattern: ident { fields }
|
||||
recordPat :: Parser Pat
|
||||
recordPat = do
|
||||
n <- identifier
|
||||
fs <- braces (commaSep fieldPat)
|
||||
return (PRecord n fs)
|
||||
|
||||
fieldPat :: Parser FieldPat
|
||||
fieldPat = do
|
||||
n <- identifier
|
||||
try (reservedOp "=" >> FPEq n <$> fieldLiteral)
|
||||
<|> try (reserved "as" >> FPAs n <$> identifier)
|
||||
<|> return (FPBind n)
|
||||
|
||||
-- Port literals (:22) are valid in record field position as well as plain literals.
|
||||
fieldLiteral :: Parser Literal
|
||||
fieldLiteral = try portLit <|> literal
|
||||
where
|
||||
portLit = do
|
||||
void (char ':')
|
||||
n <- fromIntegral <$> natural
|
||||
return (LPort n)
|
||||
|
||||
-- Named pattern reference OR constructor: starts with uppercase-ish ident
|
||||
namedOrCtorPat :: Parser Pat
|
||||
namedOrCtorPat = do
|
||||
n <- identifier
|
||||
args <- optionMaybe (try (parens (commaSep pat)))
|
||||
case args of
|
||||
Nothing -> return (PNamed n) -- bare name = named pattern ref
|
||||
Just ps -> return (PCtor n ps)
|
||||
|
||||
-- ─── Flow expressions ────────────────────────────────────────────────────────
|
||||
|
||||
flowExpr :: Parser FlowExpr
|
||||
flowExpr = do
|
||||
first <- FAtom <$> identifier
|
||||
rest <- many (reservedOp "." >> identifier)
|
||||
mw <- optionMaybe (reserved "within" >> durationLit)
|
||||
return $ buildSeq (first : map FAtom rest) mw
|
||||
where
|
||||
buildSeq [x] mw = case mw of
|
||||
Nothing -> x
|
||||
Just w -> FSeq x x (Just w) -- degenerate
|
||||
buildSeq (x:xs) mw = FSeq x (buildSeq xs mw) mw
|
||||
buildSeq [] _ = error "impossible"
|
||||
|
||||
durationLit :: Parser Duration
|
||||
durationLit = do
|
||||
n <- fromIntegral <$> natural
|
||||
u <- (char 's' >> return Seconds)
|
||||
<|> (string "ms" >> return Millis)
|
||||
<|> (char 'm' >> return Minutes)
|
||||
<|> (char 'h' >> return Hours)
|
||||
whiteSpace
|
||||
return (n, u)
|
||||
|
||||
-- ─── Types ───────────────────────────────────────────────────────────────────
|
||||
|
||||
typeP :: Parser Type
|
||||
typeP = do
|
||||
t <- baseType
|
||||
option t (reservedOp "->" >> TFun t <$> typeP)
|
||||
|
||||
baseType :: Parser Type
|
||||
baseType
|
||||
= effectType
|
||||
<|> try tupleTy
|
||||
<|> simpleTy
|
||||
|
||||
effectType :: Parser Type
|
||||
effectType = do
|
||||
effs <- angles (commaSep identifier)
|
||||
t <- simpleTy
|
||||
return (TEffect effs t)
|
||||
|
||||
tupleTy :: Parser Type
|
||||
tupleTy = TTuple <$> parens (commaSep2 typeP)
|
||||
|
||||
simpleTy :: Parser Type
|
||||
simpleTy = do
|
||||
n <- identifier
|
||||
args <- option [] (angles (commaSep typeP))
|
||||
return (TName n args)
|
||||
|
||||
-- ─── Expressions ─────────────────────────────────────────────────────────────
|
||||
|
||||
expr :: Parser Expr
|
||||
expr = lamExpr
|
||||
<|> ifExpr
|
||||
<|> doExpr
|
||||
<|> caseExpr
|
||||
<|> letExpr
|
||||
<|> infixExpr
|
||||
|
||||
lamExpr :: Parser Expr
|
||||
lamExpr = do
|
||||
reservedOp "\\"
|
||||
n <- identifier
|
||||
reservedOp "->"
|
||||
e <- expr
|
||||
return (ELam n e)
|
||||
|
||||
ifExpr :: Parser Expr
|
||||
ifExpr = do
|
||||
reserved "if"
|
||||
c <- expr
|
||||
reserved "then"
|
||||
t <- expr
|
||||
reserved "else"
|
||||
f <- expr
|
||||
return (EIf c t f)
|
||||
|
||||
doExpr :: Parser Expr
|
||||
doExpr = reserved "do" >> braces (EDo <$> semiSep doStmt)
|
||||
|
||||
doStmt :: Parser DoStmt
|
||||
doStmt = try bindStmt <|> (DSExpr <$> expr)
|
||||
|
||||
bindStmt :: Parser DoStmt
|
||||
bindStmt = do
|
||||
n <- identifier
|
||||
reservedOp "<-"
|
||||
e <- expr
|
||||
return (DSBind n e)
|
||||
|
||||
caseExpr :: Parser Expr
|
||||
caseExpr = do
|
||||
reserved "case"
|
||||
e <- expr
|
||||
reserved "of"
|
||||
ab <- armBlock
|
||||
return (ECase e ab)
|
||||
|
||||
letExpr :: Parser Expr
|
||||
letExpr = do
|
||||
reserved "let"
|
||||
n <- identifier
|
||||
reservedOp "="
|
||||
e1 <- expr
|
||||
reserved "in"
|
||||
e2 <- expr
|
||||
return (ELet n e1 e2)
|
||||
|
||||
-- Operator table for infix expressions
|
||||
infixExpr :: Parser Expr
|
||||
infixExpr = Ex.buildExpressionParser opTable appExpr
|
||||
|
||||
opTable :: Ex.OperatorTable String () Identity Expr
|
||||
opTable =
|
||||
[ [ prefix "!" ENot ]
|
||||
, [ infixL "==" OpEq, infixL "!=" OpNeq
|
||||
, infixL "<" OpLt, infixL "<=" OpLte
|
||||
, infixL ">" OpGt, infixL ">=" OpGte
|
||||
, infixIn ]
|
||||
, [ infixR "&&" OpAnd ]
|
||||
, [ infixR "||" OpOr ]
|
||||
, [ infixR "++" OpConcat ]
|
||||
, [ infixL ">>=" OpBind ]
|
||||
, [ infixL ">>" OpThen ]
|
||||
]
|
||||
where
|
||||
prefix op f = Ex.Prefix (reservedOp op >> return f)
|
||||
infixL op c = Ex.Infix (reservedOp op >> return (EInfix c)) Ex.AssocLeft
|
||||
infixR op c = Ex.Infix (reservedOp op >> return (EInfix c)) Ex.AssocRight
|
||||
infixIn = Ex.Infix
|
||||
((memberOp <|> reserved "in") >> return (EInfix OpIn))
|
||||
Ex.AssocNone
|
||||
|
||||
appExpr :: Parser Expr
|
||||
appExpr = do
|
||||
f <- atom
|
||||
args <- many atom
|
||||
return (foldl EApp f args)
|
||||
|
||||
atom :: Parser Expr
|
||||
atom
|
||||
= try performExpr
|
||||
<|> try mapLit
|
||||
<|> try setLit
|
||||
<|> try tupleLit
|
||||
<|> try (parens expr)
|
||||
<|> try litExpr
|
||||
<|> try portExpr
|
||||
<|> qualNameExpr
|
||||
|
||||
performExpr :: Parser Expr
|
||||
performExpr = do
|
||||
reserved "perform"
|
||||
parts <- sepBy1 identifier dot
|
||||
args <- parens (commaSep expr)
|
||||
return (EPerform parts args)
|
||||
|
||||
qualNameExpr :: Parser Expr
|
||||
qualNameExpr = do
|
||||
parts <- sepBy1 identifier (try (dot <* notFollowedBy digit))
|
||||
case parts of
|
||||
[n] -> return (EVar n)
|
||||
ns -> return (EQual ns)
|
||||
|
||||
litExpr :: Parser Expr
|
||||
litExpr = ELit <$> literal
|
||||
|
||||
portExpr :: Parser Expr
|
||||
portExpr = do
|
||||
void (char ':')
|
||||
n <- fromIntegral <$> natural
|
||||
return (ELit (LPort n))
|
||||
|
||||
tupleLit :: Parser Expr
|
||||
tupleLit = ETuple <$> parens (commaSep2 expr)
|
||||
|
||||
setLit :: Parser Expr
|
||||
setLit = braces $ do
|
||||
items <- commaSep expr
|
||||
return (ESet items)
|
||||
|
||||
-- map literal: { expr -> expr, ... }
|
||||
mapLit :: Parser Expr
|
||||
mapLit = braces $ do
|
||||
entries <- commaSep1 mapEntry
|
||||
return (EMap entries)
|
||||
|
||||
mapEntry :: Parser (Expr, Expr)
|
||||
mapEntry = do
|
||||
k <- expr
|
||||
reservedOp "->"
|
||||
v <- expr
|
||||
return (k, v)
|
||||
|
||||
-- ─── Literals ────────────────────────────────────────────────────────────────
|
||||
|
||||
literal :: Parser Literal
|
||||
literal
|
||||
= try ipOrCidrLit
|
||||
<|> try hexLit
|
||||
<|> try (LBool True <$ reserved "true")
|
||||
<|> try (LBool False <$ reserved "false")
|
||||
<|> try (LString <$> stringLit)
|
||||
<|> try (LInt . fromIntegral <$> natural)
|
||||
|
||||
hexLit :: Parser Literal
|
||||
hexLit = LHex <$> hexByte
|
||||
|
||||
-- ─── IP / CIDR parsing ───────────────────────────────────────────────────────
|
||||
|
||||
-- | Parse an IPv4 or IPv6 address, optionally followed by /prefix.
|
||||
-- Tries IPv6 first (it can start with hex digits too), then IPv4.
|
||||
ipOrCidrLit :: Parser Literal
|
||||
ipOrCidrLit = do
|
||||
ip <- try ipv6Lit <|> ipv4Lit_
|
||||
mPrefix <- optionMaybe (char '/' >> fromIntegral <$> natural)
|
||||
whiteSpace
|
||||
return $ case mPrefix of
|
||||
Nothing -> ip
|
||||
Just p -> LCIDR ip p
|
||||
|
||||
-- | IPv4: four decimal octets separated by dots → LIP IPv4 (32-bit Integer)
|
||||
ipv4Lit_ :: Parser Literal
|
||||
ipv4Lit_ = do
|
||||
a <- octet
|
||||
void (char '.')
|
||||
b <- octet
|
||||
void (char '.')
|
||||
c <- octet
|
||||
void (char '.')
|
||||
d <- octet
|
||||
return $ LIP IPv4
|
||||
( fromIntegral a `shiftL` 24
|
||||
.|. fromIntegral b `shiftL` 16
|
||||
.|. fromIntegral c `shiftL` 8
|
||||
.|. fromIntegral d)
|
||||
where
|
||||
octet = do
|
||||
n <- fromIntegral <$> natural
|
||||
if n > 255 then fail "octet out of range" else return n
|
||||
|
||||
-- | IPv6: full notation, :: abbreviation, and optional embedded IPv4.
|
||||
-- Stores as LIP IPv6 (128-bit Integer).
|
||||
ipv6Lit :: Parser Literal
|
||||
ipv6Lit = do
|
||||
(left, right) <- ipv6Groups
|
||||
let missing = 8 - length left - length right
|
||||
when (missing < 0) $ fail "too many groups in IPv6 address"
|
||||
let groups = left ++ replicate missing 0 ++ right
|
||||
when (length groups /= 8) $ fail "invalid IPv6 address"
|
||||
let val = foldl' (\acc g -> (acc `shiftL` 16) .|. fromIntegral g) (0::Integer) groups
|
||||
return (LIP IPv6 val)
|
||||
|
||||
-- Returns (left-of-::, right-of-::).
|
||||
-- If no :: present, left has all 8 groups and right is empty.
|
||||
ipv6Groups :: Parser ([Int], [Int])
|
||||
ipv6Groups = do
|
||||
-- must start with a hex digit or ':' (for ::)
|
||||
ahead <- lookAhead (hexDigit <|> char ':')
|
||||
case ahead of
|
||||
':' -> do
|
||||
void (string "::")
|
||||
right <- ipv6RightGroups
|
||||
return ([], right)
|
||||
_ -> do
|
||||
left <- ipv6LeftGroups
|
||||
mDbl <- optionMaybe (try (string "::"))
|
||||
case mDbl of
|
||||
Nothing -> return (left, [])
|
||||
Just _ -> do
|
||||
right <- ipv6RightGroups
|
||||
return (left, right)
|
||||
|
||||
-- Parse a run of hex16:hex16:... stopping before :: or end
|
||||
ipv6LeftGroups :: Parser [Int]
|
||||
ipv6LeftGroups = do
|
||||
first <- hex16
|
||||
rest <- many (try (char ':' >> notFollowedBy (char ':') >> hex16))
|
||||
return (first : rest)
|
||||
|
||||
-- Parse groups to the right of ::, including optional embedded IPv4
|
||||
ipv6RightGroups :: Parser [Int]
|
||||
ipv6RightGroups = option [] $
|
||||
try ipv4EmbeddedGroups <|> ipv6LeftGroups
|
||||
|
||||
-- IPv4-mapped groups: e.g. ffff:192.168.1.1 -> [0xffff, 0xc0a8, 0x0101]
|
||||
ipv4EmbeddedGroups :: Parser [Int]
|
||||
ipv4EmbeddedGroups = do
|
||||
prefix <- many (try (hex16 <* char ':' <* lookAhead digit))
|
||||
a <- octet_; void (char '.')
|
||||
b <- octet_; void (char '.')
|
||||
c <- octet_; void (char '.')
|
||||
d <- octet_
|
||||
let hi = (a `shiftL` 8) .|. b
|
||||
lo = (c `shiftL` 8) .|. d
|
||||
return (prefix ++ [hi, lo])
|
||||
where
|
||||
octet_ = do
|
||||
n <- fromIntegral <$> natural
|
||||
if n > 255 then fail "IPv4 octet out of range" else return n
|
||||
|
||||
hex16 :: Parser Int
|
||||
hex16 = do
|
||||
digits <- many1 hexDigit
|
||||
case (reads ("0x" ++ digits)) :: [(Int,String)] of
|
||||
[(v,"")] -> if v > 0xffff then fail "hex16 out of range" else return v
|
||||
_ -> fail "invalid hex group"
|
||||
|
||||
cidrLit :: Parser CIDR
|
||||
cidrLit = do
|
||||
l <- ipOrCidrLit
|
||||
case l of
|
||||
LCIDR ip p -> return (ip, p)
|
||||
_ -> fail "expected CIDR notation (address/prefix)"
|
||||
187
src/FWL/Pretty.hs
Normal file
187
src/FWL/Pretty.hs
Normal file
@@ -0,0 +1,187 @@
|
||||
-- | Pretty printer: round-trips the AST back to FWL source.
|
||||
module FWL.Pretty (prettyProgram) where
|
||||
|
||||
import Data.List (intercalate)
|
||||
import FWL.AST
|
||||
|
||||
prettyProgram :: Program -> String
|
||||
prettyProgram (Program cfg ds) =
|
||||
prettyConfig cfg ++ "\n" ++ unlines (map prettyDecl ds)
|
||||
|
||||
prettyConfig :: Config -> String
|
||||
prettyConfig (Config t)
|
||||
| t == "fwl" = ""
|
||||
| otherwise = "config { table = \"" ++ t ++ "\"; }\n"
|
||||
|
||||
prettyDecl :: Decl -> String
|
||||
prettyDecl (DInterface n k ps) =
|
||||
"interface " ++ n ++ " : " ++ prettyKind k ++ " {\n" ++
|
||||
concatMap (\p -> " " ++ prettyIfaceProp p ++ ";\n") ps ++
|
||||
"};"
|
||||
prettyDecl (DZone n ns) =
|
||||
"zone " ++ n ++ " = { " ++ intercalate ", " ns ++ " };"
|
||||
prettyDecl (DImport n t s) =
|
||||
"import " ++ n ++ " : " ++ prettyType t ++ " from \"" ++ s ++ "\";"
|
||||
prettyDecl (DLet n t e) =
|
||||
"let " ++ n ++ " : " ++ prettyType t ++ " = " ++ prettyExpr e ++ ";"
|
||||
prettyDecl (DPattern n t p) =
|
||||
"pattern " ++ n ++ " : " ++ prettyType t ++ " = " ++ prettyPat p ++ ";"
|
||||
prettyDecl (DFlow n f) =
|
||||
"flow " ++ n ++ " : FlowPattern = " ++ prettyFlow f ++ ";"
|
||||
prettyDecl (DRule n t e) =
|
||||
"rule " ++ n ++ " : " ++ prettyType t ++ " =\n " ++ prettyExpr e ++ ";"
|
||||
prettyDecl (DPolicy n t pm ab) =
|
||||
"policy " ++ n ++ " : " ++ prettyType t ++ "\n" ++
|
||||
" on { hook = " ++ prettyHook (pmHook pm) ++
|
||||
", table = " ++ prettyTable (pmTable pm) ++
|
||||
", priority = " ++ prettyPriority (pmPriority pm) ++ " }\n" ++
|
||||
" = " ++ prettyArmBlock ab ++ ";"
|
||||
|
||||
prettyKind :: IfaceKind -> String
|
||||
prettyKind IWan = "WAN"
|
||||
prettyKind ILan = "LAN"
|
||||
prettyKind IWireGuard = "WireGuard"
|
||||
prettyKind (IUser n) = n
|
||||
|
||||
prettyIfaceProp :: IfaceProp -> String
|
||||
prettyIfaceProp IPDynamic = "dynamic"
|
||||
prettyIfaceProp (IPCidr4 cs) = "cidr4 = { " ++ intercalate ", " (map prettyCidr cs) ++ " }"
|
||||
prettyIfaceProp (IPCidr6 cs) = "cidr6 = { " ++ intercalate ", " (map prettyCidr cs) ++ " }"
|
||||
|
||||
prettyCidr :: CIDR -> String
|
||||
prettyCidr (LIPv4 (a,b,c,d), p) =
|
||||
show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ "/" ++ show p
|
||||
prettyCidr (ip, p) = prettyLit ip ++ "/" ++ show p
|
||||
|
||||
prettyHook :: Hook -> String
|
||||
prettyHook HInput = "Input"
|
||||
prettyHook HForward = "Forward"
|
||||
prettyHook HOutput = "Output"
|
||||
prettyHook HPrerouting = "Prerouting"
|
||||
prettyHook HPostrouting = "Postrouting"
|
||||
|
||||
prettyTable :: TableName -> String
|
||||
prettyTable TFilter = "Filter"
|
||||
prettyTable TNAT = "NAT"
|
||||
|
||||
prettyPriority :: Priority -> String
|
||||
prettyPriority p = show (priorityValue p)
|
||||
|
||||
prettyType :: Type -> String
|
||||
prettyType (TName n []) = n
|
||||
prettyType (TName n ts) = n ++ "<" ++ intercalate ", " (map prettyType ts) ++ ">"
|
||||
prettyType (TTuple ts) = "(" ++ intercalate ", " (map prettyType ts) ++ ")"
|
||||
prettyType (TFun a b) = prettyType a ++ " -> " ++ prettyType b
|
||||
prettyType (TEffect es t) = "<" ++ intercalate ", " es ++ "> " ++ prettyType t
|
||||
|
||||
prettyPat :: Pat -> String
|
||||
prettyPat PWild = "_"
|
||||
prettyPat (PVar n) = n
|
||||
prettyPat (PNamed n) = n
|
||||
prettyPat (PCtor n ps) = n ++ "(" ++ intercalate ", " (map prettyPat ps) ++ ")"
|
||||
prettyPat (PRecord n fs) = n ++ " { " ++ intercalate ", " (map prettyFP fs) ++ " }"
|
||||
prettyPat (PTuple ps) = "(" ++ intercalate ", " (map prettyPat ps) ++ ")"
|
||||
prettyPat (PFrame mp inner)=
|
||||
"Frame(" ++ maybe "" (\pp -> prettyPath pp ++ ", ") mp ++ prettyPat inner ++ ")"
|
||||
prettyPat (PBytes bs) = "[" ++ unwords (map prettyBE bs) ++ "]"
|
||||
|
||||
prettyFP :: FieldPat -> String
|
||||
prettyFP (FPEq n l) = n ++ " = " ++ prettyLit l
|
||||
prettyFP (FPBind n) = n
|
||||
prettyFP (FPAs n v) = n ++ " as " ++ v
|
||||
|
||||
prettyPath :: PathPat -> String
|
||||
prettyPath (PathPat ms md) =
|
||||
maybe "_" prettyEP ms ++ maybe "" (\d -> " -> " ++ prettyEP d) md
|
||||
|
||||
prettyEP :: EndpointPat -> String
|
||||
prettyEP EPWild = "_"
|
||||
prettyEP (EPName n) = n
|
||||
prettyEP (EPMember n z) = n ++ " in " ++ z
|
||||
|
||||
prettyBE :: ByteElem -> String
|
||||
prettyBE (BEHex w) = "0x" ++ pad (show w) -- simplified
|
||||
where pad s = if length s < 2 then '0':s else s
|
||||
prettyBE BEWild = "_"
|
||||
prettyBE BEWildStar = "_*"
|
||||
|
||||
prettyFlow :: FlowExpr -> String
|
||||
prettyFlow (FAtom n) = n
|
||||
prettyFlow (FSeq a b mw) =
|
||||
prettyFlow a ++ " . " ++ prettyFlow b ++
|
||||
maybe "" (\(n,u) -> " within " ++ show n ++ prettyUnit u) mw
|
||||
|
||||
prettyUnit :: TimeUnit -> String
|
||||
prettyUnit Seconds = "s"
|
||||
prettyUnit Millis = "ms"
|
||||
prettyUnit Minutes = "m"
|
||||
prettyUnit Hours = "h"
|
||||
|
||||
prettyExpr :: Expr -> String
|
||||
prettyExpr (EVar n) = n
|
||||
prettyExpr (EQual ns) = intercalate "." ns
|
||||
prettyExpr (ELit l) = prettyLit l
|
||||
prettyExpr (ELam n e) = "\\" ++ n ++ " -> " ++ prettyExpr e
|
||||
prettyExpr (EApp f x) = prettyExpr f ++ " " ++ prettyAtom x
|
||||
prettyExpr (ECase e ab) =
|
||||
"case " ++ prettyExpr e ++ " of " ++ prettyArmBlock ab
|
||||
prettyExpr (EIf c t f) =
|
||||
"if " ++ prettyExpr c ++ " then " ++ prettyExpr t ++ " else " ++ prettyExpr f
|
||||
prettyExpr (EDo ss) =
|
||||
"do { " ++ intercalate "; " (map prettyStmt ss) ++ " }"
|
||||
prettyExpr (ELet n e1 e2) =
|
||||
"let " ++ n ++ " = " ++ prettyExpr e1 ++ " in " ++ prettyExpr e2
|
||||
prettyExpr (ETuple es) = "(" ++ intercalate ", " (map prettyExpr es) ++ ")"
|
||||
prettyExpr (ESet es) = "{ " ++ intercalate ", " (map prettyExpr es) ++ " }"
|
||||
prettyExpr (EMap ms) =
|
||||
"{ " ++ intercalate ", " (map (\(k,v) -> prettyExpr k ++ " -> " ++ prettyExpr v) ms) ++ " }"
|
||||
prettyExpr (EPerform ns as_) =
|
||||
"perform " ++ intercalate "." ns ++ "(" ++ intercalate ", " (map prettyExpr as_) ++ ")"
|
||||
prettyExpr (EInfix op l r) =
|
||||
prettyAtom l ++ " " ++ prettyOp op ++ " " ++ prettyAtom r
|
||||
prettyExpr (ENot e) = "!" ++ prettyAtom e
|
||||
|
||||
prettyAtom :: Expr -> String
|
||||
prettyAtom e@(EInfix _ _ _) = "(" ++ prettyExpr e ++ ")"
|
||||
prettyAtom e@(ELam _ _) = "(" ++ prettyExpr e ++ ")"
|
||||
prettyAtom e = prettyExpr e
|
||||
|
||||
prettyOp :: InfixOp -> String
|
||||
prettyOp OpAnd = "&&"
|
||||
prettyOp OpOr = "||"
|
||||
prettyOp OpEq = "=="
|
||||
prettyOp OpNeq = "!="
|
||||
prettyOp OpLt = "<"
|
||||
prettyOp OpLte = "<="
|
||||
prettyOp OpGt = ">"
|
||||
prettyOp OpGte = ">="
|
||||
prettyOp OpIn = "in"
|
||||
prettyOp OpConcat = "++"
|
||||
prettyOp OpThen = ">>"
|
||||
prettyOp OpBind = ">>="
|
||||
|
||||
prettyStmt :: DoStmt -> String
|
||||
prettyStmt (DSBind n e) = n ++ " <- " ++ prettyExpr e
|
||||
prettyStmt (DSExpr e) = prettyExpr e
|
||||
|
||||
prettyArmBlock :: ArmBlock -> String
|
||||
prettyArmBlock arms =
|
||||
"{\n" ++
|
||||
concatMap (\(Arm p mg e) ->
|
||||
" | " ++ prettyPat p ++
|
||||
maybe "" (\g -> " if " ++ prettyExpr g) mg ++
|
||||
" -> " ++ prettyExpr e ++ ";\n") arms ++
|
||||
" }"
|
||||
|
||||
prettyLit :: Literal -> String
|
||||
prettyLit (LInt n) = show n
|
||||
prettyLit (LString s) = "\"" ++ s ++ "\""
|
||||
prettyLit (LBool True) = "true"
|
||||
prettyLit (LBool False) = "false"
|
||||
prettyLit (LIPv4 (a,b,c,d)) =
|
||||
show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d
|
||||
prettyLit (LIPv6 _) = "<ipv6>"
|
||||
prettyLit (LCIDR ip p) = prettyLit ip ++ "/" ++ show p
|
||||
prettyLit (LPort p) = ":" ++ show p
|
||||
prettyLit (LDuration n u) = show n ++ prettyUnit u
|
||||
prettyLit (LHex b) = "0x" ++ show b
|
||||
224
test/CheckTests.hs
Normal file
224
test/CheckTests.hs
Normal file
@@ -0,0 +1,224 @@
|
||||
module CheckTests (tests) where
|
||||
|
||||
import Test.Tasty
|
||||
import Test.Tasty.HUnit
|
||||
|
||||
import FWL.Check
|
||||
import FWL.Util
|
||||
|
||||
tests :: TestTree
|
||||
tests = testGroup "Check"
|
||||
[ undefinedNameTests
|
||||
, duplicateTests
|
||||
, policyTerminationTests
|
||||
, patternCycleTests
|
||||
, cleanProgramTests
|
||||
]
|
||||
|
||||
-- ─── Helper ──────────────────────────────────────────────────────────────────
|
||||
|
||||
checkSrc :: String -> IO [CheckError]
|
||||
checkSrc src = do
|
||||
p <- parseOk src
|
||||
return (checkProgram p)
|
||||
|
||||
assertNoErrors :: String -> IO ()
|
||||
assertNoErrors src = do
|
||||
errs <- checkSrc src
|
||||
case errs of
|
||||
[] -> return ()
|
||||
_ -> assertFailure ("Unexpected errors: " ++ show errs)
|
||||
|
||||
assertHasError :: (CheckError -> Bool) -> String -> IO ()
|
||||
assertHasError p src = do
|
||||
errs <- checkSrc src
|
||||
if any p errs
|
||||
then return ()
|
||||
else assertFailure ("Expected error not found. Got: " ++ show errs)
|
||||
|
||||
isUndefined :: String -> CheckError -> Bool
|
||||
isUndefined n (UndefinedName _ m) = m == n
|
||||
isUndefined _ _ = False
|
||||
|
||||
isDuplicate :: String -> CheckError -> Bool
|
||||
isDuplicate n (DuplicateDecl _ m) = m == n
|
||||
isDuplicate _ _ = False
|
||||
|
||||
isNoContinue :: String -> CheckError -> Bool
|
||||
isNoContinue n (PolicyNoContinue m) = m == n
|
||||
isNoContinue _ _ = False
|
||||
|
||||
isCycle :: CheckError -> Bool
|
||||
isCycle (PatternCycle _) = True
|
||||
isCycle _ = False
|
||||
|
||||
-- ─── Undefined name tests ────────────────────────────────────────────────────
|
||||
|
||||
undefinedNameTests :: TestTree
|
||||
undefinedNameTests = testGroup "undefined names"
|
||||
[ testCase "zone references unknown interface" $
|
||||
assertHasError (isUndefined "ghost")
|
||||
"zone bad_zone = { lan, ghost };"
|
||||
|
||||
, testCase "zone references known interface — no error" $
|
||||
assertNoErrors
|
||||
"interface lan : LAN {}; \
|
||||
\zone good = { lan };"
|
||||
|
||||
, testCase "pattern references undefined named pattern" $
|
||||
assertHasError (isUndefined "Undefined")
|
||||
"pattern Bad : Frame = Frame(_, IPv4(ip, Undefined));"
|
||||
|
||||
, testCase "pattern references known named pattern — no error" $
|
||||
assertNoErrors
|
||||
"pattern WGInit : (UDPHeader,Bytes) = (udp { length = 156 }, [0x01 _*]); \
|
||||
\pattern Compound : Frame = Frame(_, IPv4(ip, WGInit));"
|
||||
|
||||
, testCase "flow references undefined pattern" $
|
||||
assertHasError (isUndefined "Ghost")
|
||||
"flow Bad : FlowPattern = Ghost;"
|
||||
|
||||
, testCase "flow references known pattern — no error" $
|
||||
assertNoErrors
|
||||
"pattern P : T = udp { length = 1 }; \
|
||||
\flow F : FlowPattern = P;"
|
||||
|
||||
, testCase "policy guard references undeclared zone" $
|
||||
-- 'unknown_zone' not declared; check should flag it
|
||||
assertHasError (isUndefined "unknown_zone")
|
||||
"policy fwd : Frame \
|
||||
\ on { hook = Forward, table = Filter, priority = Filter } \
|
||||
\ = { | Frame(iif in unknown_zone -> wan, _) -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
|
||||
, testCase "policy references known zone — no error" $
|
||||
assertNoErrors
|
||||
"interface lan : LAN {}; \
|
||||
\zone trusted = { lan }; \
|
||||
\policy fwd : Frame \
|
||||
\ on { hook = Forward, table = Filter, priority = Filter } \
|
||||
\ = { | Frame(iif in trusted -> wan, _) -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
]
|
||||
|
||||
-- ─── Duplicate declaration tests ─────────────────────────────────────────────
|
||||
|
||||
duplicateTests :: TestTree
|
||||
duplicateTests = testGroup "duplicates"
|
||||
[ testCase "duplicate interface" $
|
||||
assertHasError (isDuplicate "lan")
|
||||
"interface lan : LAN {}; \
|
||||
\interface lan : WAN {};"
|
||||
|
||||
, testCase "duplicate zone" $
|
||||
assertHasError (isDuplicate "z")
|
||||
"zone z = { a }; \
|
||||
\zone z = { b };"
|
||||
|
||||
, testCase "duplicate pattern" $
|
||||
assertHasError (isDuplicate "P")
|
||||
"pattern P : T = udp { length = 1 }; \
|
||||
\pattern P : T = udp { length = 2 };"
|
||||
|
||||
, testCase "duplicate policy" $
|
||||
assertHasError (isDuplicate "input")
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; }; \
|
||||
\policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Drop; };"
|
||||
|
||||
, testCase "distinct names — no error" $
|
||||
assertNoErrors
|
||||
"interface lan : LAN {}; \
|
||||
\interface wan : WAN { dynamic; }; \
|
||||
\zone z = { lan };"
|
||||
]
|
||||
|
||||
-- ─── Policy termination tests ────────────────────────────────────────────────
|
||||
|
||||
policyTerminationTests :: TestTree
|
||||
policyTerminationTests = testGroup "policy termination"
|
||||
[ testCase "last arm is Continue — error" $
|
||||
assertHasError (isNoContinue "bad_policy")
|
||||
"policy bad_policy : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Continue; };"
|
||||
|
||||
, testCase "last arm is Drop — ok" $
|
||||
assertNoErrors
|
||||
"policy good : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ if ct.state in { Established } -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
|
||||
, testCase "last arm is Allow — ok" $
|
||||
assertNoErrors
|
||||
"policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
|
||||
, testCase "Continue in non-last arm is fine" $
|
||||
assertNoErrors
|
||||
"rule r : Frame -> Action = \
|
||||
\ \\f -> case f of { \
|
||||
\ | Frame(_, IPv4(ip, _)) -> Continue; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
|
||||
, testCase "empty policy body — error" $
|
||||
assertHasError (isNoContinue "empty")
|
||||
"policy empty : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = {};"
|
||||
]
|
||||
|
||||
-- ─── Pattern cycle tests ─────────────────────────────────────────────────────
|
||||
|
||||
patternCycleTests :: TestTree
|
||||
patternCycleTests = testGroup "pattern cycles"
|
||||
[ testCase "direct self-reference — cycle error" $
|
||||
assertHasError isCycle
|
||||
"pattern Loop : T = Frame(_, Loop);"
|
||||
|
||||
, testCase "mutual cycle — cycle error" $
|
||||
assertHasError isCycle
|
||||
"pattern A : T = Frame(_, B); \
|
||||
\pattern B : T = Frame(_, A);"
|
||||
|
||||
, testCase "linear chain — no cycle" $
|
||||
assertNoErrors
|
||||
"pattern Base : T = udp { length = 1 }; \
|
||||
\pattern Mid : T = Frame(_, Base); \
|
||||
\pattern Top : T = Frame(_, Mid);"
|
||||
]
|
||||
|
||||
-- ─── Clean full programs ──────────────────────────────────────────────────────
|
||||
|
||||
cleanProgramTests :: TestTree
|
||||
cleanProgramTests = testGroup "clean programs"
|
||||
[ testCase "minimal router skeleton" $
|
||||
assertNoErrors
|
||||
"interface wan : WAN { dynamic; }; \
|
||||
\interface lan : LAN { cidr4 = { 10.17.1.0/24 }; }; \
|
||||
\interface wg0 : WireGuard {}; \
|
||||
\zone lan_zone = { lan, wg0 }; \
|
||||
\policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ if ct.state in { Established, Related } -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ }; \
|
||||
\policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
|
||||
, testCase "pattern and flow declarations" $
|
||||
assertNoErrors
|
||||
"pattern WGInit : (UDPHeader,Bytes) = (udp { length = 156 }, [0x01 _*]); \
|
||||
\pattern WGResp : (UDPHeader,Bytes) = (udp { length = 100 }, [0x02 _*]); \
|
||||
\flow WGHandshake : FlowPattern = WGInit . WGResp within 5s;"
|
||||
]
|
||||
384
test/CompileTests.hs
Normal file
384
test/CompileTests.hs
Normal file
@@ -0,0 +1,384 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
module CompileTests (tests) where
|
||||
|
||||
import Test.Tasty
|
||||
import Test.Tasty.HUnit
|
||||
import qualified Data.Aeson as A
|
||||
import qualified Data.Aeson.Key as AK
|
||||
import qualified Data.Aeson.KeyMap as AKM
|
||||
import qualified Data.Vector as V
|
||||
import qualified Data.ByteString.Lazy.Char8 as BL8
|
||||
|
||||
import FWL.AST
|
||||
import FWL.Compile
|
||||
import FWL.Util
|
||||
|
||||
tests :: TestTree
|
||||
tests = testGroup "Compile"
|
||||
[ jsonStructureTests
|
||||
, chainTests
|
||||
, ruleExprTests
|
||||
, verdictTests
|
||||
, layerStrippingTests
|
||||
, continueTests
|
||||
, configTests
|
||||
]
|
||||
|
||||
-- ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
compileToValue :: String -> IO A.Value
|
||||
compileToValue src = do
|
||||
p <- parseOk src
|
||||
case A.decode (compileToJson p) of
|
||||
Nothing -> assertFailure "Compiled output is not valid JSON" >> undefined
|
||||
Just v -> return v
|
||||
|
||||
-- Navigate a Value by a list of string keys / numeric indices.
|
||||
at :: [String] -> A.Value -> Maybe A.Value
|
||||
at [] v = Just v
|
||||
at (k:ks) (A.Object o) =
|
||||
case AKM.lookup (AK.fromString k) o of
|
||||
Nothing -> Nothing
|
||||
Just v -> at ks v
|
||||
at (k:ks) (A.Array arr) =
|
||||
case reads k of
|
||||
[(i,"")] | i < V.length arr -> at ks (arr V.! i)
|
||||
_ -> Nothing
|
||||
at _ _ = Nothing
|
||||
|
||||
nftArr :: A.Value -> IO [A.Value]
|
||||
nftArr v =
|
||||
case at ["nftables"] v of
|
||||
Just (A.Array arr) -> return (V.toList arr)
|
||||
_ -> assertFailure "Missing top-level 'nftables' array" >> undefined
|
||||
|
||||
withKey :: String -> [A.Value] -> [A.Value]
|
||||
withKey k = filter (\v -> case at [k] v of Just _ -> True; _ -> False)
|
||||
|
||||
-- ─── JSON structure tests ────────────────────────────────────────────────────
|
||||
|
||||
jsonStructureTests :: TestTree
|
||||
jsonStructureTests = testGroup "JSON structure"
|
||||
[ testCase "output is valid JSON" $ do
|
||||
_ <- compileToValue
|
||||
"policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
return ()
|
||||
|
||||
, testCase "top-level nftables array present" $ do
|
||||
v <- compileToValue "policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
_ <- nftArr v
|
||||
return ()
|
||||
|
||||
, testCase "metainfo is first element" $ do
|
||||
v <- compileToValue "policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
arr <- nftArr v
|
||||
case arr of
|
||||
(first:_) -> case at ["metainfo"] first of
|
||||
Just _ -> return ()
|
||||
Nothing -> assertFailure "First element is not metainfo"
|
||||
[] -> assertFailure "Empty nftables array"
|
||||
|
||||
, testCase "table object present" $ do
|
||||
v <- compileToValue "policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
arr <- nftArr v
|
||||
assertBool "Expected at least one table object"
|
||||
(not (null (withKey "table" arr)))
|
||||
|
||||
, testCase "default table name is fwl" $ do
|
||||
v <- compileToValue "policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
arr <- nftArr v
|
||||
case withKey "table" arr of
|
||||
(t:_) -> at ["table","name"] t @?= Just (A.String "fwl")
|
||||
[] -> assertFailure "No table object"
|
||||
|
||||
, testCase "custom table name respected" $ do
|
||||
v <- compileToValue
|
||||
"config { table = \"custom\"; } \
|
||||
\policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
arr <- nftArr v
|
||||
case withKey "table" arr of
|
||||
(t:_) -> at ["table","name"] t @?= Just (A.String "custom")
|
||||
[] -> assertFailure "No table object"
|
||||
]
|
||||
|
||||
-- ─── Chain declaration tests ─────────────────────────────────────────────────
|
||||
|
||||
chainTests :: TestTree
|
||||
chainTests = testGroup "chain declarations"
|
||||
[ testCase "filter input chain has correct hook" $ do
|
||||
v <- compileToValue
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Drop; };"
|
||||
arr <- nftArr v
|
||||
case withKey "chain" arr of
|
||||
(c:_) -> at ["chain","hook"] c @?= Just (A.String "input")
|
||||
[] -> assertFailure "No chain"
|
||||
|
||||
, testCase "filter chain type is filter" $ do
|
||||
v <- compileToValue
|
||||
"policy fwd : Frame \
|
||||
\ on { hook = Forward, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Drop; };"
|
||||
arr <- nftArr v
|
||||
case withKey "chain" arr of
|
||||
(c:_) -> at ["chain","type"] c @?= Just (A.String "filter")
|
||||
[] -> assertFailure "No chain"
|
||||
|
||||
, testCase "NAT chain type is nat" $ do
|
||||
v <- compileToValue
|
||||
"policy nat_post : Frame \
|
||||
\ on { hook = Postrouting, table = NAT, priority = SrcNat } \
|
||||
\ = { | _ -> Allow; };"
|
||||
arr <- nftArr v
|
||||
case withKey "chain" arr of
|
||||
(c:_) -> at ["chain","type"] c @?= Just (A.String "nat")
|
||||
[] -> assertFailure "No chain"
|
||||
|
||||
, testCase "input chain default policy is drop" $ do
|
||||
v <- compileToValue
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Drop; };"
|
||||
arr <- nftArr v
|
||||
case withKey "chain" arr of
|
||||
(c:_) -> at ["chain","policy"] c @?= Just (A.String "drop")
|
||||
[] -> assertFailure "No chain"
|
||||
|
||||
, testCase "output chain default policy is accept" $ do
|
||||
v <- compileToValue
|
||||
"policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
arr <- nftArr v
|
||||
case withKey "chain" arr of
|
||||
(c:_) -> at ["chain","policy"] c @?= Just (A.String "accept")
|
||||
[] -> assertFailure "No chain"
|
||||
|
||||
, testCase "chain name matches policy name" $ do
|
||||
v <- compileToValue
|
||||
"policy my_input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Drop; };"
|
||||
arr <- nftArr v
|
||||
case withKey "chain" arr of
|
||||
(c:_) -> at ["chain","name"] c @?= Just (A.String "my_input")
|
||||
[] -> assertFailure "No chain"
|
||||
|
||||
, testCase "two policies produce two chains" $ do
|
||||
v <- compileToValue
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Drop; }; \
|
||||
\policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
arr <- nftArr v
|
||||
length (withKey "chain" arr) @?= 2
|
||||
]
|
||||
|
||||
-- ─── Rule expression tests ───────────────────────────────────────────────────
|
||||
|
||||
ruleExprs :: [A.Value] -> [A.Value]
|
||||
ruleExprs arr =
|
||||
[ e | r <- withKey "rule" arr
|
||||
, Just (A.Array es) <- [at ["rule","expr"] r]
|
||||
, e <- V.toList es ]
|
||||
|
||||
ruleExprTests :: TestTree
|
||||
ruleExprTests = testGroup "rule expressions"
|
||||
[ testCase "two arms produce two rules" $ do
|
||||
v <- compileToValue
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ if ct.state in { Established, Related } -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
arr <- nftArr v
|
||||
length (withKey "rule" arr) @?= 2
|
||||
|
||||
, testCase "arm without guard produces one rule" $ do
|
||||
v <- compileToValue
|
||||
"policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
arr <- nftArr v
|
||||
length (withKey "rule" arr) @?= 1
|
||||
|
||||
, testCase "rule expr array is present" $ do
|
||||
v <- compileToValue
|
||||
"policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
arr <- nftArr v
|
||||
case withKey "rule" arr of
|
||||
(r:_) -> case at ["rule","expr"] r of
|
||||
Just (A.Array _) -> return ()
|
||||
_ -> assertFailure "Missing or non-array 'expr'"
|
||||
[] -> assertFailure "No rule"
|
||||
|
||||
, testCase "IPv4 ctor emits nfproto match" $ do
|
||||
v <- compileToValue
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | Frame(_, IPv4(ip, _)) -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
arr <- nftArr v
|
||||
let matches = withKey "match" (ruleExprs arr)
|
||||
hasNfp = any (\m ->
|
||||
at ["match","left","meta","key"] m == Just (A.String "nfproto"))
|
||||
matches
|
||||
assertBool "Expected nfproto match for IPv4 ctor" hasNfp
|
||||
|
||||
, testCase "record field pat emits payload match" $ do
|
||||
v <- compileToValue
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | Frame(_, TCP(tcp { dport = :22 }, _)) -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
arr <- nftArr v
|
||||
let matches = withKey "match" (ruleExprs arr)
|
||||
hasPort = any (\m ->
|
||||
at ["match","right"] m == Just (A.String "22"))
|
||||
matches
|
||||
assertBool "Expected port 22 payload match" hasPort
|
||||
]
|
||||
|
||||
-- ─── Verdict tests ───────────────────────────────────────────────────────────
|
||||
|
||||
allExprs :: [A.Value] -> [A.Value]
|
||||
allExprs arr =
|
||||
concatMap (\r -> case at ["rule","expr"] r of
|
||||
Just (A.Array es) -> V.toList es; _ -> [])
|
||||
(withKey "rule" arr)
|
||||
|
||||
verdictTests :: TestTree
|
||||
verdictTests = testGroup "verdicts"
|
||||
[ testCase "Allow compiles to accept" $ do
|
||||
v <- compileToValue
|
||||
"policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
arr <- nftArr v
|
||||
assertBool "Expected accept verdict"
|
||||
(not (null (withKey "accept" (allExprs arr))))
|
||||
|
||||
, testCase "Drop compiles to drop" $ do
|
||||
v <- compileToValue
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Drop; };"
|
||||
arr <- nftArr v
|
||||
assertBool "Expected drop verdict"
|
||||
(not (null (withKey "drop" (allExprs arr))))
|
||||
|
||||
, testCase "Masquerade compiles to masquerade" $ do
|
||||
v <- compileToValue
|
||||
"policy nat_post : Frame \
|
||||
\ on { hook = Postrouting, table = NAT, priority = SrcNat } \
|
||||
\ = { | _ -> Masquerade; };"
|
||||
arr <- nftArr v
|
||||
assertBool "Expected masquerade verdict"
|
||||
(not (null (withKey "masquerade" (allExprs arr))))
|
||||
|
||||
, testCase "rule call compiles to jump" $ do
|
||||
v <- compileToValue
|
||||
"rule blockAll : Frame -> Action = \\f -> case f of { | _ -> Drop; }; \
|
||||
\policy fwd : Frame \
|
||||
\ on { hook = Forward, table = Filter, priority = Filter } \
|
||||
\ = { | frame -> blockAll(frame); };"
|
||||
arr <- nftArr v
|
||||
assertBool "Expected jump verdict for rule call"
|
||||
(not (null (withKey "jump" (allExprs arr))))
|
||||
]
|
||||
|
||||
-- ─── Layer stripping tests ───────────────────────────────────────────────────
|
||||
|
||||
layerStrippingTests :: TestTree
|
||||
layerStrippingTests = testGroup "layer stripping"
|
||||
[ testCase "Frame with and without Ether both emit nfproto match" $ do
|
||||
let withEther =
|
||||
"policy p1 : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | Frame(_, Ether(_, IPv4(ip, _))) -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
withoutEther =
|
||||
"policy p1 : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | Frame(_, IPv4(ip, _)) -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
v1 <- compileToValue withEther
|
||||
v2 <- compileToValue withoutEther
|
||||
arr1 <- nftArr v1
|
||||
arr2 <- nftArr v2
|
||||
let nfp arr = filter
|
||||
(\m -> at ["match","left","meta","key"] m == Just (A.String "nfproto"))
|
||||
(withKey "match" (ruleExprs arr))
|
||||
assertBool "Both should produce nfproto matches"
|
||||
(not (null (nfp arr1)) && not (null (nfp arr2)))
|
||||
]
|
||||
|
||||
-- ─── Continue tests ───────────────────────────────────────────────────────────
|
||||
|
||||
continueTests :: TestTree
|
||||
continueTests = testGroup "Continue"
|
||||
[ testCase "two terminal arms produce two rules" $ do
|
||||
v <- compileToValue
|
||||
"policy fwd : Frame \
|
||||
\ on { hook = Forward, table = Filter, priority = Filter } \
|
||||
\ = { | _ if ct.state in { Established } -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
arr <- nftArr v
|
||||
length (withKey "rule" arr) @?= 2
|
||||
|
||||
, testCase "non-Continue arms still produce rules" $ do
|
||||
v <- compileToValue
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ if ct.state in { Established } -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
arr <- nftArr v
|
||||
assertBool "Should have rules for non-Continue arms"
|
||||
(not (null (withKey "rule" arr)))
|
||||
]
|
||||
|
||||
-- ─── Config tests ─────────────────────────────────────────────────────────────
|
||||
|
||||
configTests :: TestTree
|
||||
configTests = testGroup "config"
|
||||
[ testCase "all rule objects reference correct table" $ do
|
||||
v <- compileToValue
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Drop; };"
|
||||
arr <- nftArr v
|
||||
mapM_ (\r -> at ["rule","table"] r @?= Just (A.String "fwl"))
|
||||
(withKey "rule" arr)
|
||||
|
||||
, testCase "chain objects reference correct table" $ do
|
||||
v <- compileToValue
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Drop; };"
|
||||
arr <- nftArr v
|
||||
mapM_ (\c -> at ["chain","table"] c @?= Just (A.String "fwl"))
|
||||
(withKey "chain" arr)
|
||||
]
|
||||
44
test/FWL/Util.hs
Normal file
44
test/FWL/Util.hs
Normal file
@@ -0,0 +1,44 @@
|
||||
-- | Shared test utilities.
|
||||
module FWL.Util where
|
||||
|
||||
import Test.Tasty.HUnit
|
||||
import Text.Parsec.String (Parser)
|
||||
import Text.Parsec (parse)
|
||||
|
||||
import FWL.Parser (parseProgram)
|
||||
import FWL.AST
|
||||
|
||||
-- | Assert a parser succeeds and return the result.
|
||||
shouldParse :: (Show a) => Parser a -> String -> IO a
|
||||
shouldParse p input =
|
||||
case parse p "<test>" input of
|
||||
Left err -> assertFailure ("Unexpected parse error:\n" ++ show err)
|
||||
>> undefined
|
||||
Right v -> return v
|
||||
|
||||
-- | Assert a parser fails.
|
||||
shouldFailParse :: (Show a) => Parser a -> String -> IO ()
|
||||
shouldFailParse p input =
|
||||
case parse p "<test>" input of
|
||||
Left _ -> return ()
|
||||
Right v -> assertFailure ("Expected parse failure but got: " ++ show v)
|
||||
|
||||
-- | Parse a full program, asserting success.
|
||||
parseOk :: String -> IO Program
|
||||
parseOk src =
|
||||
case parseProgram "<test>" src of
|
||||
Left err -> assertFailure ("Parse error:\n" ++ show err) >> undefined
|
||||
Right p -> return p
|
||||
|
||||
-- | Parse a full program, asserting failure.
|
||||
parseFail :: String -> IO ()
|
||||
parseFail src =
|
||||
case parseProgram "<test>" src of
|
||||
Left _ -> return ()
|
||||
Right p -> assertFailure ("Expected parse failure, got:\n" ++ show p)
|
||||
|
||||
-- | Extract the single declaration from a one-decl program.
|
||||
singleDecl :: Program -> IO Decl
|
||||
singleDecl (Program _ [d]) = return d
|
||||
singleDecl (Program _ ds) =
|
||||
assertFailure ("Expected 1 decl, got " ++ show (length ds)) >> undefined
|
||||
516
test/ParserTests.hs
Normal file
516
test/ParserTests.hs
Normal file
@@ -0,0 +1,516 @@
|
||||
module ParserTests (tests) where
|
||||
|
||||
import Test.Tasty
|
||||
import Test.Tasty.HUnit
|
||||
|
||||
import FWL.AST
|
||||
import FWL.Util
|
||||
|
||||
tests :: TestTree
|
||||
tests = testGroup "Parser"
|
||||
[ interfaceTests
|
||||
, zoneTests
|
||||
, importTests
|
||||
, letTests
|
||||
, patternTests
|
||||
, flowTests
|
||||
, typeTests
|
||||
, exprTests
|
||||
, policyTests
|
||||
, ruleTests
|
||||
, configTests
|
||||
, errorTests
|
||||
]
|
||||
|
||||
-- ─── Interface ───────────────────────────────────────────────────────────────
|
||||
|
||||
interfaceTests :: TestTree
|
||||
interfaceTests = testGroup "interface"
|
||||
[ testCase "WAN dynamic" $ do
|
||||
p <- parseOk "interface wan : WAN { dynamic; };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DInterface "wan" IWan [IPDynamic] -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "LAN with cidr4" $ do
|
||||
p <- parseOk "interface lan : LAN { cidr4 = { 10.0.0.0/8 }; };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DInterface "lan" ILan [IPCidr4 [(LIPv4 (10,0,0,0), 8)]] -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "LAN with cidr4 and cidr6" $ do
|
||||
p <- parseOk
|
||||
"interface lan : LAN { \
|
||||
\ cidr4 = { 10.17.1.0/24 }; \
|
||||
\ cidr6 = { 192.168.0.0/16 }; \
|
||||
\};"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DInterface "lan" ILan [IPCidr4 _, IPCidr6 _] -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "WireGuard interface" $ do
|
||||
p <- parseOk "interface wg0 : WireGuard {};"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DInterface "wg0" IWireGuard [] -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "user-defined kind" $ do
|
||||
p <- parseOk "interface eth0 : Bridge {};"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DInterface "eth0" (IUser "Bridge") [] -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "multiple CIDRs in set" $ do
|
||||
p <- parseOk
|
||||
"interface lan : LAN { \
|
||||
\ cidr4 = { 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 }; \
|
||||
\};"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DInterface _ _ [IPCidr4 cidrs] -> length cidrs @?= 3
|
||||
_ -> assertFailure (show d)
|
||||
]
|
||||
|
||||
-- ─── Zone ────────────────────────────────────────────────────────────────────
|
||||
|
||||
zoneTests :: TestTree
|
||||
zoneTests = testGroup "zone"
|
||||
[ testCase "single member" $ do
|
||||
p <- parseOk "zone trusted = { lan };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DZone "trusted" ["lan"] -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "multiple members" $ do
|
||||
p <- parseOk "zone lan_zone = { lan, wg0, vlan10 };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DZone "lan_zone" ["lan","wg0","vlan10"] -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
]
|
||||
|
||||
-- ─── Import ──────────────────────────────────────────────────────────────────
|
||||
|
||||
importTests :: TestTree
|
||||
importTests = testGroup "import"
|
||||
[ testCase "basic import" $ do
|
||||
p <- parseOk "import rfc1918 : CIDRSet from \"builtin:rfc1918\";"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DImport "rfc1918" (TName "CIDRSet" []) "builtin:rfc1918" -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
]
|
||||
|
||||
-- ─── Let ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
letTests :: TestTree
|
||||
letTests = testGroup "let"
|
||||
[ testCase "simple integer" $ do
|
||||
p <- parseOk "let timeout : Int = 30;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet "timeout" (TName "Int" []) (ELit (LInt 30)) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "map literal" $ do
|
||||
p <- parseOk
|
||||
"let forwards : Map<(Protocol,Port),(IP,Port)> = { \
|
||||
\ (tcp, :8080) -> (10.0.0.1, :80) \
|
||||
\};"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet "forwards" _ (EMap [_]) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "string literal" $ do
|
||||
p <- parseOk "let name : String = \"hello\";"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet "name" _ (ELit (LString "hello")) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
]
|
||||
|
||||
-- ─── Pattern ─────────────────────────────────────────────────────────────────
|
||||
|
||||
patternTests :: TestTree
|
||||
patternTests = testGroup "pattern"
|
||||
[ testCase "tuple with record field" $ do
|
||||
p <- parseOk
|
||||
"pattern WGInitiation : (UDPHeader, Bytes) = \
|
||||
\ (udp { length = 156 }, [0x01 _*]);"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPattern "WGInitiation" _ (PTuple [PRecord "udp" _, PBytes _]) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "byte pattern elements" $ do
|
||||
p <- parseOk
|
||||
"pattern WGResponse : (UDPHeader, Bytes) = \
|
||||
\ (udp { length = 100 }, [0x02 _ _*]);"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPattern "WGResponse" _ (PTuple [_, PBytes [BEHex 0x02, BEWild, BEWildStar]]) ->
|
||||
return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "named pattern reference in ctor" $ do
|
||||
p <- parseOk
|
||||
"pattern Complex : Frame = \
|
||||
\ Frame(_, IPv4(ip, WGInitiation));"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPattern "Complex" _ (PFrame Nothing (PCtor "IPv4" [PVar "ip", PNamed "WGInitiation"])) ->
|
||||
return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "record with field bind" $ do
|
||||
p <- parseOk "pattern HasTCP : TCP = tcp { dport };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPattern "HasTCP" _ (PRecord "tcp" [FPBind "dport"]) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "record with field equality" $ do
|
||||
p <- parseOk "pattern SSH : TCP = tcp { dport = :22 };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPattern "SSH" _ (PRecord "tcp" [FPEq "dport" (LPort 22)]) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
]
|
||||
|
||||
-- ─── Flow ────────────────────────────────────────────────────────────────────
|
||||
|
||||
flowTests :: TestTree
|
||||
flowTests = testGroup "flow"
|
||||
[ testCase "two-step sequence with within" $ do
|
||||
p <- parseOk
|
||||
"flow WireGuardHandshake : FlowPattern = \
|
||||
\ WGInitiation . WGResponse within 5s;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DFlow "WireGuardHandshake" (FSeq (FAtom "WGInitiation") (FAtom "WGResponse") (Just (5, Seconds))) ->
|
||||
return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "single atom flow" $ do
|
||||
p <- parseOk "flow Simple : FlowPattern = Ping;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DFlow "Simple" (FAtom "Ping") -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "duration in milliseconds" $ do
|
||||
p <- parseOk "flow Fast : FlowPattern = A . B within 500ms;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DFlow "Fast" (FSeq _ _ (Just (500, Millis))) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
]
|
||||
|
||||
-- ─── Types ───────────────────────────────────────────────────────────────────
|
||||
|
||||
typeTests :: TestTree
|
||||
typeTests = testGroup "types"
|
||||
[ testCase "simple name" $ do
|
||||
p <- parseOk "let x : Frame = Allow;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ (TName "Frame" []) _ -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "generic type" $ do
|
||||
p <- parseOk "let x : Map<Int, String> = Allow;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ (TName "Map" [TName "Int" [], TName "String" []]) _ -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "function type" $ do
|
||||
p <- parseOk "let x : Frame -> Action = Allow;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ (TFun (TName "Frame" []) (TName "Action" [])) _ -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "effect type" $ do
|
||||
p <- parseOk "let x : <Log, FlowMatch> Action = Allow;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ (TEffect ["Log","FlowMatch"] (TName "Action" [])) _ -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "tuple type" $ do
|
||||
p <- parseOk "let x : (Int, String) = Allow;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ (TTuple [TName "Int" [], TName "String" []]) _ -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "function with effects" $ do
|
||||
p <- parseOk "let x : Frame -> <Log> Action = Allow;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ (TFun _ (TEffect ["Log"] _)) _ -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
]
|
||||
|
||||
-- ─── Expressions ─────────────────────────────────────────────────────────────
|
||||
|
||||
exprTests :: TestTree
|
||||
exprTests = testGroup "expressions"
|
||||
[ testCase "boolean and" $ do
|
||||
p <- parseOk "let x : Bool = a && b;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (EInfix OpAnd (EVar "a") (EVar "b")) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "set membership with 'in'" $ do
|
||||
p <- parseOk "let x : Bool = ct.state in { Established, Related };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (EInfix OpIn (EQual ["ct","state"]) (ESet _)) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "equality comparison" $ do
|
||||
p <- parseOk "let x : Bool = tcp.dport == :22;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (EInfix OpEq (EQual ["tcp","dport"]) (ELit (LPort 22))) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "if-then-else" $ do
|
||||
p <- parseOk "let x : Action = if a then Allow else Drop;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (EIf (EVar "a") (EVar "Allow") (EVar "Drop")) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "perform expression" $ do
|
||||
p <- parseOk "let x : Action = perform Log.emit(Info, \"msg\");"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (EPerform ["Log","emit"] [ELit (LString "Info"), ELit (LString "msg")]) -> return ()
|
||||
DLet _ _ (EPerform ["Log","emit"] _) -> return () -- arg parsing flexible
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "do block" $ do
|
||||
p <- parseOk "let x : Action = do { y <- foo; y };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (EDo [DSBind "y" _, DSExpr (EVar "y")]) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "nested case" $ do
|
||||
p <- parseOk
|
||||
"let x : Action = case e of { \
|
||||
\ | a -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\};"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (ECase (EVar "e") [Arm (PVar "a") Nothing _, Arm PWild Nothing _]) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "lambda" $ do
|
||||
p <- parseOk "let x : Frame -> Action = \\frame -> Allow;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (ELam "frame" (EVar "Allow")) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "string concat" $ do
|
||||
p <- parseOk "let x : String = \"hello\" ++ \" world\";"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (EInfix OpConcat _ _) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "negation" $ do
|
||||
p <- parseOk "let x : Bool = !flag;"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (ENot (EVar "flag")) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "set literal" $ do
|
||||
p <- parseOk "let x : Set<Int> = { 22, 80, 443 };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DLet _ _ (ESet [ELit (LInt 22), ELit (LInt 80), ELit (LInt 443)]) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
]
|
||||
|
||||
-- ─── Policy ──────────────────────────────────────────────────────────────────
|
||||
|
||||
policyTests :: TestTree
|
||||
policyTests = testGroup "policy"
|
||||
[ testCase "minimal policy" $ do
|
||||
p <- parseOk
|
||||
"policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPolicy "output" _ (PolicyMeta HOutput TFilter (Priority 0)) [_] -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "NAT prerouting" $ do
|
||||
p <- parseOk
|
||||
"policy nat_pre : Frame \
|
||||
\ on { hook = Prerouting, table = NAT, priority = DstNat } \
|
||||
\ = { | _ -> Allow; };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPolicy _ _ (PolicyMeta HPrerouting TNAT (Priority (-100))) _ -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "arm with guard" $ do
|
||||
p <- parseOk
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { \
|
||||
\ | _ if ct.state in { Established, Related } -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPolicy _ _ _ [Arm PWild (Just _) _, Arm PWild Nothing _] -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "Frame pattern with path" $ do
|
||||
p <- parseOk
|
||||
"policy forward : Frame \
|
||||
\ on { hook = Forward, table = Filter, priority = Filter } \
|
||||
\ = { \
|
||||
\ | Frame(iif in lan_zone -> wan, _) -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPolicy _ _ _ (Arm (PFrame (Just _) _) Nothing _ : _) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "Frame pattern without Ether (layer stripping)" $ do
|
||||
p <- parseOk
|
||||
"policy input : Frame \
|
||||
\ on { hook = Input, table = Filter, priority = Filter } \
|
||||
\ = { \
|
||||
\ | Frame(_, IPv4(ip, TCP(tcp, _))) if tcp.dport == :22 -> Allow; \
|
||||
\ | _ -> Drop; \
|
||||
\ };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPolicy _ _ _ (Arm (PFrame Nothing (PCtor "IPv4" _)) _ _ : _) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "policy arm calls rule" $ do
|
||||
p <- parseOk
|
||||
"policy forward : Frame \
|
||||
\ on { hook = Forward, table = Filter, priority = Filter } \
|
||||
\ = { \
|
||||
\ | frame -> blockOutboundWG(frame); \
|
||||
\ };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPolicy _ _ _ [Arm (PVar "frame") Nothing (EApp (EVar "blockOutboundWG") _)] ->
|
||||
return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "Continue arm is parsed" $ do
|
||||
p <- parseOk
|
||||
"rule r : Frame -> Action = \
|
||||
\ \\frame -> case frame of { \
|
||||
\ | _ -> Continue; \
|
||||
\ };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DRule _ _ _ -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
]
|
||||
|
||||
-- ─── Rule ────────────────────────────────────────────────────────────────────
|
||||
|
||||
ruleTests :: TestTree
|
||||
ruleTests = testGroup "rule"
|
||||
[ testCase "simple rule" $ do
|
||||
p <- parseOk
|
||||
"rule blockAll : Frame -> Action = \
|
||||
\ \\frame -> case frame of { | _ -> Drop; };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DRule "blockAll" _ (ELam "frame" (ECase _ _)) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "rule with effects in type" $ do
|
||||
p <- parseOk
|
||||
"rule logged : Frame -> <Log> Action = \
|
||||
\ \\f -> case f of { | _ -> Allow; };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DRule "logged" (TFun _ (TEffect ["Log"] _)) _ -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "nested case in rule" $ do
|
||||
p <- parseOk
|
||||
"rule check : Frame -> <FlowMatch> Action = \
|
||||
\ \\frame -> \
|
||||
\ case frame of { \
|
||||
\ | Frame(_, IPv4(ip, UDP(udp, _))) -> \
|
||||
\ case perform FlowMatch.check(ip, wg) of { \
|
||||
\ | Matched -> Drop; \
|
||||
\ | _ -> Continue; \
|
||||
\ }; \
|
||||
\ | _ -> Continue; \
|
||||
\ };"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DRule "check" _ (ELam _ (ECase _ _)) -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
]
|
||||
|
||||
-- ─── Config ──────────────────────────────────────────────────────────────────
|
||||
|
||||
configTests :: TestTree
|
||||
configTests = testGroup "config"
|
||||
[ testCase "default table name" $ do
|
||||
p <- parseOk "interface wan : WAN {};"
|
||||
configTable (progConfig p) @?= "fwl"
|
||||
|
||||
, testCase "custom table name" $ do
|
||||
p <- parseOk "config { table = \"myrules\"; } interface wan : WAN {};"
|
||||
configTable (progConfig p) @?= "myrules"
|
||||
]
|
||||
|
||||
-- ─── Error cases ─────────────────────────────────────────────────────────────
|
||||
|
||||
errorTests :: TestTree
|
||||
errorTests = testGroup "parse errors"
|
||||
[ testCase "missing semicolon" $
|
||||
parseFail "interface wan : WAN {}"
|
||||
|
||||
, testCase "unknown hook" $
|
||||
parseFail
|
||||
"policy p : Frame \
|
||||
\ on { hook = Bogus, table = Filter, priority = Filter } \
|
||||
\ = { | _ -> Allow; };"
|
||||
|
||||
, testCase "empty arm block with no arms is ok" $ do
|
||||
p <- parseOk
|
||||
"policy output : Frame \
|
||||
\ on { hook = Output, table = Filter, priority = Filter } \
|
||||
\ = {};"
|
||||
d <- singleDecl p
|
||||
case d of
|
||||
DPolicy _ _ _ [] -> return ()
|
||||
_ -> assertFailure (show d)
|
||||
|
||||
, testCase "CIDR without prefix fails" $
|
||||
parseFail "interface lan : LAN { cidr4 = { 10.0.0.1 }; };"
|
||||
]
|
||||
15
test/Spec.hs
Normal file
15
test/Spec.hs
Normal file
@@ -0,0 +1,15 @@
|
||||
module Main where
|
||||
|
||||
import Test.Tasty
|
||||
import Test.Tasty.HUnit
|
||||
|
||||
import qualified ParserTests
|
||||
import qualified CheckTests
|
||||
import qualified CompileTests
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain $ testGroup "FWL"
|
||||
[ ParserTests.tests
|
||||
, CheckTests.tests
|
||||
, CompileTests.tests
|
||||
]
|
||||
Reference in New Issue
Block a user