import OFC.TopLevel (OFL)
import qualified OFC.TopLevel as TopLevel
import Data.Complex
+import Data.Monoid (Monoid(..))
import Text.PrettyPrint
import Data.Map (Map)
import Data.List (foldl', intersperse, intercalate)
ScalarT ScalarType |
PositionT PositionFieldType |
MomentumT MomentumFieldType
- deriving Show
+ deriving (Eq, Show)
data ScalarType =
RealType |
ComplexType |
IntegerType
- deriving Show
+ deriving (Eq, Show)
+
+instance Monoid ScalarType where
+ mempty = IntegerType
+ mappend IntegerType a = a
+ mappend a IntegerType = a
+ mappend RealType RealType = RealType
+ mappend ComplexType _ = ComplexType
+ mappend _ ComplexType = ComplexType
data PositionFieldType =
Psinc Integer |
AnalyticPositionType
- deriving Show
+ deriving (Eq, Show)
data MomentumFieldType =
PsincReciprocal Integer |
AnalyticMomentumType
- deriving Show
+ deriving (Eq, Show)
data PsincE
data PsincReciprocalE
binaryToDoc op lhs rhs = parens $ hcat [toDoc lhs, text op, toDoc rhs]
functionToDoc name params = text name <> (parens $ hcat (intersperse (text ", ") params))
--- getType :: OFL2 -> Expression -> ValueType
--- getType ofl2 (IndexedIdentifier name _) = getValueType ofl2 name
--- getType ofl2 (ToMomentum e) = case getType ofl2 e of
--- basis@(PsincReciprocal _) -> basis
--- Psinc i -> PsincReciprocal i
--- _ -> error "ToMomentum applied to non-Psinc value"
--- getType ofl2 (ToPosition e) = case getType ofl2 e of
--- basis@(Psinc _) -> basis
--- PsincReciprocal i -> Psinc i
--- _ -> error "ToPosition applied to non-Psinc value"
--- getType ofl2 (Upsample e) = case getType ofl2 e of
--- Psinc i -> Psinc $ i+1
--- _ -> error "Upsample applied to non-Psinc value"
--- getType ofl2 (Downsample e) = case getType ofl2 e of
--- Psinc i -> Psinc $ i-1
--- _ -> error "Downsample applied to non-Psinc value"
--- getType ofl2 (Integrate e) = case getType ofl2 e of
--- Psinc _ -> RealType
--- _ -> error "Integrate applied to non-Psinc value"
--- getType _ (ConstReal _) = RealType
--- getType _ (ConstInteger _) = IntegerType
--- getType _ (ConstComplex _) = ComplexType
--- getType _ (PositionComponent _) = PositionFunction
--- getType _ (MomentumComponent _) = MomentumFunction
--- getType ofl2 (Negate e) = getType ofl2 e
--- getType ofl2 (Sum e _) = getType ofl2 e
---getType ofl2 e = error "Unimplemented"
-
-getValueType :: OFL2 -> String -> ValueType
-getValueType ofl2 name = case Map.lookup name (symbols ofl2) of
+getImplementationType :: OFL2 -> Expression e -> Either String ValueType
+getImplementationType ofl2 expr = case expr of
+ IndexedPsincIdentifier name _ -> Right $ getIdentType ofl2 name
+ IndexedScalarIdentifier name _ -> Right $ getIdentType ofl2 name
+ ToMomentum e -> do
+ eType <- getImplementationType ofl2 e
+ case eType of
+ PositionT (Psinc i) -> return $ MomentumT $ PsincReciprocal i
+ PositionT (AnalyticPositionType) ->
+ Left $ "Cannot convert analytic position expression " ++ prettyPrint e ++ " to momentum representation."
+ _ -> Left $ "ToMomentum applied to incorrect expression " ++ prettyPrint e
+ ToPosition e -> do
+ eType <- getImplementationType ofl2 e
+ case eType of
+ MomentumT (PsincReciprocal i) -> return $ PositionT $ Psinc i
+ MomentumT (AnalyticMomentumType) ->
+ Left $ "Cannot convert analytic momentum expression " ++ prettyPrint e ++ " to position representation."
+ _ -> Left $ "ToPosition applied to incorrect expression " ++ prettyPrint e
+ Upsample e -> do
+ eType <- getImplementationType ofl2 e
+ case eType of
+ PositionT (Psinc i) -> return $ PositionT (Psinc $ i+1)
+ _ -> Left $ "Upsample applied to inappropriate expression " ++ prettyPrint e
+ Downsample e -> do
+ eType <- getImplementationType ofl2 e
+ case eType of
+ PositionT (Psinc i) -> if i>1
+ then return $ PositionT (Psinc $ i-1)
+ else Left $ "Cannot downsample expression " ++ prettyPrint e ++ " further."
+ _ -> Left $ "Downsample applied to inappropriate expression " ++ prettyPrint e
+ Integrate integrand -> do
+ eType <- getImplementationType ofl2 integrand
+ case eType of
+ PositionT (Psinc _) -> return $ ScalarT RealType
+ _ -> Left $ "Do not know how to integrate " ++ prettyPrint integrand
+ AnalyticMomentum _ -> return $ MomentumT AnalyticMomentumType
+ AnalyticPosition _ -> return $ PositionT AnalyticPositionType
+ AnalyticToPsinc e i -> do
+ eType <- getImplementationType ofl2 e
+ case eType of
+ PositionT AnalyticPositionType -> return $ PositionT $ Psinc i
+ _ -> Left $ "AnalyticToPsinc applied to incorrect expression " ++ prettyPrint e
+ Sum e _ -> getImplementationType ofl2 e
+ Add a b -> handleAdditiveExpression a b
+ Sub a b -> handleAdditiveExpression a b
+ Neg e -> getImplementationType ofl2 e
+ MulScalar e s -> handleMultiplicativeExpression e s
+ DivScalar e s -> handleMultiplicativeExpression e s
+ Power a b -> do
+ ScalarT aType <- getImplementationType ofl2 a
+ ScalarT bType <- getImplementationType ofl2 b
+ return $ ScalarT $ mappend aType bType
+ PsincProduct a b -> do
+ aType <- getImplementationType ofl2 a
+ bType <- getImplementationType ofl2 b
+ case (aType, bType) of
+ (PositionT (Psinc n), PositionT (Psinc n2)) -> if n == n2
+ then return $ PositionT $ Psinc n
+ else Left $ "Expressions " ++ prettyPrint a ++ " and " ++ prettyPrint b ++ " do not have matching basis sets"
+ _ -> Left $ "Do not know how to calculate product between " ++ prettyPrint a ++ " and " ++ prettyPrint b
+ PsincReciprocalProduct a b -> do
+ aType <- getImplementationType ofl2 a
+ bType <- getImplementationType ofl2 b
+ case (aType, bType) of
+ (MomentumT (PsincReciprocal n), MomentumT (PsincReciprocal n2)) -> if n == n2
+ then return $ MomentumT $ PsincReciprocal n
+ else Left $ "Expressions " ++ prettyPrint a ++ " and " ++ prettyPrint b ++ " do not have matching basis sets"
+ (MomentumT (PsincReciprocal n), MomentumT AnalyticMomentumType) -> return $ MomentumT (PsincReciprocal n)
+ (MomentumT AnalyticMomentumType, MomentumT (PsincReciprocal n)) -> return $ MomentumT (PsincReciprocal n)
+ _ -> Left $ "Do not know how to calculate product between " ++ prettyPrint a ++ " and " ++ prettyPrint b
+ ConstInteger _ -> Right $ ScalarT IntegerType
+ ConstReal _ -> Right $ ScalarT RealType
+ ConstComplex _ -> Right $ ScalarT ComplexType
+ where
+ handleAdditiveExpression a b = do
+ aType <- getImplementationType ofl2 a
+ bType <- getImplementationType ofl2 b
+ if aType == bType
+ then return aType
+ else Left $ "Cannot add/subtract expressions of mismatching types " ++ prettyPrint a ++ " and " ++ prettyPrint b
+ handleMultiplicativeExpression e s = do
+ eType <- getImplementationType ofl2 e
+ sType <- getImplementationType ofl2 s
+ case (eType, sType) of
+ (ScalarT a, ScalarT b) -> return $ ScalarT $ mappend a b
+ (PositionT (Psinc _), ScalarT IntegerType) -> return eType
+ (PositionT (Psinc _), ScalarT RealType) -> return eType
+ (MomentumT (PsincReciprocal _), ScalarT IntegerType) -> return eType
+ (MomentumT (PsincReciprocal _), ScalarT RealType) -> return eType
+ (MomentumT (PsincReciprocal _), ScalarT ComplexType) -> return eType
+ _ -> Left $ "Do not know how to implement multiplication/division of " ++ prettyPrint e ++ " by " ++ prettyPrint s
+
+getIdentType :: OFL2 -> String -> ValueType
+getIdentType ofl2 name = case Map.lookup name (symbols ofl2) of
Just (ValueTag baseType _) -> baseType
_ -> error $ "Could not find type of symbol " ++ name
fromInteger _ = error "Unimplemented: Bandwidth fromInteger"
theoreticalFrequency :: OFL2 -> Expression e -> Bandwidth
-theoreticalFrequency ofl2 (IndexedPsincIdentifier ident _) = case getValueType ofl2 ident of
+theoreticalFrequency ofl2 (IndexedPsincIdentifier ident _) = case getIdentType ofl2 ident of
PositionT (Psinc i) -> GMaxMultiple i
_ -> error $ "Expected identifier " ++ ident ++ " to be a Psinc field."
theoreticalFrequency _ (IndexedScalarIdentifier _ _) = GMaxMultiple 0