--- /dev/null
+{-# LANGUAGE EmptyDataDecls, FlexibleInstances, FlexibleContexts #-}
+
+module LTA.Symbolic
+ ( Expr(..)
+ , Constant(..)
+ , Literal(..)
+ , UnaryFunction(..)
+ , pow
+ , simplify
+ ) where
+
+import Control.Applicative ((<$>))
+import Data.Ratio (numerator, denominator, (%))
+import Data.List (foldl')
+import Data.Map (Map)
+import qualified Data.Map as Map
+
+data SumTag
+data ProductTag
+
+class PairSeqLike e where
+ empty :: e
+ rebuild :: e -> Expr
+ extractMultipliers :: e -> e
+ flatten :: e -> e
+
+instance PairSeqLike (PairSeq SumTag Literal) where
+ empty = PairSeq 0 Map.empty
+ rebuild (PairSeq overall pairs) = case (overall, Map.toList pairs) of
+ (n, []) -> Literal n
+ (0, [(a, 1)]) -> a
+ _ -> Sum $ PairSeq overall pairs
+ extractMultipliers (PairSeq overall pairs) =
+ PairSeq overall pairs' where
+ newTerms = rebuildTerm <$> (Map.toList pairs)
+ rebuildTerm (expr, coeff) = (expr', coeff * factor) where
+ (expr', factor) = extractMultiplier expr
+ pairs' = Map.fromListWith (+) newTerms
+ flatten pairSeq = mergeTerms empty 1 pairSeq where
+ mergeTerms (PairSeq overall terms) multiplier (PairSeq childOverall childTerms) = pairSeq'' where
+ pairSeq' = Map.foldlWithKey' mergeTerm (PairSeq overall terms) childTerms
+ pairSeq'' = mergeTerm pairSeq' (Literal childOverall) 1
+ mergeTerm oldSeq expr coeff =
+ let localMultiplier = multiplier * coeff in
+ case expr of
+ Sum childSeq -> mergeTerms oldSeq localMultiplier childSeq
+ Literal n -> transformOverall oldSeq (+ (localMultiplier * n))
+ _ -> addPair oldSeq (expr, localMultiplier)
+
+instance PairSeqLike (PairSeq ProductTag Expr) where
+ empty = PairSeq 1 Map.empty
+ rebuild (PairSeq overall pairs) = case (overall, Map.toList pairs) of
+ (n, []) -> Literal n
+ (1, [(a, 1)]) -> a
+ _ -> Product $ PairSeq overall pairs
+ extractMultipliers (PairSeq overall pairs) =
+ PairSeq overall' (if overall' == 0 then Map.empty else pairs') where
+ pairs' = Map.fromListWith (+) newTerms
+ newTerms = map (\(_, expr, coeff) -> (expr, coeff)) analysedTerms
+ overall' = foldl' (*) overall $ map (\(multiplier, _, _) -> multiplier) analysedTerms
+ analysedTerms = rebuildTerm <$> (Map.toList pairs)
+ rebuildTerm (expr, coeff) = let (expr', factor) = extractMultiplier expr in
+ case (Literal factor) `evalPow` coeff of
+ Just multiplier -> (multiplier, expr', coeff)
+ Nothing -> (1, expr, coeff)
+ flatten pairSeq = mergeTerms empty 1 pairSeq where
+ mergeTerms (PairSeq overall terms) multiplier (PairSeq childOverall childTerms) = pairSeq'' where
+ pairSeq' = Map.foldlWithKey' mergeTerm (PairSeq overall terms) childTerms
+ pairSeq'' = mergeTerm pairSeq' (Literal childOverall) 1
+ mergeTerm oldSeq expr coeff =
+ let localMultiplier = multiplier * coeff in
+ case (expr, localMultiplier) of
+ (Product childSeq, _) -> mergeTerms oldSeq localMultiplier childSeq
+ (a, b) -> case evalPow a b of
+ Just n -> transformOverall oldSeq (* n)
+ Nothing -> addPair oldSeq (a, b)
+
+data Expr
+ = IntegerSymbol String
+ | FloatSymbol String
+ | Constant Constant
+ | Literal Literal
+ | Summation String Expr Expr Expr
+ | Sum (PairSeq SumTag Literal)
+ | Product (PairSeq ProductTag Expr)
+ | UnaryFunction UnaryFunction Expr
+ | Div Expr Expr
+ | Mod Expr Expr
+ deriving (Eq, Ord, Show)
+
+data Literal
+ = RationalLiteral Rational
+ | FloatLiteral Double
+ deriving (Eq, Ord, Show)
+
+data Constant
+ = Pi
+ | Euler
+ | ImaginaryUnit
+ deriving (Eq, Ord, Show)
+
+data PairSeq tag coeff
+ = PairSeq Literal (Map Expr coeff)
+ deriving (Eq, Ord, Show)
+
+data UnaryFunction
+ = Abs
+ | Signum
+ deriving (Eq, Ord, Show)
+
+simplify :: Expr -> Expr
+simplify (Sum pairSeq) = rebuild $ normalise pairSeq
+simplify (Product pairSeq) = rebuild $ normalise pairSeq
+simplify e = e
+
+addPair :: Num c => PairSeq t c -> (Expr, c) -> PairSeq t c
+addPair (PairSeq overall pairs) (expr, coeff) =
+ PairSeq overall $ Map.insertWith' (+) expr coeff pairs
+
+transformOverall :: PairSeq t c -> (Literal -> Literal) -> PairSeq t c
+transformOverall (PairSeq overall terms) f = PairSeq (f overall) terms
+
+removeZeros :: (Num c, Eq c) => PairSeq t c -> PairSeq t c
+removeZeros (PairSeq overall terms) = PairSeq overall terms' where
+ terms' = Map.fromListWith (+) filteredTerms
+ filteredTerms = filter (not . isZero . snd) (Map.toList terms)
+
+normalise :: (PairSeqLike (PairSeq t c), Eq c, Num c) => PairSeq t c -> PairSeq t c
+normalise = removeZeros . flatten . extractMultipliers
+
+isZero :: (Num e, Eq e) => e -> Bool
+isZero n = n == 0
+
+extractMultiplier :: Expr -> (Expr, Literal)
+extractMultiplier (Literal literal) = (1, literal)
+extractMultiplier (Product pairSeq) = (expr, coeff) where
+ expr = rebuild $ (PairSeq 1 pairSeq' :: PairSeq ProductTag Expr)
+ (PairSeq coeff pairSeq') = normalise pairSeq
+extractMultiplier (Sum pairSeq) = (rebuild $ normalise pairSeq, 1)
+extractMultiplier e = (e, 1)
+
+pow :: Expr -> Expr -> Expr
+pow a b = simplify $ Product $ PairSeq 1 (Map.singleton a b)
+
+evalPow :: Expr -> Expr -> Maybe Literal
+evalPow (Literal a) (Literal b) = evalPow' a b where
+ evalPow' (FloatLiteral x) (FloatLiteral y) = Just . FloatLiteral $ x ** y
+ evalPow' (RationalLiteral x) (RationalLiteral y) = if (denominator y) == 1
+ then let power = numerator y in
+ Just . RationalLiteral $ (numerator x ^ power) % (denominator x ^ power)
+ else Nothing
+ evalPow' _ _ = Nothing
+evalPow _ _ = Nothing
+
+instance Num Literal where
+ (+) (RationalLiteral a) (RationalLiteral b) = RationalLiteral $ a + b
+ (+) (FloatLiteral a) (FloatLiteral b) = FloatLiteral $ a + b
+ (+) (RationalLiteral a) (FloatLiteral b) = FloatLiteral $ (fromRational a) + b
+ (+) (FloatLiteral a) (RationalLiteral b) = FloatLiteral $ a + (fromRational b)
+ (*) (RationalLiteral a) (RationalLiteral b) = RationalLiteral $ a * b
+ (*) (FloatLiteral a) (FloatLiteral b) = FloatLiteral $ a * b
+ (*) (RationalLiteral a) (FloatLiteral b) = FloatLiteral $ (fromRational a) * b
+ (*) (FloatLiteral a) (RationalLiteral b) = FloatLiteral $ a * (fromRational b)
+ abs (RationalLiteral a) = RationalLiteral $ abs a
+ abs (FloatLiteral a) = FloatLiteral $ abs a
+ signum (RationalLiteral a) = RationalLiteral $ signum a
+ signum (FloatLiteral a) = FloatLiteral $ signum a
+ fromInteger = RationalLiteral . fromInteger
+
+-- The special cases for multiplication by 1 are needed as a base case to avoid
+-- infinite recursion when simplifying products.
+instance Num Expr where
+ (+) n 0 = n
+ (+) 0 n = n
+ (+) a b = simplify $ Sum $ empty `addPair` (a, 1) `addPair` (b, 1)
+ (*) _ 0 = 0
+ (*) 0 _ = 0
+ (*) 1 n = n
+ (*) n 1 = n
+ (*) a b = simplify $ Product $ empty `addPair` (a, 1) `addPair` (b, 1)
+ fromInteger = Literal . fromInteger
+ abs = UnaryFunction Abs
+ signum = UnaryFunction Signum