module ParsedOFL where
-import Data.Map as Map
+import Data.Map as Map (Map, lookup, insertWithKey, empty)
-- The top-level types
-data BaseType = Real | Function | Integer deriving Show
+data BaseType = Real | Function | Integer deriving (Show, Eq)
data IndexType = FunctionIndex | SpinIndex | SpatialIndex deriving (Show, Eq)
-- Expressions
Derivative Expression String deriving Show
data Assignment = Assign Expression Expression deriving Show
-data ValidationResult = Valid | Invalid String
-- The symbol table
data SymbolType = ValueTag BaseType [IndexType] | IndexTag IndexType deriving Show
addAssignment :: OFL -> Expression -> Expression -> OFL
addAssignment ofl lhs rhs = let assignment = (Assign lhs rhs) in
case (validateAssignment ofl assignment) of
- Valid -> ofl { assignments = (Assign lhs rhs):(assignments ofl) }
- Invalid reason -> error reason
+ Right () -> ofl { assignments = (Assign lhs rhs):(assignments ofl) }
+ Left reason -> error $ show reason
errorOnDuplicate :: Show k => k -> a -> a -> a
errorOnDuplicate key _ _ = error $ "Attempted redefinition of symbol " ++ show key
-getIndices :: OFL -> String -> Maybe [IndexType]
+getIndices :: OFL -> String -> [IndexType]
getIndices ofl name =
case Map.lookup name (symbols ofl) of
- Nothing -> Nothing
- Just (IndexTag _) -> Nothing
- Just (ValueTag _ indices) -> Just indices
+ Just (ValueTag _ indices) -> indices
+ _ -> fail $ "Cannot find indices for " ++ show name
-getIndexType :: OFL -> String -> Maybe IndexType
+getIndexType :: OFL -> String -> IndexType
getIndexType ofl name = case Map.lookup name (symbols ofl) of
- Nothing -> Nothing
- Just (IndexTag indexType) -> Just indexType
- Just (ValueTag _ _) -> Nothing
+ Just (IndexTag indexType) -> indexType
+ _ -> error $ "Cannot find index type of " ++ show name
+
+getValueType :: OFL -> String -> BaseType
+getValueType ofl name = case Map.lookup name (symbols ofl) of
+ Just (ValueTag baseType _) -> baseType
+ _ -> error $ "Cannot find type of value " ++ show name
+
+hasIndex :: OFL -> String -> Bool
+hasIndex ofl name = case Map.lookup name (symbols ofl) of
+ Just (IndexTag _) -> True
+ _ -> False
+
+hasValue :: OFL -> String -> Bool
+hasValue ofl name = case Map.lookup name (symbols ofl) of
+ Just (ValueTag _ _) -> True
+ _ -> False
+
+promote :: BaseType -> BaseType -> BaseType
+promote Function _ = Function
+promote _ Function = Function
+promote Real _ = Real
+promote _ Real = Real
+promote t1 t2 | (t1 == t2) = t1
+
+getType :: OFL -> Expression -> BaseType
+getType ofl (IndexedIdentifier name _) = getValueType ofl name
+getType ofl (ConstReal _) = Real
+getType ofl (ConstInteger _) = Integer
+getType ofl (Negate e) = getType ofl e
+getType ofl (Inner _ _) = Function
+getType ofl (Laplacian _) = Function
+getType ofl (Sum e _) = getType ofl e
+getType ofl (Multiply a b) = promote (getType ofl a) (getType ofl b)
+getType ofl (Divide a b) = promote (getType ofl a) (getType ofl b)
+getType ofl (Add a b) = promote (getType ofl a) (getType ofl b)
+getType ofl (Sub a b) = promote (getType ofl a) (getType ofl b)
+getType ofl (Power a b) = promote (getType ofl a) (getType ofl b)
+getType ofl (Component e i) = Function
+getType ofl (Derivative e i) = getType ofl e
emptyOFL :: OFL
emptyOFL = OFL { assignments = [], symbols = Map.empty }
validateAssignment :: OFL -> Assignment -> ValidationResult
-validateAssignment ofl (Assign lhs rhs) = case (validateExpression ofl lhs) of
- Invalid reason -> Invalid $ "Invalid LHS of assignment" ++ reason
- Valid -> case (validateExpression ofl rhs) of
- Valid -> Valid
- Invalid reason -> Invalid $ "Invalid RHS of assignment: " ++ reason
+validateAssignment ofl (Assign lhs rhs) = do {
+ validateExpression ofl lhs;
+ validateExpression ofl rhs;
+ return ()
+}
+
+-- Validation
+
+data ValidationError = Message String deriving Show
+type ValidationResult = Either ValidationError ()
+validationSuccess = Right ()
+validationFailure = \x -> Left (Message x)
+
+indexExists :: OFL -> String -> ValidationResult
+indexExists ofl name = if (hasIndex ofl name) then validationSuccess else validationFailure $ "Unknown index " ++ name
+
+valueExists :: OFL -> String -> ValidationResult
+valueExists ofl name = if (hasValue ofl name) then validationSuccess else validationFailure $ "Unknownm value " ++ name
+
+isFunction :: OFL -> Expression -> ValidationResult
+isFunction ofl e = case (getType ofl e) of
+ Function -> validationSuccess
+ _ -> validationFailure $ "Expression " ++ show e ++ " is not a function"
validateExpression :: OFL -> Expression -> ValidationResult
-validateExpression ofl (IndexedIdentifier name indices) =
- case getIndices ofl name of
- Nothing -> Invalid $ "Cannot find declaration for value " ++ show name
- Just indexTypes -> let lengthMatch = (length indices) == (length indexTypes) in
- case lengthMatch of
- False -> Invalid $ "Value " ++ show name ++ " used with wrong number of indices."
- True -> let typeMatch = (Prelude.map (getIndexType ofl) indices) == (Prelude.map Just indexTypes) in
- case typeMatch of
- True -> Valid
- False -> Invalid $ "Value " ++ show name ++ "indexed with invalid indices."
-validateExpression ofl e = Invalid $ "Don't know how to validate " ++ show e
+
+validateExpression ofl (IndexedIdentifier name indices) = do {
+ valueExists ofl name;
+ foldl (>>) validationSuccess $ map (indexExists ofl) indices;
+ let indexTypes = map (getIndexType ofl) indices in
+ case indexTypes == (getIndices ofl name) of
+ True -> validationSuccess
+ False -> validationFailure $ "Incorrect number or type of indices used to index " ++ name
+}
+
+validateExpression ofl (ConstReal _) = validationSuccess
+
+validateExpression ofl (ConstInteger _) = validationSuccess
+
+validateExpression ofl (Negate e) = validateExpression ofl e
+
+validateExpression ofl (Inner a b) = do {
+ validateExpression ofl a;
+ validateExpression ofl b;
+ isFunction ofl a;
+ isFunction ofl b;
+}
+
+validateExpression ofl (Laplacian e) = do {
+ validateExpression ofl e;
+ isFunction ofl e;
+}
+
+validateExpression ofl (Sum e i) = do {
+ validateExpression ofl e;
+ indexExists ofl i;
+}
+
+validateExpression ofl (Multiply a b) = do {
+ validateExpression ofl a;
+ validateExpression ofl b;
+}
+
+validateExpression ofl (Divide a b) = do {
+ validateExpression ofl a;
+ validateExpression ofl b;
+}
+
+validateExpression ofl (Add a b) = do {
+ validateExpression ofl a;
+ validateExpression ofl b;
+}
+
+validateExpression ofl (Sub a b) = do {
+ validateExpression ofl a;
+ validateExpression ofl b;
+}
+
+validateExpression ofl (Power a b) = do {
+ validateExpression ofl a;
+ validateExpression ofl b;
+}
+
+
+validateExpression ofl (Component e i) = do {
+ validateExpression ofl e;
+ indexExists ofl i;
+}
+
+validateExpression ofl (Derivative e i) = do {
+ validateExpression ofl e;
+ indexExists ofl i;
+}