more compiler fixes

This commit is contained in:
2026-05-04 00:14:47 -07:00
parent 8a508ad7cc
commit d136bd62f7
7 changed files with 87 additions and 47 deletions

View File

@@ -165,6 +165,7 @@ pat ::= wildcardPat -- _
| bytesPat -- [ byteElem* ] | bytesPat -- [ byteElem* ]
| recordPat -- Ctor { field = lit, ... } | recordPat -- Ctor { field = lit, ... }
| namedOrCtorPat -- Ctor(p,...) or bare identifier | namedOrCtorPat -- Ctor(p,...) or bare identifier
| pat "|" pat -- Or-pattern
wildcardPat ::= "_" wildcardPat ::= "_"
framePat ::= "Frame" "(" frameArgs ")" framePat ::= "Frame" "(" frameArgs ")"

View File

@@ -7,7 +7,7 @@ interface wg0 : WireGuard {};
zone lan_zone = { lan, wg0 }; zone lan_zone = { lan, wg0 };
import rfc1918 : CIDRSet from "builtin:rfc1918"; let rfc1918 : Set<IPv4> = { 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 };
let forwards : Map<(Protocol, Port), (IP, Port)> = { let forwards : Map<(Protocol, Port), (IP, Port)> = {
(tcp, :8080) -> (10.17.1.10, :80), (tcp, :8080) -> (10.17.1.10, :80),
@@ -64,8 +64,8 @@ policy forward : Frame
| _ if ct.status == DNAT -> Allow; | _ if ct.status == DNAT -> Allow;
| Frame(iif in lan_zone -> wan, _) -> Allow; | Frame(iif in lan_zone -> wan, _) -> Allow;
| Frame(iif in lan_zone -> lan_zone, _) -> Allow; | Frame(iif in lan_zone -> lan_zone, _) -> Allow;
| Frame(wan -> lan_zone, IPv4(ip, TCP(tcp, _))) | Frame(wan -> lan_zone, IPv4(ip, TCP(th, _) | UDP(th, _)))
if (ip.dst, tcp.dport) in forwards -> Allow; if (ip.protocol, th.dport) in forwards -> Allow;
| _ -> Drop; | _ -> Drop;
}; };
@@ -80,9 +80,9 @@ policy output : Frame
policy nat_prerouting : Frame policy nat_prerouting : Frame
on { hook = Prerouting, table = NAT, priority = DstNat } on { hook = Prerouting, table = NAT, priority = DstNat }
= { = {
| Frame(_, IPv4(ip, _)) -> | Frame(_, IPv4(ip, TCP(th, _) | UDP(th, _))) ->
if perform FIB.daddrLocal(ip.dst) if perform FIB.daddrLocal(ip.dst)
then DNATMap(forwards) then DNATMap((ip.protocol, th.dport), forwards)
else Allow; else Allow;
| _ -> Allow; | _ -> Allow;
}; };

View File

@@ -82,6 +82,7 @@ data Pat
| PTuple [Pat] | PTuple [Pat]
| PFrame (Maybe PathPat) Pat | PFrame (Maybe PathPat) Pat
| PBytes [ByteElem] | PBytes [ByteElem]
| POr Pat Pat
deriving (Show) deriving (Show)
data FieldPat data FieldPat

View File

@@ -20,6 +20,7 @@ data CheckError
| PolicyNoContinue String -- policy name | PolicyNoContinue String -- policy name
| PatternCycle [String] -- cycle path | PatternCycle [String] -- cycle path
| DuplicateDecl String String -- kind, name | DuplicateDecl String String -- kind, name
| OrPatternMismatch [String] [String]
deriving (Show, Eq) deriving (Show, Eq)
type Env = Map.Map String DeclKind type Env = Map.Map String DeclKind
@@ -117,6 +118,25 @@ checkPat env (PRecord _ fs) = concatMap (checkFP env) fs
checkPat env (PTuple ps) = concatMap (checkPat env) ps checkPat env (PTuple ps) = concatMap (checkPat env) ps
checkPat env (PFrame mp inner)= maybe [] (checkPath env) mp ++ checkPat env inner checkPat env (PFrame mp inner)= maybe [] (checkPath env) mp ++ checkPat env inner
checkPat _ (PBytes _) = [] checkPat _ (PBytes _) = []
checkPat env (POr p1 p2) =
let v1 = boundVars p1
v2 = boundVars p2
errs = if Set.fromList v1 == Set.fromList v2 then [] else [OrPatternMismatch v1 v2]
in errs ++ checkPat env p1 ++ checkPat env p2
boundVars :: Pat -> [String]
boundVars (PVar n) = [n]
boundVars (PCtor _ ps) = concatMap boundVars ps
boundVars (PRecord _ fs) = concatMap boundFP fs
boundVars (PTuple ps) = concatMap boundVars ps
boundVars (PFrame _ p) = boundVars p
boundVars (POr p1 p2) = boundVars p1
boundVars _ = []
boundFP :: FieldPat -> [String]
boundFP (FPBind n) = [n]
boundFP (FPAs _ v) = [v]
boundFP _ = []
checkFP :: Env -> FieldPat -> [CheckError] checkFP :: Env -> FieldPat -> [CheckError]
checkFP _ _ = [] -- field names checked by type-checker later checkFP _ _ = [] -- field names checked by type-checker later
@@ -153,6 +173,7 @@ addPat env (PFrame mp inner) =
in case md of Just (EPName n) -> Map.insert n KLet env1; _ -> env1 in case md of Just (EPName n) -> Map.insert n KLet env1; _ -> env1
Nothing -> env Nothing -> env
in addPat env' inner in addPat env' inner
addPat env (POr p1 _) = addPat env p1
addPat env _ = env addPat env _ = env
addFP :: Env -> FieldPat -> Env addFP :: Env -> FieldPat -> Env
@@ -211,6 +232,7 @@ checkPatternCycles decls =
refsInPat (PCtor _ ps) = concatMap refsInPat ps refsInPat (PCtor _ ps) = concatMap refsInPat ps
refsInPat (PTuple ps) = concatMap refsInPat ps refsInPat (PTuple ps) = concatMap refsInPat ps
refsInPat (PFrame _ p) = refsInPat p refsInPat (PFrame _ p) = refsInPat p
refsInPat (POr p1 p2) = refsInPat p1 ++ refsInPat p2
refsInPat _ = [] refsInPat _ = []
findCycles :: Map.Map String [String] -> [[String]] findCycles :: Map.Map String [String] -> [[String]]

View File

@@ -98,18 +98,17 @@ armToRuleValues env tbl chain (Arm p mg body) =
case compileAction env body of case compileAction env body of
Nothing -> [] Nothing -> []
Just verdict -> Just verdict ->
let patExprs = compilePat env p let patExprsAlts = compilePat env p
guardExprs = maybe [] (compileGuard env) mg guardExprs = maybe [] (compileGuard env) mg
allExprs = patExprs ++ guardExprs ++ [verdict]
in [ object in [ object
[ "rule" .= object [ "rule" .= object
[ "family" .= ("inet" :: String) [ "family" .= ("inet" :: String)
, "table" .= tbl , "table" .= tbl
, "chain" .= chain , "chain" .= chain
, "expr" .= toJSON allExprs , "expr" .= toJSON (patExprs ++ guardExprs ++ [verdict])
]
] ]
] ]
| patExprs <- patExprsAlts ]
-- ─── Pattern → [Value] ─────────────────────────────────────────────────────── -- ─── Pattern → [Value] ───────────────────────────────────────────────────────
@@ -127,54 +126,57 @@ buildEnv = foldr (\d m -> Map.insert (declNameOf d) d m) Map.empty
declNameOf (DLet n _ _) = n declNameOf (DLet n _ _) = n
declNameOf (DImport n _ _) = n declNameOf (DImport n _ _) = n
compilePat :: CompileEnv -> Pat -> [Value] compilePat :: CompileEnv -> Pat -> [[Value]]
compilePat _ PWild = [] compilePat _ PWild = [[]]
compilePat _ (PVar _) = [] compilePat _ (PVar _) = [[]]
compilePat env (PNamed n) = expandNamedPat env n compilePat env (PNamed n) = expandNamedPat env n
compilePat env (PFrame mp inner) = compilePat env (PFrame mp inner) = do
maybe [] (compilePathPat env) mp ++ compilePat env inner pathConds <- maybe [[]] (compilePathPat env) mp
innerConds <- compilePat env inner
return (pathConds ++ innerConds)
compilePat env (PCtor n ps) = compileCtorPat env n ps compilePat env (PCtor n ps) = compileCtorPat env n ps
compilePat _ (PRecord n fs) = compileRecordPat n fs compilePat _ (PRecord n fs) = compileRecordPat n fs
compilePat env (PTuple ps) = concatMap (compilePat env) ps compilePat env (PTuple ps) = map concat (sequence (map (compilePat env) ps))
compilePat _ (PBytes _) = [] compilePat _ (PBytes _) = [[]]
compilePat env (POr p1 p2) = compilePat env p1 ++ compilePat env p2
expandNamedPat :: CompileEnv -> Name -> [Value] expandNamedPat :: CompileEnv -> Name -> [[Value]]
expandNamedPat env n = expandNamedPat env n =
case Map.lookup n env of case Map.lookup n env of
Just (DPattern _ _ p) -> compilePat env p Just (DPattern _ _ p) -> compilePat env p
_ -> [] _ -> []
compileCtorPat :: CompileEnv -> String -> [Pat] -> [Value] compileCtorPat :: CompileEnv -> String -> [Pat] -> [[Value]]
compileCtorPat env ctor ps = case ctor of compileCtorPat env ctor ps = case ctor of
"Ether" -> children "Ether" -> children
"IPv4" -> matchMeta "nfproto" "ipv4" : children "IPv4" -> map (matchMeta "nfproto" "ipv4" :) children
"IPv6" -> matchMeta "nfproto" "ipv6" : children "IPv6" -> map (matchMeta "nfproto" "ipv6" :) children
"TCP" -> matchMeta "l4proto" "tcp" : children "TCP" -> map (matchMeta "l4proto" "tcp" :) children
"UDP" -> matchMeta "l4proto" "udp" : children "UDP" -> map (matchMeta "l4proto" "udp" :) children
"ICMPv6" -> matchPayload "ip6" "nexthdr" "ipv6-icmp" : children "ICMPv6" -> map (matchPayload "ip6" "nexthdr" "ipv6-icmp" :) children
"ICMP" -> matchPayload "ip" "protocol" "icmp" : children "ICMP" -> map (matchPayload "ip" "protocol" "icmp" :) children
_ -> children _ -> children
where where
children = concatMap (compilePat env) ps children = map concat (sequence (map (compilePat env) ps))
compileRecordPat :: String -> [FieldPat] -> [Value] compileRecordPat :: String -> [FieldPat] -> [[Value]]
compileRecordPat proto = mapMaybe go compileRecordPat proto fs = [mapMaybe go fs]
where where
go (FPEq field lit) = Just (matchPayload proto field (renderLit lit)) go (FPEq field lit) = Just (matchPayload proto field (renderLit lit))
go _ = Nothing go _ = Nothing
compilePathPat :: CompileEnv -> PathPat -> [Value] compilePathPat :: CompileEnv -> PathPat -> [[Value]]
compilePathPat env (PathPat ms md) = compilePathPat env (PathPat ms md) =
maybe [] (compileEndpoint env "iifname") ms ++ [ maybe [] (compileEndpoint env "iifname") ms ++
maybe [] (compileEndpoint env "oifname") md maybe [] (compileEndpoint env "oifname") md ]
compileEndpoint :: CompileEnv -> String -> EndpointPat -> [Value] compileEndpoint :: CompileEnv -> String -> EndpointPat -> [Value]
compileEndpoint _ _ EPWild = [] compileEndpoint _ _ EPWild = []
compileEndpoint _ dir (EPName n) = [matchMeta dir n] compileEndpoint _ dir (EPName n) = [matchMeta dir n]
compileEndpoint env dir (EPMember _ z) = compileEndpoint env dir (EPMember _ z) =
case Map.lookup z env of case Map.lookup z env of
Just (DZone _ ns) -> [matchInSet (metaVal dir) ns] Just (DZone _ ns) -> [matchInSet (metaVal dir) (map (A.String . toText) ns)]
_ -> [matchInSet (metaVal dir) [z]] _ -> [matchInSet (metaVal dir) [A.String (toText z)]]
-- ─── Guard → [Value] ───────────────────────────────────────────────────────── -- ─── Guard → [Value] ─────────────────────────────────────────────────────────
@@ -188,20 +190,20 @@ compileGuard _ _ = []
compileInExpr :: CompileEnv -> Expr -> Expr -> Value compileInExpr :: CompileEnv -> Expr -> Expr -> Value
-- Fix 4: put the more-specific ct patterns BEFORE the generic 2-element -- Fix 4: put the more-specific ct patterns BEFORE the generic 2-element
-- EQual case to eliminate the overlapping pattern match warning. -- EQual case to eliminate the overlapping pattern match warning.
compileInExpr _ (EQual ["ct", "state"]) (ESet vs) = ctMatch "state" vs compileInExpr env (EQual ["ct", "state"]) (ESet vs) = ctMatch env "state" vs
compileInExpr _ (EQual ["ct", "status"]) (ESet vs) = ctMatch "status" vs compileInExpr env (EQual ["ct", "status"]) (ESet vs) = ctMatch env "status" vs
compileInExpr env l (ESet vs) = compileInExpr env l (ESet vs) =
matchExpr "in" (exprVal env l) (setVal (map exprToStr vs)) matchExpr "in" (exprVal env l) (setVal (map (exprVal env) vs))
compileInExpr env l (EVar z) compileInExpr env l (EVar z)
| Just (DZone _ ns) <- Map.lookup z env = | Just (DZone _ ns) <- Map.lookup z env =
matchExpr "in" (exprVal env l) (setVal ns) matchExpr "in" (exprVal env l) (setVal (map (A.String . toText) ns))
compileInExpr env l r = compileInExpr env l r =
matchExpr "==" (exprVal env l) (exprVal env r) matchExpr "==" (exprVal env l) (exprVal env r)
ctMatch :: String -> [Expr] -> Value ctMatch :: CompileEnv -> String -> [Expr] -> Value
ctMatch key vs = matchExpr "in" ctMatch env key vs = matchExpr "in"
(object ["ct" .= object ["key" .= (key :: String)]]) (object ["ct" .= object ["key" .= (key :: String)]])
(setVal (map exprToStr vs)) (setVal (map (exprVal env) vs))
-- ─── Action → Maybe Value ───────────────────────────────────────────────────── -- ─── Action → Maybe Value ─────────────────────────────────────────────────────
@@ -212,10 +214,10 @@ compileAction _ (EVar "Continue") = Nothing
compileAction _ (EVar "Masquerade") = Just (object ["masquerade" .= Null]) compileAction _ (EVar "Masquerade") = Just (object ["masquerade" .= Null])
compileAction _ (EApp (EVar "DNAT") arg) = compileAction _ (EApp (EVar "DNAT") arg) =
Just $ object ["dnat" .= object ["addr" .= exprToStr arg]] Just $ object ["dnat" .= object ["addr" .= exprToStr arg]]
compileAction _ (EApp (EVar "DNATMap") arg) = compileAction env (EApp (EVar "DNATMap") (ETuple [key, arg])) =
Just $ object ["dnat" .= object ["addr" .= object Just $ object ["dnat" .= object ["addr" .= object
[ "map" .= object [ "key" .= object ["concat" .= Array mempty] [ "map" .= object [ "key" .= exprVal env key
, "data" .= exprToStr arg ]]]] , "data" .= A.String ("@" <> toText (exprToStr arg)) ]]]]
compileAction env (EApp (EVar rn) _) = compileAction env (EApp (EVar rn) _) =
case Map.lookup rn env of case Map.lookup rn env of
Just (DRule _ _ _) -> Just $ object ["jump" .= object ["target" .= rn]] Just (DRule _ _ _) -> Just $ object ["jump" .= object ["target" .= rn]]
@@ -300,7 +302,7 @@ matchPayload :: String -> String -> String -> Value
matchPayload proto field val = matchPayload proto field val =
matchExpr "==" (payloadVal proto field) (A.String (toText val)) matchExpr "==" (payloadVal proto field) (A.String (toText val))
matchInSet :: Value -> [String] -> Value matchInSet :: Value -> [Value] -> Value
matchInSet lhs vals = matchExpr "in" lhs (setVal vals) matchInSet lhs vals = matchExpr "in" lhs (setVal vals)
metaVal :: String -> Value metaVal :: String -> Value
@@ -313,7 +315,7 @@ payloadVal proto field =
, "field" .= (field :: String) , "field" .= (field :: String)
]] ]]
setVal :: [String] -> Value setVal :: [Value] -> Value
setVal vs = object ["set" .= toJSON vs] setVal vs = object ["set" .= toJSON vs]
-- ─── Expression helpers ─────────────────────────────────────────────────────── -- ─── Expression helpers ───────────────────────────────────────────────────────
@@ -332,6 +334,8 @@ mapField f = f
-- Fix 3 (overlap): specific ct pattern first, generic 2-element case second. -- Fix 3 (overlap): specific ct pattern first, generic 2-element case second.
exprVal :: CompileEnv -> Expr -> Value exprVal :: CompileEnv -> Expr -> Value
exprVal _ (EQual ["ct", k]) = object ["ct" .= object ["key" .= (k :: String)]] exprVal _ (EQual ["ct", k]) = object ["ct" .= object ["key" .= (k :: String)]]
exprVal _ (EQual ["meta", k])= metaVal k
exprVal _ (EQual ["th", k]) = payloadVal "th" k
exprVal _ (EQual [p, f]) = payloadVal p (mapField f) exprVal _ (EQual [p, f]) = payloadVal p (mapField f)
exprVal _ (EQual ns) = A.String (toText (intercalate "." ns)) exprVal _ (EQual ns) = A.String (toText (intercalate "." ns))
exprVal env (EVar n) exprVal env (EVar n)
@@ -343,8 +347,14 @@ exprVal env (EVar n)
| n == "Established" = A.String "established" | n == "Established" = A.String "established"
| n == "Related" = A.String "related" | n == "Related" = A.String "related"
| otherwise = metaVal n | otherwise = metaVal n
exprVal _ (ELit (LCIDR ip p)) = object
[ "prefix" .= object
[ "addr" .= A.String (toText (renderLit ip))
, "len" .= p
]
]
exprVal _ (ELit l) = A.String (toText (renderLit l)) exprVal _ (ELit l) = A.String (toText (renderLit l))
exprVal _ (ESet vs) = setVal (map exprToStr vs) exprVal env (ESet vs) = setVal (map (exprVal env) vs)
exprVal env (ETuple es) = object ["concat" .= toJSON (map (exprVal env) es)] exprVal env (ETuple es) = object ["concat" .= toJSON (map (exprVal env) es)]
exprVal _ e = A.String (toText (exprToStr e)) exprVal _ e = A.String (toText (exprToStr e))

View File

@@ -225,7 +225,12 @@ arm = do
-- ─── Patterns ──────────────────────────────────────────────────────────────── -- ─── Patterns ────────────────────────────────────────────────────────────────
pat :: Parser Pat pat :: Parser Pat
pat = wildcardPat pat = Ex.buildExpressionParser patTable patAtom <?> "pattern"
where
patTable = [ [Ex.Infix (reservedOp "|" >> return POr) Ex.AssocLeft] ]
patAtom :: Parser Pat
patAtom = wildcardPat
<|> try framePat <|> try framePat
<|> try tuplePat <|> try tuplePat
<|> bytesPat <|> bytesPat

View File

@@ -82,6 +82,7 @@ prettyPat (PTuple ps) = "(" ++ intercalate ", " (map prettyPat ps) ++ ")"
prettyPat (PFrame mp inner)= prettyPat (PFrame mp inner)=
"Frame(" ++ maybe "" (\pp -> prettyPath pp ++ ", ") mp ++ prettyPat inner ++ ")" "Frame(" ++ maybe "" (\pp -> prettyPath pp ++ ", ") mp ++ prettyPat inner ++ ")"
prettyPat (PBytes bs) = "[" ++ unwords (map prettyBE bs) ++ "]" prettyPat (PBytes bs) = "[" ++ unwords (map prettyBE bs) ++ "]"
prettyPat (POr p1 p2) = prettyPat p1 ++ " | " ++ prettyPat p2
prettyFP :: FieldPat -> String prettyFP :: FieldPat -> String
prettyFP (FPEq n l) = n ++ " = " ++ prettyLit l prettyFP (FPEq n l) = n ++ " = " ++ prettyLit l