]> git.unchartedbackwaters.co.uk Git - francis/ofc.git/commitdiff
Add naïve cost function for comparing second-level expressions.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Mon, 26 Nov 2012 17:52:01 +0000 (17:52 +0000)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Mon, 26 Nov 2012 17:52:01 +0000 (17:52 +0000)
OFC/SecondLevel.hs

index 250c15b45af44f530bcd455212822d55f3ab99f5..1a5d4b6434dea90256513d6a424021c5fd15d654 100644 (file)
@@ -19,7 +19,7 @@ import Data.Complex
 import Data.Monoid (Monoid(..))
 import Text.PrettyPrint
 import Data.Map (Map)
-import Data.List (foldl', intersperse, intercalate)
+import Data.List (foldl', intersperse, intercalate, sortBy)
 import qualified Data.Map as Map
 
 minPsincDensity, maxPsincDensity :: Integer
@@ -292,6 +292,44 @@ getIdentType ofl2 name = case Map.lookup name (symbols ofl2) of
   Just (ValueTag baseType _) -> baseType
   _ -> error $ "Could not find type of symbol " ++ name
 
+-- Note: this is an extremely basic cost metric that takes no account
+-- of repeated operations or expression re-use.
+numFFTs :: OFL2 -> Expression e -> Integer
+numFFTs ofl2 expr = case expr of
+  IndexedPsincIdentifier _ _ -> 0
+  IndexedScalarIdentifier _ _ -> 0
+  ToMomentum e -> 1 + unary e
+  ToPosition e -> 1 + unary e
+  Upsample e -> 2 + unary e
+  Downsample e -> 2 + unary e
+  Integrate integrand -> unary integrand
+  AnalyticMomentum _ -> 0
+  AnalyticPosition _ -> 0
+  AnalyticToPsinc _ _ -> 0
+  Sum e _ -> unary e
+  Add a b -> binary a b
+  Sub a b -> binary a b
+  Neg e -> unary e
+  MulScalar e s -> binary e s
+  DivScalar e s -> binary e s
+  Power a b -> binary a b
+  PsincProduct a b -> binary a b
+  PsincReciprocalProduct a b -> binary a b
+  ConstInteger _ -> 0
+  ConstReal _ -> 0
+  ConstComplex _ -> 0
+  where
+  unary :: Expression a -> Integer
+  unary = numFFTs ofl2
+  binary :: Expression a -> Expression b -> Integer
+  binary a b = unary a + unary b
+
+compareByNumFFTs :: OFL2 -> Expression e -> Expression e -> Ordering
+compareByNumFFTs ofl2 a b = compare aFFTs bFFTs
+  where
+  aFFTs = numFFTs ofl2 a
+  bFFTs = numFFTs ofl2 b
+
 generateVariants :: (Resamplable e, e ~ Expression a1) => OFL2 -> e -> [e]
 generateVariants ofl2 expr = case expr of
   IndexedPsincIdentifier _ _ -> resampled ofl2 expr
@@ -361,8 +399,8 @@ instance PrettyPrintable OFL2 where
     assignmentsDoc = text "Assignments: "
       $$ nest 1 (vcat $ map toDoc $ assignments ofl2) 
       $$ text "Variants:" $$ nest 2 (vcat $ map variantsDoc $ assignments ofl2)
-    variantsDoc (AssignPsinc _ b) = vcat $ map toDoc $ generateVariants ofl2 b
-    variantsDoc (AssignScalar _ b) = vcat $ map toDoc $ generateVariants ofl2 b
+    variantsDoc (AssignPsinc _ b) = vcat $ map toDoc $ sortBy (compareByNumFFTs ofl2) (generateVariants ofl2 b)
+    variantsDoc (AssignScalar _ b) = vcat $ map toDoc $ sortBy (compareByNumFFTs ofl2) (generateVariants ofl2 b)
     symAssocToDoc (name, ValueTag baseType indexTypes) = text $
       (show baseType) ++ " " ++ name ++ case indexTypes of
         [] -> ""