]> git.unchartedbackwaters.co.uk Git - francis/lta.git/commitdiff
Work on finding and renaming symbols.
authorFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 17 Apr 2013 12:43:05 +0000 (13:43 +0100)
committerFrancis Russell <francis@unchartedbackwaters.co.uk>
Wed, 17 Apr 2013 12:43:05 +0000 (13:43 +0100)
LTA/Symbolic.hs

index 5136006fd7adcf0277978d592c985b90a0114163..107d3bd9344e3b421794fecaf066dc1a3462d2d1 100644 (file)
@@ -34,6 +34,10 @@ class PairSeqLike e where
   extractMultipliers :: e -> e
   flatten :: e -> e
 
+class ContainsSymbols e where
+  rename :: String -> String -> e -> e
+  findSymbols :: e -> Set String
+
 instance PairSeqLike (PairSeq SumTag Literal) where
   empty = PairSeq 0 Map.empty
   rebuild (PairSeq overall pairs) = case (overall, Map.toList pairs) of
@@ -156,6 +160,69 @@ buildConditional base ((cond, expr) : cases) =
 (~>) :: Expr -> Expr -> Cond
 (~>) = Compare GreaterThan
 
+instance ContainsSymbols Expr where
+  rename from to e = case e of
+    IntegerSymbol name -> if name == from
+      then IntegerSymbol to
+      else IntegerSymbol name
+    FloatSymbol name -> if name == from
+      then FloatSymbol to
+      else FloatSymbol name
+    (Constant _) -> e
+    (Literal _) -> e
+    (Summation var low high summand) -> if var == from
+      then e
+      else Summation var (rename' low) (rename' high) (rename' summand)
+    (UnaryFunction f expr) -> UnaryFunction f $ rename' expr
+    (Div e1 e2) -> Div (rename' e1) (rename' e2)
+    (Mod e1 e2) -> Mod (rename' e1) (rename' e2)
+    (Sum (PairSeq overall terms)) ->
+      Sum $ Map.foldlWithKey' mergeTerm (PairSeq overall Map.empty) terms
+    (Product (PairSeq overall terms)) ->
+      Product $ Map.foldlWithKey' mergeTerm (PairSeq overall Map.empty) terms
+    (Conditional cond e1 e2) -> Conditional (rename' cond) (rename' e1) (rename' e2)
+    where
+    rename' :: (ContainsSymbols e) => e -> e
+    rename' = rename from to
+    mergeTerm oldSeq expr coeff = oldSeq `addPair` (rename' expr, coeff)
+  findSymbols e = case e of
+    IntegerSymbol name -> Set.singleton name
+    FloatSymbol name -> Set.singleton name
+    (Constant _) -> Set.empty
+    (Literal _) -> Set.empty
+    (Summation var low high summand) ->
+      foldl' Set.union (Set.singleton var) (findSymbols <$> [low, high, summand])
+    (UnaryFunction _ expr) -> findSymbols expr
+    (Div e1 e2) -> (findSymbols e1) `Set.union` (findSymbols e2)
+    (Mod e1 e2) -> (findSymbols e1) `Set.union` (findSymbols e2)
+    (Sum (PairSeq _ terms)) ->
+      foldl' Set.union Set.empty ((findSymbols . fst) <$> Map.toList terms)
+    (Product (PairSeq _ terms)) ->
+      foldl' Set.union Set.empty ((findSymbols . fst) <$> Map.toList terms)
+    (Conditional cond e1 e2) ->
+      (findSymbols cond) `Set.union` (findSymbols e1) `Set.union` (findSymbols e2)
+
+
+instance ContainsSymbols Cond where
+  rename from to expr = case expr of
+    TrueC -> TrueC
+    FalseC -> FalseC
+    Compare op e1 e2 -> Compare op (rename' e1) (rename' e2)
+    And e1 e2 -> And (rename' e1) (rename' e2)
+    Or e1 e2 -> And (rename' e1) (rename' e2)
+    Not e -> Not $ rename' e
+    where
+    rename' :: (ContainsSymbols e) => e -> e
+    rename' = rename from to
+  findSymbols expr = case expr of
+    TrueC -> Set.empty
+    FalseC -> Set.empty
+    Compare _ e1 e2 -> Set.union (findSymbols e1) (findSymbols e2)
+    And e1 e2 -> Set.union (findSymbols e1) (findSymbols e2)
+    Or e1 e2 -> Set.union (findSymbols e1) (findSymbols e2)
+    Not e -> findSymbols e
+
+
 simplify :: Expr -> Expr
 simplify (Sum pairSeq) = rebuild $ normalise pairSeq
 simplify (Product pairSeq) = rebuild $ normalise pairSeq