From: Francis Russell Date: Wed, 21 Nov 2012 21:20:49 +0000 (+0000) Subject: Attempt to compute and validate second level expression types. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=4542fc1ef55dd95b73ec5dea7adefcf222ebaa13;p=francis%2Fofc.git Attempt to compute and validate second level expression types. --- diff --git a/OFC/SecondLevel.hs b/OFC/SecondLevel.hs index 164f2e8..d094911 100644 --- a/OFC/SecondLevel.hs +++ b/OFC/SecondLevel.hs @@ -16,6 +16,7 @@ import qualified OFC.TargetMapping as TM import OFC.TopLevel (OFL) import qualified OFC.TopLevel as TopLevel import Data.Complex +import Data.Monoid (Monoid(..)) import Text.PrettyPrint import Data.Map (Map) import Data.List (foldl', intersperse, intercalate) @@ -37,23 +38,31 @@ data ValueType = ScalarT ScalarType | PositionT PositionFieldType | MomentumT MomentumFieldType - deriving Show + deriving (Eq, Show) data ScalarType = RealType | ComplexType | IntegerType - deriving Show + deriving (Eq, Show) + +instance Monoid ScalarType where + mempty = IntegerType + mappend IntegerType a = a + mappend a IntegerType = a + mappend RealType RealType = RealType + mappend ComplexType _ = ComplexType + mappend _ ComplexType = ComplexType data PositionFieldType = Psinc Integer | AnalyticPositionType - deriving Show + deriving (Eq, Show) data MomentumFieldType = PsincReciprocal Integer | AnalyticMomentumType - deriving Show + deriving (Eq, Show) data PsincE data PsincReciprocalE @@ -162,36 +171,100 @@ instance PrettyPrintable (Expression e) where binaryToDoc op lhs rhs = parens $ hcat [toDoc lhs, text op, toDoc rhs] functionToDoc name params = text name <> (parens $ hcat (intersperse (text ", ") params)) --- getType :: OFL2 -> Expression -> ValueType --- getType ofl2 (IndexedIdentifier name _) = getValueType ofl2 name --- getType ofl2 (ToMomentum e) = case getType ofl2 e of --- basis@(PsincReciprocal _) -> basis --- Psinc i -> PsincReciprocal i --- _ -> error "ToMomentum applied to non-Psinc value" --- getType ofl2 (ToPosition e) = case getType ofl2 e of --- basis@(Psinc _) -> basis --- PsincReciprocal i -> Psinc i --- _ -> error "ToPosition applied to non-Psinc value" --- getType ofl2 (Upsample e) = case getType ofl2 e of --- Psinc i -> Psinc $ i+1 --- _ -> error "Upsample applied to non-Psinc value" --- getType ofl2 (Downsample e) = case getType ofl2 e of --- Psinc i -> Psinc $ i-1 --- _ -> error "Downsample applied to non-Psinc value" --- getType ofl2 (Integrate e) = case getType ofl2 e of --- Psinc _ -> RealType --- _ -> error "Integrate applied to non-Psinc value" --- getType _ (ConstReal _) = RealType --- getType _ (ConstInteger _) = IntegerType --- getType _ (ConstComplex _) = ComplexType --- getType _ (PositionComponent _) = PositionFunction --- getType _ (MomentumComponent _) = MomentumFunction --- getType ofl2 (Negate e) = getType ofl2 e --- getType ofl2 (Sum e _) = getType ofl2 e ---getType ofl2 e = error "Unimplemented" - -getValueType :: OFL2 -> String -> ValueType -getValueType ofl2 name = case Map.lookup name (symbols ofl2) of +getImplementationType :: OFL2 -> Expression e -> Either String ValueType +getImplementationType ofl2 expr = case expr of + IndexedPsincIdentifier name _ -> Right $ getIdentType ofl2 name + IndexedScalarIdentifier name _ -> Right $ getIdentType ofl2 name + ToMomentum e -> do + eType <- getImplementationType ofl2 e + case eType of + PositionT (Psinc i) -> return $ MomentumT $ PsincReciprocal i + PositionT (AnalyticPositionType) -> + Left $ "Cannot convert analytic position expression " ++ prettyPrint e ++ " to momentum representation." + _ -> Left $ "ToMomentum applied to incorrect expression " ++ prettyPrint e + ToPosition e -> do + eType <- getImplementationType ofl2 e + case eType of + MomentumT (PsincReciprocal i) -> return $ PositionT $ Psinc i + MomentumT (AnalyticMomentumType) -> + Left $ "Cannot convert analytic momentum expression " ++ prettyPrint e ++ " to position representation." + _ -> Left $ "ToPosition applied to incorrect expression " ++ prettyPrint e + Upsample e -> do + eType <- getImplementationType ofl2 e + case eType of + PositionT (Psinc i) -> return $ PositionT (Psinc $ i+1) + _ -> Left $ "Upsample applied to inappropriate expression " ++ prettyPrint e + Downsample e -> do + eType <- getImplementationType ofl2 e + case eType of + PositionT (Psinc i) -> if i>1 + then return $ PositionT (Psinc $ i-1) + else Left $ "Cannot downsample expression " ++ prettyPrint e ++ " further." + _ -> Left $ "Downsample applied to inappropriate expression " ++ prettyPrint e + Integrate integrand -> do + eType <- getImplementationType ofl2 integrand + case eType of + PositionT (Psinc _) -> return $ ScalarT RealType + _ -> Left $ "Do not know how to integrate " ++ prettyPrint integrand + AnalyticMomentum _ -> return $ MomentumT AnalyticMomentumType + AnalyticPosition _ -> return $ PositionT AnalyticPositionType + AnalyticToPsinc e i -> do + eType <- getImplementationType ofl2 e + case eType of + PositionT AnalyticPositionType -> return $ PositionT $ Psinc i + _ -> Left $ "AnalyticToPsinc applied to incorrect expression " ++ prettyPrint e + Sum e _ -> getImplementationType ofl2 e + Add a b -> handleAdditiveExpression a b + Sub a b -> handleAdditiveExpression a b + Neg e -> getImplementationType ofl2 e + MulScalar e s -> handleMultiplicativeExpression e s + DivScalar e s -> handleMultiplicativeExpression e s + Power a b -> do + ScalarT aType <- getImplementationType ofl2 a + ScalarT bType <- getImplementationType ofl2 b + return $ ScalarT $ mappend aType bType + PsincProduct a b -> do + aType <- getImplementationType ofl2 a + bType <- getImplementationType ofl2 b + case (aType, bType) of + (PositionT (Psinc n), PositionT (Psinc n2)) -> if n == n2 + then return $ PositionT $ Psinc n + else Left $ "Expressions " ++ prettyPrint a ++ " and " ++ prettyPrint b ++ " do not have matching basis sets" + _ -> Left $ "Do not know how to calculate product between " ++ prettyPrint a ++ " and " ++ prettyPrint b + PsincReciprocalProduct a b -> do + aType <- getImplementationType ofl2 a + bType <- getImplementationType ofl2 b + case (aType, bType) of + (MomentumT (PsincReciprocal n), MomentumT (PsincReciprocal n2)) -> if n == n2 + then return $ MomentumT $ PsincReciprocal n + else Left $ "Expressions " ++ prettyPrint a ++ " and " ++ prettyPrint b ++ " do not have matching basis sets" + (MomentumT (PsincReciprocal n), MomentumT AnalyticMomentumType) -> return $ MomentumT (PsincReciprocal n) + (MomentumT AnalyticMomentumType, MomentumT (PsincReciprocal n)) -> return $ MomentumT (PsincReciprocal n) + _ -> Left $ "Do not know how to calculate product between " ++ prettyPrint a ++ " and " ++ prettyPrint b + ConstInteger _ -> Right $ ScalarT IntegerType + ConstReal _ -> Right $ ScalarT RealType + ConstComplex _ -> Right $ ScalarT ComplexType + where + handleAdditiveExpression a b = do + aType <- getImplementationType ofl2 a + bType <- getImplementationType ofl2 b + if aType == bType + then return aType + else Left $ "Cannot add/subtract expressions of mismatching types " ++ prettyPrint a ++ " and " ++ prettyPrint b + handleMultiplicativeExpression e s = do + eType <- getImplementationType ofl2 e + sType <- getImplementationType ofl2 s + case (eType, sType) of + (ScalarT a, ScalarT b) -> return $ ScalarT $ mappend a b + (PositionT (Psinc _), ScalarT IntegerType) -> return eType + (PositionT (Psinc _), ScalarT RealType) -> return eType + (MomentumT (PsincReciprocal _), ScalarT IntegerType) -> return eType + (MomentumT (PsincReciprocal _), ScalarT RealType) -> return eType + (MomentumT (PsincReciprocal _), ScalarT ComplexType) -> return eType + _ -> Left $ "Do not know how to implement multiplication/division of " ++ prettyPrint e ++ " by " ++ prettyPrint s + +getIdentType :: OFL2 -> String -> ValueType +getIdentType ofl2 name = case Map.lookup name (symbols ofl2) of Just (ValueTag baseType _) -> baseType _ -> error $ "Could not find type of symbol " ++ name @@ -401,7 +474,7 @@ instance Num Bandwidth where fromInteger _ = error "Unimplemented: Bandwidth fromInteger" theoreticalFrequency :: OFL2 -> Expression e -> Bandwidth -theoreticalFrequency ofl2 (IndexedPsincIdentifier ident _) = case getValueType ofl2 ident of +theoreticalFrequency ofl2 (IndexedPsincIdentifier ident _) = case getIdentType ofl2 ident of PositionT (Psinc i) -> GMaxMultiple i _ -> error $ "Expected identifier " ++ ident ++ " to be a Psinc field." theoreticalFrequency _ (IndexedScalarIdentifier _ _) = GMaxMultiple 0