]> git.unchartedbackwaters.co.uk Git - francis/stp.git/commitdiff
Speedup. The bvsolver now uses a reduced version of the ASTNode graph which contains...
authortrevor_hansen <trevor_hansen@e59a4935-1847-0410-ae03-e826735625c1>
Wed, 23 Jun 2010 03:53:02 +0000 (03:53 +0000)
committertrevor_hansen <trevor_hansen@e59a4935-1847-0410-ae03-e826735625c1>
Wed, 23 Jun 2010 03:53:02 +0000 (03:53 +0000)
git-svn-id: https://stp-fast-prover.svn.sourceforge.net/svnroot/stp-fast-prover/trunk/stp@859 e59a4935-1847-0410-ae03-e826735625c1

src/simplifier/CountOfSymbols.h
src/simplifier/Symbols.h [new file with mode: 0644]
src/simplifier/bvsolver.cpp
src/simplifier/bvsolver.h

index 4041f56e5a28285b6dc5e645b6bcda6bc40a32c1..c7f489daedc009dcbd2c2ae157fdf4de39722395 100644 (file)
@@ -3,6 +3,7 @@
 
 #include "../AST/AST.h"
 #include <cassert>
+#include "Symbols.h"
 
 // Count the number of times each symbol appears in the input term.
 // This can be expensive to build for large terms, so it's built lazily.
@@ -15,57 +16,34 @@ class CountOfSymbols {
                        ASTNode::ASTNodeEqual> ASTNodeToIntMap;
 
        ASTNodeToIntMap Vars;
-       const ASTNode& top;
+       const Symbols* top;
        bool loaded;
 
-       ASTNodeSet visited;
-
        // If no variables are found in "term", then it's cached---we don't need to visit there
        // again. However, if it's true, we need to revisit (and hence recount), the next time it's
        // encountered.
 
-       bool VarsInTheTerm(const ASTNode& term) {
-
-               if (visited.find(term) != visited.end())
-                       return false;
-
-               bool found = false;
+       void VarsInTheTerm(const Symbols* term) {
+               assert(!term->empty());
 
-               switch (term.GetKind()) {
-               case BVCONST:
-                       return false;
-               case SYMBOL:
-                       //cerr << "debugging: symbol added: " << term << endl;
-                       Vars[term]++;
-                       return true;
-               case READ:
-                       //skip the arrayname, provided the arrayname is a SYMBOL
-                       //But we don't skip it if it's a WRITE function??
-                       if (SYMBOL == term[0].GetKind()) {
-                               found |= VarsInTheTerm(term[1]);
-                       } else {
-                               found |= VarsInTheTerm(term[0]);
-                               found |= VarsInTheTerm(term[1]);
-                       }
-                       break;
-               default: {
-                       const ASTVec& c = term.GetChildren();
-                       for (ASTVec::const_iterator it = c.begin(), itend = c.end(); it
-                                       != itend; it++) {
-                               found |= VarsInTheTerm(*it);
-                       }
-                       break;
+               if (!term->found.IsNull())
+               {
+                       Vars[term->found]++;
                }
+               else
+               {
+                       const vector<Symbols*>& c = term->children;
+                       for (vector<Symbols*>::const_iterator it = c.begin(), itend = c.end(); it
+                                       != itend; it++)
+                       {
+                               VarsInTheTerm(*it);
+                       }
                }
-
-               if (!found)
-                       visited.insert(term);
-               return found;
        } //end of VarsInTheTerm()
 
 public:
 
-       CountOfSymbols(const ASTNode& top_v) :
+       CountOfSymbols(const Symbols* top_v) :
                top(top_v) {
                loaded = false;
        }
@@ -75,7 +53,6 @@ public:
                {
                        VarsInTheTerm(top);
                        loaded = true;
-                       visited.clear();
                }
 
                ASTNodeToIntMap::const_iterator it = Vars.find(m);
diff --git a/src/simplifier/Symbols.h b/src/simplifier/Symbols.h
new file mode 100644 (file)
index 0000000..1a46ed9
--- /dev/null
@@ -0,0 +1,74 @@
+#ifndef SYMBOLS_H
+#define SYMBOLS_H
+
+// Each node is either: empty, an ASTNode, or a vector of more than one child nodes.
+
+class Symbols {
+       private:
+               Symbols& operator =(const Symbols& other) { /*..*/}
+               Symbols(const Symbols& other) {/*..*/}
+
+//             pair<ASTNode,bool> cache;
+       public:
+
+               const ASTNode found;
+               const vector<Symbols*> children;
+
+               Symbols() {
+               }
+
+               Symbols(const ASTNode& n): found(n)
+               {
+               }
+
+               // This will create an "empty" node if the array is empty.
+               Symbols(const vector<Symbols*>& s):
+                       children(s.begin(), s.end())
+               {
+                       // Children should never be empty. They shouldn't be children.
+                       for(vector<Symbols*>::const_iterator it = children.begin(); it!= children.end(); it++)
+                       {
+                               assert(!(*it)->empty());
+                       }
+
+                       assert(children.size() != 1);
+               }
+
+               bool isContained(const ASTNode& n) {
+//                     if (cache.first == n)
+//                             return cache.second;
+
+                       bool result = false;
+                       if (!found.IsNull())
+                               result =  (found == n);
+                       else {
+                               for (int i = 0; i < children.size(); i++)
+                                       if (children[i]->isContained(n))
+                                       {
+                                               result =  true;
+                                               break;
+                                       }
+                       }
+//                     cache = make_pair(n,result);
+                       return result;
+               }
+
+               bool empty() const {
+                       return (found.IsNull() && children.size() == 0);
+               }
+
+
+       };
+
+class SymbolPtrHasher
+{
+public:
+  size_t operator()(const Symbols * n) const
+  {
+    return (size_t) n;
+  }
+  ;
+}; //End of ASTNodeHasher
+
+
+#endif
index 0ef386a3386035a63f389be18a0a8785d8728bef..a227b1e2d1acb3a61b8994c25cefcb4c4e80b2d6 100644 (file)
@@ -190,7 +190,9 @@ namespace BEEV
     const ASTNode& rhs = eq[1];
 
     //collect all the vars in the lhs and rhs
-    CountOfSymbols count(eq);
+
+    BuildSymbolGraph(eq);
+    CountOfSymbols count(symbol_graph[eq]);
 
     //handle BVPLUS case
     const ASTVec& c = lhs.GetChildren();
@@ -935,36 +937,84 @@ namespace BEEV
     return output;
   } //end of BVSolve_Even()
 
-  bool BVSolver::VarSeenInTerm(const ASTNode& var, const ASTNode& term)
-  {
-    ASTNodeMap::iterator it;
-    if ((it = TermsAlreadySeenMap.find(term)) != TermsAlreadySeenMap.end())
-      {
-        if (it->second == var)
-          {
-            return false;
-          }
-      }
 
-    if (var == term)
-      {
-        return true;
-      }
+       // This builds a reduced version of a graph, where there
+    // is only a new node if the number of non-array SYMBOLS
+    // in the descendents changes. For example (EXTRACT 0 1 n)
+    // will have the same "Symbols" node as n, because there is
+    // no new symbols are introduced.
+       Symbols* BVSolver::BuildSymbolGraph(const ASTNode& n)
+       {
+       if (symbol_graph.find(n) != symbol_graph.end())
+       {
+               return symbol_graph[n];
+       }
 
-    if (term.isConstant())
-       return false;
+       Symbols* node;
+
+       // Note we skip array variables. We never solve for them so
+       // can ignore them.
+       if (n.GetKind() == SYMBOL && n.GetIndexWidth() == 0) {
+               node = new Symbols(n);
+               symbol_graph.insert(make_pair(n, node));
+               return node;
+       }
+
+       vector<Symbols*> children;
+       for (int i = 0; i < n.Degree(); i++) {
+               Symbols* v = BuildSymbolGraph(n[i]);
+               if (!v->empty())
+                       children.push_back(v);
+       }
+
+       if (children.size() == 1) {
+               // If there is only a single child with a symbol. Then jump to it.
+               node = children.back();
+       }
+       else
+               node = new Symbols(children);
+
+       symbol_graph.insert(make_pair(n, node));
+
+       return node;
+       }
+
+
+         bool BVSolver::VarSeenInTerm(const ASTNode& var, Symbols* term)
+         {
+                 SymbolPtrToNode::iterator it;
+           if ((it = TermsAlreadySeenMap.find(term)) != TermsAlreadySeenMap.end())
+             {
+               if (it->second == var)
+                 {
+                       return false;
+                 }
+             }
+
+           if (var == term->found)
+             {
+               return true;
+             }
+
+           for (vector<Symbols*>::const_iterator
+                  it = term->children.begin(), itend = term->children.end();
+                it != itend; it++)
+             {
+               if (VarSeenInTerm(var, *it))
+                 {
+                   return true;
+                 }
+             }
+
+           TermsAlreadySeenMap[term] = var;
+           return false;
+         }//End of VarSeenInTerm
+
+         bool BVSolver::VarSeenInTerm(const ASTNode& var, const ASTNode& term)
+         {
+                 BuildSymbolGraph(term);
+                 return VarSeenInTerm(var,symbol_graph[term]);
+         }
 
-    for (ASTVec::const_iterator 
-           it = term.begin(), itend = term.end();
-         it != itend; it++)
-      {
-        if (VarSeenInTerm(var, *it))
-          {
-            return true;
-          }
-      }
 
-    TermsAlreadySeenMap[term] = var;
-    return false;
-  }//End of VarSeenInTerm
 };//end of namespace BEEV
index 3dad93a9620aaad6932afb331ae6f7affa37b4f6..c5de8e5cd8ef86b18ae8f74eebe6d55165c81aad 100644 (file)
@@ -11,6 +11,7 @@
 #define BVSOLVER_H
 
 #include "simplifier.h"
+#include "Symbols.h"
 
 namespace BEEV
 {
@@ -60,8 +61,14 @@ namespace BEEV
     //this map is useful while traversing terms and uniquely
     //identifying variables in the those terms. Prevents double
     //counting.
-    ASTNodeMap TermsAlreadySeenMap;
-    //ASTNodeMap TermsAlreadySeenMap_ForArrays;
+       typedef HASHMAP<
+         Symbols*,
+         ASTNode,
+         SymbolPtrHasher
+         > SymbolPtrToNode;
+       SymbolPtrToNode TermsAlreadySeenMap;
+
+       //ASTNodeMap TermsAlreadySeenMap_ForArrays;
 
     //solved variables list: If a variable has been solved for then do
     //not solve for it again
@@ -97,6 +104,8 @@ namespace BEEV
     //this function return true if the var occurs in term, else the
     //function returns false
     bool VarSeenInTerm(const ASTNode& var, const ASTNode& term);
+       bool VarSeenInTerm(const ASTNode& var, Symbols* term);
+
 
     //takes an even number "in" as input, and returns an odd number
     //(return value) and a power of 2 (as number_shifts by reference),
@@ -130,6 +139,15 @@ namespace BEEV
     //else returns FALSE
     bool CheckAlreadySolvedMap(const ASTNode& key, ASTNode& output);
 
+       typedef HASHMAP<
+         ASTNode,
+         Symbols*,
+         ASTNode::ASTNodeHasher,
+         ASTNode::ASTNodeEqual> ASTNodeToNodes;
+         ASTNodeToNodes symbol_graph;
+
+         Symbols* BuildSymbolGraph(const ASTNode& n);
+
   public:
     //constructor
   BVSolver(STPMgr * bm, Simplifier * simp) : _bm(bm), _simp(simp)       
@@ -147,6 +165,10 @@ namespace BEEV
         DoNotSolve_TheseVars.clear();
         FormulasAlreadySolvedMap.clear();
         //TermsAlreadySeenMap_ForArrays.clear();
+        for (ASTNodeToNodes::iterator it = symbol_graph.begin(); it != symbol_graph.end(); it++)
+         {
+               delete it->second;
+        }
       }
 
     //Top Level Solver: Goes over the input DAG, identifies the