]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Attempt to compute and validate second level expression types.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 21 Nov 2012 21:20:49 +0000 (21:20 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 21 Nov 2012 21:20:49 +0000 (21:20 +0000)
OFC/SecondLevel.hs

index 164f2e87ed337c715e35fe3c0e239ed95ab60851..d094911355d365ba8f599e1afa3e34c8a5de932f 100644 (file)
@@ -16,6 +16,7 @@ import qualified OFC.TargetMapping as TM
 import OFC.TopLevel (OFL)
 import qualified OFC.TopLevel as TopLevel
 import Data.Complex
+import Data.Monoid (Monoid(..))
 import Text.PrettyPrint
 import Data.Map (Map)
 import Data.List (foldl', intersperse, intercalate)
@@ -37,23 +38,31 @@ data ValueType =
   ScalarT ScalarType |
   PositionT PositionFieldType |
   MomentumT MomentumFieldType
-  deriving Show
+  deriving (Eq, Show)
 
 data ScalarType =
   RealType | 
   ComplexType |
   IntegerType
-  deriving Show
+  deriving (Eq, Show)
+
+instance Monoid ScalarType where
+  mempty = IntegerType
+  mappend IntegerType a = a
+  mappend a IntegerType = a
+  mappend RealType RealType = RealType
+  mappend ComplexType _ = ComplexType
+  mappend _ ComplexType = ComplexType
 
 data PositionFieldType =
   Psinc Integer |
   AnalyticPositionType
-  deriving Show
+  deriving (Eq, Show)
 
 data MomentumFieldType =
   PsincReciprocal Integer |
   AnalyticMomentumType
-  deriving Show
+  deriving (Eq, Show)
 
 data PsincE
 data PsincReciprocalE
@@ -162,36 +171,100 @@ instance PrettyPrintable (Expression e) where
     binaryToDoc op lhs rhs = parens $ hcat [toDoc lhs, text op, toDoc rhs]
     functionToDoc name params = text name <> (parens $ hcat (intersperse (text ", ") params))
 
--- getType :: OFL2 -> Expression -> ValueType
--- getType ofl2 (IndexedIdentifier name _) = getValueType ofl2 name
--- getType ofl2 (ToMomentum e) = case getType ofl2 e of
---   basis@(PsincReciprocal _) -> basis
---   Psinc i -> PsincReciprocal i
---   _ -> error "ToMomentum applied to non-Psinc value"
--- getType ofl2 (ToPosition e) = case getType ofl2 e of
---   basis@(Psinc _) -> basis
---   PsincReciprocal i -> Psinc i
---   _ -> error "ToPosition applied to non-Psinc value"
--- getType ofl2 (Upsample e) = case getType ofl2 e of 
---   Psinc i -> Psinc $ i+1
---   _ -> error "Upsample applied to non-Psinc value"
--- getType ofl2 (Downsample e) = case getType ofl2 e of 
---   Psinc i -> Psinc $ i-1
---   _ -> error "Downsample applied to non-Psinc value"
--- getType ofl2 (Integrate e) = case getType ofl2 e of
---   Psinc _ -> RealType
---   _ -> error "Integrate applied to non-Psinc value"
--- getType _ (ConstReal _) = RealType
--- getType _ (ConstInteger _) = IntegerType
--- getType _ (ConstComplex _) = ComplexType
--- getType _ (PositionComponent _) = PositionFunction
--- getType _ (MomentumComponent _) = MomentumFunction
--- getType ofl2 (Negate e) = getType ofl2 e
--- getType ofl2 (Sum e _) = getType ofl2 e
---getType ofl2 e = error "Unimplemented"
-
-getValueType :: OFL2 -> String -> ValueType
-getValueType ofl2 name = case Map.lookup name (symbols ofl2) of
+getImplementationType :: OFL2 -> Expression e -> Either String ValueType
+getImplementationType ofl2 expr = case expr of
+  IndexedPsincIdentifier name _ -> Right $ getIdentType ofl2 name
+  IndexedScalarIdentifier name _ -> Right $ getIdentType ofl2 name
+  ToMomentum e -> do
+    eType <- getImplementationType ofl2 e
+    case eType of
+      PositionT (Psinc i) -> return $ MomentumT $ PsincReciprocal i
+      PositionT (AnalyticPositionType) -> 
+        Left $ "Cannot convert analytic position expression " ++ prettyPrint e ++ " to momentum representation."
+      _ -> Left $ "ToMomentum applied to incorrect expression " ++ prettyPrint e
+  ToPosition e -> do
+    eType <- getImplementationType ofl2  e
+    case eType of
+      MomentumT (PsincReciprocal i) -> return $ PositionT $ Psinc i
+      MomentumT (AnalyticMomentumType) -> 
+        Left $ "Cannot convert analytic momentum expression " ++ prettyPrint e ++ " to position representation."
+      _ -> Left $ "ToPosition applied to incorrect expression " ++ prettyPrint e
+  Upsample e -> do
+    eType <- getImplementationType ofl2 e
+    case eType of
+      PositionT (Psinc i) -> return $ PositionT (Psinc $ i+1)
+      _ -> Left $ "Upsample applied to inappropriate expression " ++ prettyPrint e
+  Downsample e -> do
+    eType <- getImplementationType ofl2 e
+    case eType of
+      PositionT (Psinc i) -> if i>1 
+        then return $ PositionT (Psinc $ i-1) 
+        else Left $ "Cannot downsample expression " ++ prettyPrint e ++ " further."
+      _ -> Left $ "Downsample applied to inappropriate expression " ++ prettyPrint e
+  Integrate integrand -> do
+    eType <- getImplementationType ofl2 integrand
+    case eType of
+     PositionT (Psinc _) -> return $ ScalarT RealType 
+     _ -> Left $ "Do not know how to integrate " ++ prettyPrint integrand
+  AnalyticMomentum _ -> return $ MomentumT AnalyticMomentumType
+  AnalyticPosition _ -> return $ PositionT AnalyticPositionType
+  AnalyticToPsinc e i -> do
+    eType <- getImplementationType ofl2 e
+    case eType of
+      PositionT AnalyticPositionType -> return $ PositionT $ Psinc i
+      _ -> Left $ "AnalyticToPsinc applied to incorrect expression " ++ prettyPrint e
+  Sum e _ -> getImplementationType ofl2 e
+  Add a b -> handleAdditiveExpression a b
+  Sub a b -> handleAdditiveExpression a b
+  Neg e -> getImplementationType ofl2 e
+  MulScalar e s -> handleMultiplicativeExpression e s
+  DivScalar e s -> handleMultiplicativeExpression e s
+  Power a b -> do
+    ScalarT aType <- getImplementationType ofl2 a
+    ScalarT bType <- getImplementationType ofl2 b
+    return $ ScalarT $ mappend aType bType
+  PsincProduct a b -> do
+    aType <- getImplementationType ofl2 a
+    bType <- getImplementationType ofl2 b
+    case (aType, bType) of
+      (PositionT (Psinc n), PositionT (Psinc n2)) -> if n == n2 
+        then return $ PositionT $ Psinc n
+        else Left $ "Expressions " ++ prettyPrint a ++ " and " ++ prettyPrint b ++ " do not have matching basis sets"
+      _ -> Left $ "Do not know how to calculate product between " ++ prettyPrint a ++ " and " ++ prettyPrint b
+  PsincReciprocalProduct a b -> do
+    aType <- getImplementationType ofl2 a
+    bType <- getImplementationType ofl2 b
+    case (aType, bType) of
+      (MomentumT (PsincReciprocal n), MomentumT (PsincReciprocal n2)) -> if n == n2 
+        then return $ MomentumT $ PsincReciprocal n
+        else Left $ "Expressions " ++ prettyPrint a ++ " and " ++ prettyPrint b ++ " do not have matching basis sets"
+      (MomentumT (PsincReciprocal n), MomentumT AnalyticMomentumType) -> return $ MomentumT (PsincReciprocal n) 
+      (MomentumT AnalyticMomentumType, MomentumT (PsincReciprocal n)) -> return $ MomentumT (PsincReciprocal n)
+      _ -> Left $ "Do not know how to calculate product between " ++ prettyPrint a ++ " and " ++ prettyPrint b
+  ConstInteger _ -> Right $ ScalarT IntegerType
+  ConstReal _ -> Right $ ScalarT RealType
+  ConstComplex _ -> Right $ ScalarT ComplexType
+  where
+    handleAdditiveExpression a b = do
+      aType <- getImplementationType ofl2 a
+      bType <- getImplementationType ofl2 b
+      if aType == bType
+        then return aType
+        else Left $ "Cannot add/subtract expressions of mismatching types " ++ prettyPrint a ++ " and " ++ prettyPrint b
+    handleMultiplicativeExpression e s = do
+      eType <- getImplementationType ofl2 e
+      sType <- getImplementationType ofl2 s
+      case (eType, sType) of
+        (ScalarT a, ScalarT b) -> return $ ScalarT $ mappend a b
+        (PositionT (Psinc _), ScalarT IntegerType) -> return eType
+        (PositionT (Psinc _), ScalarT RealType) -> return eType
+        (MomentumT (PsincReciprocal _), ScalarT IntegerType) -> return eType
+        (MomentumT (PsincReciprocal _), ScalarT RealType) -> return eType
+        (MomentumT (PsincReciprocal _), ScalarT ComplexType) -> return eType
+        _ -> Left $ "Do not know how to implement multiplication/division of " ++ prettyPrint e ++ " by " ++ prettyPrint s
+
+getIdentType :: OFL2 -> String -> ValueType
+getIdentType ofl2 name = case Map.lookup name (symbols ofl2) of
   Just (ValueTag baseType _) -> baseType
   _ -> error $ "Could not find type of symbol " ++ name
 
@@ -401,7 +474,7 @@ instance Num Bandwidth where
   fromInteger _ = error "Unimplemented: Bandwidth fromInteger"
 
 theoreticalFrequency :: OFL2 -> Expression e -> Bandwidth
-theoreticalFrequency ofl2 (IndexedPsincIdentifier ident _) = case getValueType ofl2 ident of
+theoreticalFrequency ofl2 (IndexedPsincIdentifier ident _) = case getIdentType ofl2 ident of
   PositionT (Psinc i) -> GMaxMultiple i
   _ -> error $ "Expected identifier " ++ ident ++ " to be a Psinc field."
 theoreticalFrequency _ (IndexedScalarIdentifier _ _) = GMaxMultiple 0