{-# LANGUAGE OverloadedStrings #-} module CompileTests (tests) where import Test.Tasty import Test.Tasty.HUnit import qualified Data.Aeson as A import qualified Data.Aeson.Key as AK import qualified Data.Aeson.KeyMap as AKM import qualified Data.Vector as V import qualified Data.ByteString.Lazy.Char8 as BL8 import FWL.AST import FWL.Compile import FWL.Util tests :: TestTree tests = testGroup "Compile" [ jsonStructureTests , chainTests , ruleExprTests , verdictTests , layerStrippingTests , continueTests , configTests , filterInjectionTests , portforwardCompileTests , masqueradeCompileTests ] -- ─── Helpers ───────────────────────────────────────────────────────────────── compileToValue :: String -> IO A.Value compileToValue src = do p <- parseOk src case A.decode (compileToJson p) of Nothing -> assertFailure "Compiled output is not valid JSON" >> undefined Just v -> return v -- Navigate a Value by a list of string keys / numeric indices. at :: [String] -> A.Value -> Maybe A.Value at [] v = Just v at (k:ks) (A.Object o) = case AKM.lookup (AK.fromString k) o of Nothing -> Nothing Just v -> at ks v at (k:ks) (A.Array arr) = case reads k of [(i,"")] | i < V.length arr -> at ks (arr V.! i) _ -> Nothing at _ _ = Nothing nftArr :: A.Value -> IO [A.Value] nftArr v = case at ["nftables"] v of Just (A.Array arr) -> return (V.toList arr) _ -> assertFailure "Missing top-level 'nftables' array" >> undefined withKey :: String -> [A.Value] -> [A.Value] withKey k = filter (\v -> case at [k] v of Just _ -> True; _ -> False) -- ─── JSON structure tests ──────────────────────────────────────────────────── jsonStructureTests :: TestTree jsonStructureTests = testGroup "JSON structure" [ testCase "output is valid JSON" $ do _ <- compileToValue "policy output : Frame hook Output = { | _ -> Allow; };" return () , testCase "top-level nftables array present" $ do v <- compileToValue "policy output : Frame hook Output = { | _ -> Allow; };" _ <- nftArr v return () , testCase "metainfo is first element" $ do v <- compileToValue "policy output : Frame hook Output = { | _ -> Allow; };" arr <- nftArr v case arr of (first:_) -> case at ["metainfo"] first of Just _ -> return () Nothing -> assertFailure "First element is not metainfo" [] -> assertFailure "Empty nftables array" , testCase "table object present" $ do v <- compileToValue "policy output : Frame hook Output = { | _ -> Allow; };" arr <- nftArr v assertBool "Expected at least one table object" (not (null (withKey "table" arr))) , testCase "default table name is fwl" $ do v <- compileToValue "policy output : Frame hook Output = { | _ -> Allow; };" arr <- nftArr v case withKey "table" arr of (t:_) -> at ["table","name"] t @?= Just (A.String "fwl") [] -> assertFailure "No table object" , testCase "custom table name respected" $ do v <- compileToValue "config { table = \"custom\"; } \ \policy output : Frame hook Output = { | _ -> Allow; };" arr <- nftArr v case withKey "table" arr of (t:_) -> at ["table","name"] t @?= Just (A.String "custom") [] -> assertFailure "No table object" ] -- ─── Chain declaration tests ───────────────────────────────────────────────── chainTests :: TestTree chainTests = testGroup "chain declarations" [ testCase "filter input chain has correct hook" $ do v <- compileToValue "policy input : Frame hook Input = { | _ -> Drop; };" arr <- nftArr v case withKey "chain" arr of (c:_) -> at ["chain","hook"] c @?= Just (A.String "input") [] -> assertFailure "No chain" , testCase "filter chain type is filter" $ do v <- compileToValue "policy fwd : Frame hook Forward = { | _ -> Drop; };" arr <- nftArr v case withKey "chain" arr of (c:_) -> at ["chain","type"] c @?= Just (A.String "filter") [] -> assertFailure "No chain" , testCase "NAT chain type is nat" $ do v <- compileToValue "policy nat_post : Frame hook Postrouting = { | _ -> Allow; };" arr <- nftArr v case withKey "chain" arr of (c:_) -> at ["chain","type"] c @?= Just (A.String "nat") [] -> assertFailure "No chain" , testCase "input chain default policy is drop" $ do v <- compileToValue "policy input : Frame hook Input = { | _ -> Drop; };" arr <- nftArr v case withKey "chain" arr of (c:_) -> at ["chain","policy"] c @?= Just (A.String "drop") [] -> assertFailure "No chain" , testCase "output chain default policy is accept" $ do v <- compileToValue "policy output : Frame hook Output = { | _ -> Allow; };" arr <- nftArr v case withKey "chain" arr of (c:_) -> at ["chain","policy"] c @?= Just (A.String "accept") [] -> assertFailure "No chain" , testCase "chain name matches policy name" $ do v <- compileToValue "policy my_input : Frame hook Input = { | _ -> Drop; };" arr <- nftArr v case withKey "chain" arr of (c:_) -> at ["chain","name"] c @?= Just (A.String "my_input") [] -> assertFailure "No chain" , testCase "two policies produce two chains" $ do v <- compileToValue "policy input : Frame hook Input = { | _ -> Drop; }; \ \policy output : Frame hook Output = { | _ -> Allow; };" arr <- nftArr v length (withKey "chain" arr) @?= 2 ] -- ─── Rule expression tests ─────────────────────────────────────────────────── ruleExprs :: [A.Value] -> [A.Value] ruleExprs arr = [ e | r <- withKey "rule" arr , Just (A.Array es) <- [at ["rule","expr"] r] , e <- V.toList es ] ruleExprTests :: TestTree ruleExprTests = testGroup "rule expressions" [ testCase "arm without guard produces rule" $ do v <- compileToValue "policy output : Frame hook Output = { | _ -> Allow; };" arr <- nftArr v assertBool "Should have at least one rule" (not (null (withKey "rule" arr))) , testCase "rule expr array is present" $ do v <- compileToValue "policy output : Frame hook Output = { | _ -> Allow; };" arr <- nftArr v case withKey "rule" arr of (r:_) -> case at ["rule","expr"] r of Just (A.Array _) -> return () _ -> assertFailure "Missing or non-array 'expr'" [] -> assertFailure "No rule" , testCase "IPv4 ctor emits nfproto match" $ do v <- compileToValue "policy input : Frame hook Input = \ \ { | Frame(_, IPv4(ip, _)) -> Allow; \ \ | _ -> Drop; \ \ };" arr <- nftArr v let matches = withKey "match" (ruleExprs arr) hasNfp = any (\m -> at ["match","left","meta","key"] m == Just (A.String "nfproto")) matches assertBool "Expected nfproto match for IPv4 ctor" hasNfp , testCase "record field pat emits payload match" $ do v <- compileToValue "policy input : Frame hook Input = \ \ { | Frame(_, TCP(tcp { dport = :22 }, _)) -> Allow; \ \ | _ -> Drop; \ \ };" arr <- nftArr v let matches = withKey "match" (ruleExprs arr) hasPort = any (\m -> at ["match","right"] m == Just (A.String "22")) matches assertBool "Expected port 22 payload match" hasPort ] -- ─── Verdict tests ─────────────────────────────────────────────────────────── allExprs :: [A.Value] -> [A.Value] allExprs arr = concatMap (\r -> case at ["rule","expr"] r of Just (A.Array es) -> V.toList es; _ -> []) (withKey "rule" arr) verdictTests :: TestTree verdictTests = testGroup "verdicts" [ testCase "Allow compiles to accept" $ do v <- compileToValue "policy output : Frame hook Output = { | _ -> Allow; };" arr <- nftArr v assertBool "Expected accept verdict" (not (null (withKey "accept" (allExprs arr)))) , testCase "Drop compiles to drop" $ do v <- compileToValue "policy input : Frame hook Input = { | _ -> Drop; };" arr <- nftArr v assertBool "Expected drop verdict" (not (null (withKey "drop" (allExprs arr)))) , testCase "Masquerade compiles to masquerade" $ do v <- compileToValue "policy nat_post : Frame hook Postrouting = { | _ -> Masquerade; };" arr <- nftArr v assertBool "Expected masquerade verdict" (not (null (withKey "masquerade" (allExprs arr)))) , testCase "rule call compiles to jump" $ do v <- compileToValue "rule blockAll : Frame -> Action = \\f -> case f of { | _ -> Drop; }; \ \policy fwd : Frame hook Forward = { | frame -> blockAll(frame); };" arr <- nftArr v assertBool "Expected jump verdict for rule call" (not (null (withKey "jump" (allExprs arr)))) ] -- ─── Layer stripping tests ─────────────────────────────────────────────────── layerStrippingTests :: TestTree layerStrippingTests = testGroup "layer stripping" [ testCase "Frame with and without Ether both emit nfproto match" $ do let withEther = "policy p1 : Frame hook Input = \ \ { | Frame(_, Ether(_, IPv4(ip, _))) -> Allow; \ \ | _ -> Drop; \ \ };" withoutEther = "policy p1 : Frame hook Input = \ \ { | Frame(_, IPv4(ip, _)) -> Allow; \ \ | _ -> Drop; \ \ };" v1 <- compileToValue withEther v2 <- compileToValue withoutEther arr1 <- nftArr v1 arr2 <- nftArr v2 let nfp arr = filter (\m -> at ["match","left","meta","key"] m == Just (A.String "nfproto")) (withKey "match" (ruleExprs arr)) assertBool "Both should produce nfproto matches" (not (null (nfp arr1)) && not (null (nfp arr2))) ] -- ─── Continue tests ─────────────────────────────────────────────────────────── continueTests :: TestTree continueTests = testGroup "Continue" [ testCase "non-Continue arms still produce rules" $ do v <- compileToValue "policy input : Frame hook Input = \ \ { | _ if ct.state in { Established } -> Allow; \ \ | _ -> Drop; \ \ };" arr <- nftArr v assertBool "Should have rules for non-Continue arms" (not (null (withKey "rule" arr))) ] -- ─── Config tests ───────────────────────────────────────────────────────────── configTests :: TestTree configTests = testGroup "config" [ testCase "all rule objects reference correct table" $ do v <- compileToValue "policy input : Frame hook Input = { | _ -> Drop; };" arr <- nftArr v mapM_ (\r -> at ["rule","table"] r @?= Just (A.String "fwl")) (withKey "rule" arr) , testCase "chain objects reference correct table" $ do v <- compileToValue "policy input : Frame hook Input = { | _ -> Drop; };" arr <- nftArr v mapM_ (\c -> at ["chain","table"] c @?= Just (A.String "fwl")) (withKey "chain" arr) ] -- ─── Filter-hook injection tests ───────────────────────────────────────────── filterInjectionTests :: TestTree filterInjectionTests = testGroup "filter hook injections" [ testCase "Input chain first rule is stateful ct state" $ do v <- compileToValue "policy input : Frame hook Input = { | _ -> Drop; };" arr <- nftArr v let rules = withKey "rule" arr inputRules = filter (\r -> at ["rule","chain"] r == Just (A.String "input")) rules case inputRules of (r:_) -> case at ["rule","expr","0","match","left","ct","key"] r of Just (A.String "state") -> return () _ -> case at ["rule","expr"] r of Just (A.Array es) -> let exprs = V.toList es hasState = any (\e -> at ["match","left","ct","key"] e == Just (A.String "state")) exprs in assertBool "First rule should have ct state match" hasState _ -> assertFailure "No expr in first rule" [] -> assertFailure "No rules for input chain" , testCase "Input chain has loopback rule (iifname lo)" $ do v <- compileToValue "policy input : Frame hook Input = { | _ -> Drop; };" arr <- nftArr v let rules = withKey "rule" arr inputRules = filter (\r -> at ["rule","chain"] r == Just (A.String "input")) rules hasLo = any (\r -> case at ["rule","expr"] r of Just (A.Array es) -> any (\e -> at ["match","right"] e == Just (A.String "lo")) (V.toList es) _ -> False) inputRules assertBool "Input chain should have iifname lo rule" hasLo , testCase "Forward chain first rule is stateful ct state" $ do v <- compileToValue "policy forward : Frame hook Forward = { | _ -> Drop; };" arr <- nftArr v let rules = withKey "rule" arr fwdRules = filter (\r -> at ["rule","chain"] r == Just (A.String "forward")) rules case fwdRules of (r:_) -> case at ["rule","expr"] r of Just (A.Array es) -> let hasState = any (\e -> at ["match","left","ct","key"] e == Just (A.String "state")) (V.toList es) in assertBool "First forward rule should have ct state match" hasState _ -> assertFailure "No expr" [] -> assertFailure "No rules for forward chain" , testCase "Output chain has stateful rule but no loopback" $ do v <- compileToValue "policy output : Frame hook Output = { | _ -> Allow; };" arr <- nftArr v let rules = withKey "rule" arr outRules = filter (\r -> at ["rule","chain"] r == Just (A.String "output")) rules hasState = any (\r -> case at ["rule","expr"] r of Just (A.Array es) -> any (\e -> at ["match","left","ct","key"] e == Just (A.String "state")) (V.toList es) _ -> False) outRules hasLo = any (\r -> case at ["rule","expr"] r of Just (A.Array es) -> any (\e -> at ["match","right"] e == Just (A.String "lo")) (V.toList es) _ -> False) outRules assertBool "Output chain should have ct state rule" hasState assertBool "Output chain should NOT have loopback rule" (not hasLo) ] -- ─── PortForward compile tests ─────────────────────────────────────────────── portforwardCompileTests :: TestTree portforwardCompileTests = testGroup "portforward compilation" [ testCase "portforward produces a map object with the decl name" $ do v <- compileToValue "portforward wan_forwards on wan via Map<(Protocol, Port), (IPv4, Port)> = { \ \ (tcp, :8080) -> (10.0.0.10, :80) \ \}; \ \policy forward : Frame hook Forward = { | _ -> Drop; };" arr <- nftArr v let maps = withKey "map" arr named = filter (\m -> at ["map","name"] m == Just (A.String "wan_forwards")) maps assertBool "Should have a map named wan_forwards" (not (null named)) , testCase "portforward produces prerouting chain" $ do v <- compileToValue "portforward wan_forwards on wan via Map<(Protocol, Port), (IPv4, Port)> = { \ \ (tcp, :8080) -> (10.0.0.10, :80) \ \}; \ \policy forward : Frame hook Forward = { | _ -> Drop; };" arr <- nftArr v let chains = withKey "chain" arr preChain = filter (\c -> at ["chain","name"] c == Just (A.String "wan_forwards_prerouting")) chains assertBool "Should have wan_forwards_prerouting chain" (not (null preChain)) case preChain of (c:_) -> do at ["chain","type"] c @?= Just (A.String "nat") at ["chain","hook"] c @?= Just (A.String "prerouting") [] -> return () , testCase "portforward injects ct status dnat accept into Forward chain" $ do v <- compileToValue "portforward wan_forwards on wan via Map<(Protocol, Port), (IPv4, Port)> = { \ \ (tcp, :8080) -> (10.0.0.10, :80) \ \}; \ \policy forward : Frame hook Forward = { | _ -> Drop; };" arr <- nftArr v let rules = withKey "rule" arr fwdRules = filter (\r -> at ["rule","chain"] r == Just (A.String "forward")) rules hasDnat = any (\r -> case at ["rule","expr"] r of Just (A.Array es) -> any (\e -> at ["match","left","ct","key"] e == Just (A.String "status")) (V.toList es) _ -> False) fwdRules assertBool "Forward chain should have ct status dnat rule when portforward present" hasDnat ] -- ─── Masquerade compile tests ──────────────────────────────────────────────── masqueradeCompileTests :: TestTree masqueradeCompileTests = testGroup "masquerade compilation" [ testCase "masquerade produces postrouting chain" $ do v <- compileToValue "let rfc1918 : Set = { 10.0.0.0/8 }; \ \masquerade wan_snat on wan src rfc1918;" arr <- nftArr v let chains = withKey "chain" arr postChain = filter (\c -> at ["chain","name"] c == Just (A.String "wan_snat_postrouting")) chains assertBool "Should have wan_snat_postrouting chain" (not (null postChain)) case postChain of (c:_) -> do at ["chain","type"] c @?= Just (A.String "nat") at ["chain","hook"] c @?= Just (A.String "postrouting") [] -> return () , testCase "masquerade rule has oifname match and masquerade verdict" $ do v <- compileToValue "let rfc1918 : Set = { 10.0.0.0/8 }; \ \masquerade wan_snat on wan src rfc1918;" arr <- nftArr v let rules = withKey "rule" arr snatRules = filter (\r -> at ["rule","chain"] r == Just (A.String "wan_snat_postrouting")) rules hasOifname = any (\r -> case at ["rule","expr"] r of Just (A.Array es) -> any (\e -> at ["match","left","meta","key"] e == Just (A.String "oifname")) (V.toList es) _ -> False) snatRules hasMasq = any (\r -> case at ["rule","expr"] r of Just (A.Array es) -> any (\e -> at ["masquerade"] e /= Nothing) (V.toList es) _ -> False) snatRules assertBool "Masquerade rule should match oifname" hasOifname assertBool "Masquerade rule should have masquerade verdict" hasMasq ]