]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Add AST validation.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 14 Sep 2012 12:20:46 +0000 (13:20 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Fri, 14 Sep 2012 12:20:46 +0000 (13:20 +0100)
src/ParsedOFL.hs

index aea3dfff9d029dda5c734e8072dcda08f7e58835..db234b5b03cf3a3d34280473d597ff6c760873be 100644 (file)
@@ -1,8 +1,8 @@
 module ParsedOFL where
-import Data.Map as Map
+import Data.Map as Map (Map, lookup, insertWithKey, empty)
 
 -- The top-level types
-data BaseType = Real | Function | Integer deriving Show
+data BaseType = Real | Function | Integer deriving (Show, Eq)
 data IndexType = FunctionIndex | SpinIndex | SpatialIndex deriving (Show, Eq)
 
 -- Expressions
@@ -22,7 +22,6 @@ data Expression = IndexedIdentifier String [String] |
                   Derivative Expression String deriving Show
 
 data Assignment = Assign Expression Expression deriving Show
-data ValidationResult = Valid | Invalid String
 
 -- The symbol table
 data SymbolType = ValueTag BaseType [IndexType] | IndexTag IndexType deriving Show
@@ -44,44 +43,155 @@ addIndexDeclaration ofl name indexType = ofl { symbols = symbols' }
 addAssignment :: OFL -> Expression -> Expression -> OFL
 addAssignment ofl lhs rhs = let assignment = (Assign lhs rhs) in
   case (validateAssignment ofl assignment) of
-    Valid -> ofl { assignments = (Assign lhs rhs):(assignments ofl) } 
-    Invalid reason -> error reason
+    Right () -> ofl { assignments = (Assign lhs rhs):(assignments ofl) } 
+    Left reason -> error $ show reason
 
 errorOnDuplicate :: Show k => k -> a -> a -> a
 errorOnDuplicate key _ _ = error $ "Attempted redefinition of symbol " ++ show key
 
-getIndices :: OFL -> String -> Maybe [IndexType]
+getIndices :: OFL -> String -> [IndexType]
 getIndices ofl name = 
  case Map.lookup name (symbols ofl) of
- Nothing -> Nothing
- Just (IndexTag _) -> Nothing
- Just (ValueTag _ indices) -> Just indices
+ Just (ValueTag _ indices) -> indices
+ _ -> fail $ "Cannot find indices for " ++ show name
 
-getIndexType :: OFL -> String -> Maybe IndexType
+getIndexType :: OFL -> String -> IndexType
 getIndexType ofl name = case Map.lookup name (symbols ofl) of
- Nothing -> Nothing
- Just (IndexTag indexType) -> Just indexType
- Just (ValueTag _ _) -> Nothing
+ Just (IndexTag indexType) -> indexType
+ _ -> error $ "Cannot find index type of " ++ show name
+
+getValueType :: OFL -> String -> BaseType
+getValueType ofl name = case Map.lookup name (symbols ofl) of
+ Just (ValueTag baseType _) -> baseType
+ _ -> error $ "Cannot find type of value " ++ show name
+
+hasIndex :: OFL -> String -> Bool
+hasIndex ofl name = case Map.lookup name (symbols ofl) of
+ Just (IndexTag _) -> True
+ _ -> False
+
+hasValue :: OFL -> String -> Bool
+hasValue ofl name = case Map.lookup name (symbols ofl) of
+ Just (ValueTag _ _) -> True
+ _ -> False
+
+promote :: BaseType -> BaseType -> BaseType
+promote Function _ = Function
+promote _ Function = Function
+promote Real _ = Real
+promote _ Real = Real
+promote t1 t2 | (t1 == t2) = t1
+
+getType :: OFL -> Expression -> BaseType
+getType ofl (IndexedIdentifier name _) = getValueType ofl name
+getType ofl (ConstReal _) = Real
+getType ofl (ConstInteger _) = Integer
+getType ofl (Negate e) = getType ofl e 
+getType ofl (Inner _ _) = Function
+getType ofl (Laplacian _) = Function 
+getType ofl (Sum e _) = getType ofl e
+getType ofl (Multiply a b) = promote (getType ofl a) (getType ofl b)
+getType ofl (Divide a b) = promote (getType ofl a) (getType ofl b)
+getType ofl (Add a b) = promote (getType ofl a) (getType ofl b)
+getType ofl (Sub a b) = promote (getType ofl a) (getType ofl b)
+getType ofl (Power a b) = promote (getType ofl a) (getType ofl b)
+getType ofl (Component e i) = Function
+getType ofl (Derivative e i) = getType ofl e
 
 emptyOFL :: OFL
 emptyOFL = OFL { assignments = [], symbols = Map.empty }
 
 validateAssignment :: OFL -> Assignment -> ValidationResult
-validateAssignment ofl (Assign lhs rhs) = case (validateExpression ofl lhs) of
- Invalid reason -> Invalid $ "Invalid LHS of assignment" ++ reason
- Valid -> case (validateExpression ofl rhs) of
-   Valid -> Valid
-   Invalid reason -> Invalid $ "Invalid RHS of assignment: " ++ reason
+validateAssignment ofl (Assign lhs rhs) = do {
+  validateExpression ofl lhs;
+  validateExpression ofl rhs;
+  return ()
+}
+
+-- Validation
+
+data ValidationError = Message String deriving Show
+type ValidationResult = Either ValidationError ()
+validationSuccess = Right ()
+validationFailure = \x -> Left (Message x)
+
+indexExists :: OFL -> String -> ValidationResult
+indexExists ofl name = if (hasIndex ofl name) then validationSuccess else validationFailure $ "Unknown index " ++ name
+
+valueExists :: OFL -> String -> ValidationResult
+valueExists ofl name = if (hasValue ofl name) then validationSuccess else validationFailure $ "Unknownm value " ++ name
+
+isFunction :: OFL -> Expression -> ValidationResult
+isFunction ofl e = case (getType ofl e) of
+  Function -> validationSuccess
+  _ -> validationFailure $ "Expression " ++ show e ++ " is not a function"
 
 validateExpression :: OFL -> Expression -> ValidationResult
-validateExpression ofl (IndexedIdentifier name indices) = 
-  case getIndices ofl name of
-   Nothing -> Invalid $ "Cannot find declaration for value " ++ show name
-   Just indexTypes -> let lengthMatch = (length indices) == (length indexTypes) in
-     case lengthMatch of 
-      False -> Invalid $ "Value " ++ show name ++ " used with wrong number of indices."
-      True -> let typeMatch = (Prelude.map (getIndexType ofl) indices) == (Prelude.map Just indexTypes) in
-        case typeMatch of
-          True -> Valid
-          False -> Invalid $ "Value " ++ show name ++ "indexed with invalid indices." 
-validateExpression ofl e = Invalid $ "Don't know how to validate " ++ show e
+
+validateExpression ofl (IndexedIdentifier name indices) = do {
+  valueExists ofl name;
+  foldl (>>) validationSuccess $ map (indexExists ofl) indices;
+  let indexTypes = map (getIndexType ofl) indices in
+    case indexTypes == (getIndices ofl name) of
+      True -> validationSuccess
+      False -> validationFailure $ "Incorrect number or type of indices used to index " ++ name
+}
+
+validateExpression ofl (ConstReal _) = validationSuccess
+
+validateExpression ofl (ConstInteger _) = validationSuccess
+
+validateExpression ofl (Negate e) = validateExpression ofl e
+
+validateExpression ofl (Inner a b) = do {
+  validateExpression ofl a;
+  validateExpression ofl b;
+  isFunction ofl a;
+  isFunction ofl b;
+}
+
+validateExpression ofl (Laplacian e) = do {
+  validateExpression ofl e;
+  isFunction ofl e;
+}
+
+validateExpression ofl (Sum e i) = do {
+  validateExpression ofl e;
+  indexExists ofl i;
+}
+
+validateExpression ofl (Multiply a b) = do {
+  validateExpression ofl a;
+  validateExpression ofl b;
+}
+
+validateExpression ofl (Divide a b) = do {
+  validateExpression ofl a;
+  validateExpression ofl b;
+}
+
+validateExpression ofl (Add a b) = do {
+  validateExpression ofl a;
+  validateExpression ofl b;
+}
+
+validateExpression ofl (Sub a b) = do {
+  validateExpression ofl a;
+  validateExpression ofl b;
+}
+
+validateExpression ofl (Power a b) = do {
+  validateExpression ofl a;
+  validateExpression ofl b;
+}
+
+
+validateExpression ofl (Component e i) = do {
+  validateExpression ofl e;
+  indexExists ofl i;
+}
+
+validateExpression ofl (Derivative e i) = do {
+  validateExpression ofl e;
+  indexExists ofl i;
+}