]> git.unchartedbackwaters.co.uk Git - francis/stp.git/commitdiff
Improvements to the code for generating rewrite rules.
authortrevor_hansen <trevor_hansen@e59a4935-1847-0410-ae03-e826735625c1>
Mon, 12 Mar 2012 04:01:56 +0000 (04:01 +0000)
committertrevor_hansen <trevor_hansen@e59a4935-1847-0410-ae03-e826735625c1>
Mon, 12 Mar 2012 04:01:56 +0000 (04:01 +0000)
git-svn-id: https://stp-fast-prover.svn.sourceforge.net/svnroot/stp-fast-prover/trunk/stp@1586 e59a4935-1847-0410-ae03-e826735625c1

src/util/find_rewrites/Functionlist.h
src/util/find_rewrites/Makefile
src/util/find_rewrites/misc.h [new file with mode: 0644]
src/util/find_rewrites/rewrite.cpp
src/util/find_rewrites/rewrite_rule.h
src/util/find_rewrites/rewrite_system.h

index 3e45a58c7290d269f32bd11111da812a2e0ba370..77ea4eedbd3b321abefc2a3b25c27ac4ba785738 100644 (file)
@@ -5,20 +5,16 @@
 
 #ifndef FUNCTIONLIST_H_
 #define FUNCTIONLIST_H_
+#include "rewrite_system.h"
+#include "misc.h"
 
-extern const int bits;
-extern Simplifier *simp;
 extern Rewrite_system rewrite_system;
 
-ASTNode
-widen(const ASTNode& w, int width);
-
-ASTNode
-create(Kind k, const ASTNode& n0, const ASTNode& n1);
-
 
 class Function_list
 {
+  private:
+
 
   // Because v and w might come from "result", if "result" is resized, they will
   // be moved. So we can't use references to them.
@@ -38,35 +34,6 @@ class Function_list
   }
 
 
-  ASTNode
-  rewriteThroughWithAIGS(const ASTNode &n_)
-  {
-    assert(n_.GetType() == BITVECTOR_TYPE);
-    ASTNode f = mgr->LookupOrCreateSymbol("rewriteThroughWithAIGS");
-    f.SetValueWidth(n_.GetValueWidth());
-    ASTNode n = create(EQ, n_, f);
-
-    BBNodeManagerAIG nm;
-    BitBlaster<BBNodeAIG, BBNodeManagerAIG> bb(&nm, simp, mgr->defaultNodeFactory, &mgr->UserFlags);
-    ASTNodeMap fromTo;
-    ASTNodeMap equivs;
-    bb.getConsts(n, fromTo, equivs);
-
-    ASTNode result = n_;
-    if (equivs.size() > 0)
-      {
-        ASTNodeMap cache;
-        result = SubstitutionMap::replace(result, equivs, cache, nf, false, true);
-      }
-
-    if (fromTo.size() > 0)
-      {
-        ASTNodeMap cache;
-        result = SubstitutionMap::replace(result, fromTo, cache, nf);
-      }
-    return result;
-  }
-
   void
   applyBigRewrite()
   {
@@ -187,7 +154,7 @@ class Function_list
 
         if (mgr->ASTUndefined == widen(functions[i], bits + 1))
           {
-            cerr << "Can't widen" << functions[i];
+            //cerr << "Can't widen" << functions[i];
             functions[i] = mgr->ASTUndefined; // We can't widen it later. So remove it.
             continue;
           }
@@ -224,7 +191,7 @@ class Function_list
         if (i % 100000 == 0)
           cerr << "ApplyAigs:" << i << " of " << functions.size() << endl;
 
-        rewriteThroughWithAIGS(functions[i]);
+        functions[i] = rewriteThroughWithAIGS(functions[i]);
       }
   }
 
@@ -255,9 +222,8 @@ public:
 
     allUnary();
 
-
     applyAIGs();
-    applySpeculative();
+    //applySpeculative();
     applyRewritesToAll(functions);
     checkFunctions();
     removeDuplicates(functions);
@@ -276,24 +242,21 @@ public:
           for (int j = 0; j < size; j++)
             getAllFunctions(functions_copy[i], functions_copy[j], functions);
 
-        cerr << "Removing single variables" <<endl;
-        checkFunctions();
         removeNonWidened();
 
         removeSingleVariable();
         removeDuplicates(functions);
-        applySpeculative();
-
-        //applyRewritesToAll(functions);
+        //applySpeculative();
 
         applyAIGs();
 
-        removeDuplicates(functions);
-        removeSingleUndefined();
-
         // All the unary combinations of the binaries.
         allUnary();
 
+        removeDuplicates(functions);
+        removeSingleUndefined();
+        checkFunctions();
+
         cerr << "Two Level:" << functions.size() << endl;
       }
     else
index 9b892b759692c5f0a2b464f62be4799eb9fcb975..49a3dfe631bc3d9b2c3adb8c9c55f77a8af226bb 100644 (file)
@@ -8,7 +8,7 @@ CXXFLAGS += -L../../../lib/
 .PHONY: clean
 
 rewrite: $(OBJS)  $(TOP)lib/libstp.a  
-       $(CXX)   $(CXXFLAGS) $@.o -o $@ -lstp --static
+       $(CXX)   $(CXXFLAGS) $@.o -o $@ -lstp #--static
 
 rewrite.o: rewrite.cpp Functionlist.h  rewrite_rule.h  rewrite_system.h  VariableAssignment.h
        $(CXX)   $(CXXFLAGS) rewrite.cpp -c 
diff --git a/src/util/find_rewrites/misc.h b/src/util/find_rewrites/misc.h
new file mode 100644 (file)
index 0000000..f9cbeb6
--- /dev/null
@@ -0,0 +1,27 @@
+#ifndef MISC_H
+#define MISC_H
+
+  extern const int bits;
+  extern const int widen_to;
+
+  extern Simplifier *simp;
+
+  ASTNode
+  widen(const ASTNode& w, int width);
+
+  ASTNode
+  create(Kind k, const ASTNode& n0, const ASTNode& n1);
+
+  int
+  getDifficulty(const ASTNode& n_);
+
+  bool
+  isConstant(const ASTNode& n, VariableAssignment& different, const int bits);
+
+  vector<ASTNode>
+  getVariables(const ASTNode& n);
+
+  ASTNode
+  rewriteThroughWithAIGS(const ASTNode &n_);
+
+#endif
index af38e34883d8d6214eea1ed7969293e044b5426b..7c46ca80423a58c0d688a29efbc3998f1677de0a 100644 (file)
@@ -25,6 +25,7 @@
 #include "rewrite_rule.h"
 #include "rewrite_system.h"
 #include "Functionlist.h"
+#include "misc.h"
 
 extern int
 smt2parse();
@@ -51,7 +52,7 @@ const int mask = (1 << (bits)) - 1;
 volatile bool force_writeout = false;
 
 // Saves a little bit of time. The vectors are saved between invocations.
-vector<ASTVec> saved_array;
+vector<ASTVec*> saved_array;
 
 // Stores the difficulties that have already been generated.
 map<ASTNode, int> difficulty_cache;
@@ -87,12 +88,14 @@ vector<ASTNode>
 getVariables(const ASTNode& n);
 
 bool
-unifyNode(const ASTNode& n0, const ASTNode& n1,  ASTNodeMap& fromTo, const int term_variable_width);
+matchNode(const ASTNode& n0, const ASTNode& n1,  ASTNodeMap& fromTo, const int term_variable_width);
 
 typedef HASHMAP<ASTNode, string, ASTNode::ASTNodeHasher, ASTNode::ASTNodeEqual> ASTNodeString;
 
 BEEV::STPMgr* mgr;
 NodeFactory* nf;
+NodeFactory* simpNf;
+
 SATSolver * ss;
 ASTNodeSet stored; // Store nodes so they aren't garbage collected.
 Simplifier *simp;
@@ -106,6 +109,23 @@ int lastOutput = 0;
 bool
 checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& ass, bool& bad);
 
+ASTNode
+withNF(const ASTNode &n)
+{
+  if (n.isAtom())
+    return n;
+
+  ASTVec c;
+  for (int i=0; i< n.Degree();i++)
+    c.push_back(withNF(n[i]));
+
+  if (n.GetType() == BOOLEAN_TYPE)
+    return nf->CreateNode(n.GetKind(), c);
+  else
+    return nf->CreateArrayTerm(n.GetKind(), n.GetIndexWidth(), n.GetValueWidth(), c);
+}
+
+
 ASTNode
 renameVars(const ASTNode &n)
 {
@@ -226,15 +246,16 @@ eval(const ASTNode &n, ASTNodeMap& map, int count = 0)
   // We have an array of arrays already created to store the children.
   // This reduces the number of objects created/destroyed.
   if (count >= saved_array.size())
-    saved_array.push_back(ASTVec());
+    saved_array.push_back(new ASTVec());
 
-  ASTVec& new_children = saved_array[count];
+  ASTVec& new_children = *saved_array[count];
   new_children.clear();
 
   for (int i = 0; i < n.Degree(); i++)
     new_children.push_back(eval(n[i], map, count + 1));
 
   ASTNode r = NonMemberBVConstEvaluator(mgr, n.GetKind(), new_children, n.GetValueWidth());
+  new_children.clear();
   map.insert(make_pair(n, r));
   return r;
 }
@@ -284,7 +305,7 @@ checkProp(const ASTNode& n)
 
 // True if it's always true. Otherwise fills the assignment.
 bool
-isConstant(const ASTNode& n, VariableAssignment& different)
+isConstant(const ASTNode& n, VariableAssignment& different, const int bits)
 {
   if (isConstantToSat(n))
     return true;
@@ -293,11 +314,10 @@ isConstant(const ASTNode& n, VariableAssignment& different)
       mgr->ValidFlag = false;
 
       vector<ASTNode> symbols = getVariables(n);
-      assert(symbols.size() > 0);
 
       // Both of them might not be contained in the assignment.
-      different.setV(mgr->CreateZeroConst(symbols[0].GetValueWidth()));
-      different.setW(mgr->CreateZeroConst(symbols[0].GetValueWidth()));
+      different.setV(mgr->CreateZeroConst(bits));
+      different.setW(mgr->CreateZeroConst(bits));
 
       // It might have been widened.
       for (int i = 0; i < symbols.size(); i++)
@@ -398,15 +418,39 @@ widen(const ASTNode& w, int width)
 bool
 orderEquivalence(ASTNode& from, ASTNode& to)
 {
+  if(from.IsNull())
+    return false;
+  if(from.GetKind() == UNDEFINED)
+    return false;
+  if(to.IsNull())
+    return false;
+  if(to.GetKind() == UNDEFINED)
+    return false;
+
+  {
+    ASTVec c;
+    c.push_back(from);
+    c.push_back(to);
+      ASTNode w = widen(mgr->hashingNodeFactory->CreateNode(EQ,c), widen_to);
+
+    if  (w.IsNull() || w.GetKind() == UNDEFINED)
+      return false;
+  }
+
   vector<ASTNode> s_from; // The variables in the from node.
   ASTNodeSet visited;
   getVariables(from, s_from, visited);
   std::sort(s_from.begin(), s_from.end());
+  const int from_c = visited.size();
 
   vector<ASTNode> s_to;  // The variables in the to node.
   visited.clear();
   getVariables(to, s_to, visited);
   sort(s_to.begin(), s_to.end());
+  const int to_c = visited.size();
+
+  if (from_c > 50 || to_c > 50)
+    return false; // not interested in giant rules.
 
   vector<ASTNode> result(s_to.size() + s_from.size());
   // We must map from most variables to fewer variables.
@@ -453,28 +497,17 @@ orderEquivalence(ASTNode& from, ASTNode& to)
   if (s_from.size() > s_to.size())
     return true;
 
-  if (getDifficulty(from) < getDifficulty(to))
+  if ((getDifficulty(from)+5) < getDifficulty(to))
     {
       swap(from, to);
       return true;
     }
 
-  if (getDifficulty(from) > getDifficulty(to))
+  if (getDifficulty(from) > (getDifficulty(to)+5))
     {
       return true;
     }
 
-  // Difficulty is equal. Order based on the number of nodes.
-  vector<ASTNode> symbols;
-  visited.clear();
-  getVariables(from, symbols, visited);
-  int from_c = visited.size();
-
-  symbols.clear();
-  visited.clear();
-  getVariables(to, symbols, visited);
-  int to_c = visited.size();
-
   if (to_c < from_c)
     {
       return true;
@@ -524,8 +557,37 @@ getDifficulty(const ASTNode& n_)
   ToSATBase::ASTNodeToSATVar nodeToSATVar;
   toCNF.toCNF(BBFormula, cnfData, nodeToSATVar, false, nm);
 
+  // Send the clauses to Minisat, do unit propagation.
+  ///////////////
+
+  // Create a new sat variable for each of the variables in the CNF.
+  assert(ss->nVars() ==0);
+  for (int i = 0; i < cnfData->nVars ; i++)
+    ss->newVar();
+
+  SATSolver::vec_literals satSolverClause;
+  for (int i = 0; i < cnfData->nClauses; i++)
+    {
+      satSolverClause.clear();
+      for (int * pLit = cnfData->pClauses[i], *pStop = cnfData->pClauses[i
+          + 1]; pLit < pStop; pLit++)
+        {
+          SATSolver::Var var = (*pLit) >> 1;
+          assert ((var < ss->nVars()));
+          Minisat::Lit l = SATSolver::mkLit(var, (*pLit) & 1);
+          satSolverClause.push(l);
+        }
+
+      ss->addClause(satSolverClause);
+   }
+
+  ss->simplify();
+  assert (ss->okay()); // should be satisfiable.
+
   // Why we go to all this trouble. The number of clauses.
-  int score = cnfData->nClauses;
+  const int score = ss->nClauses();
+  assert(score <= cnfData->nClauses);
+  //////////////
 
   //Cnf_ClearMemory();
   Cnf_DataFree(cnfData);
@@ -586,6 +648,14 @@ void do_write_out(int ignore)
   force_writeout = true;
 }
 
+volatile bool debug_usr2 = false;
+
+//toggle.
+void do_usr2(int ignore)
+{
+  debug_usr2=!debug_usr2;
+}
+
 
 int
 startup()
@@ -609,13 +679,9 @@ startup()
 
   GlobalSTP = new STP(mgr, simp, at, tosat, abs);
 
-#ifndef NOTSIMPLIFYING_NF
-  nf = new SimplifyingNodeFactory(*(mgr->hashingNodeFactory), *mgr);
-  mgr->defaultNodeFactory = nf;
-#else
-  nf =  mgr->hashingNodeFactory;
-  mgr->defaultNodeFactory = mgr->hashingNodeFactory;
-#endif
+  simpNf = new SimplifyingNodeFactory(*(mgr->hashingNodeFactory), *mgr);
+  nf = new TypeChecker(*simpNf, *mgr);
+  mgr->defaultNodeFactory = simpNf;
 
   mgr->UserFlags.stats_flag = false;
   mgr->UserFlags.optimize_flag = true;
@@ -625,7 +691,7 @@ startup()
   // Prime the cache with 100..
   for (int i = 0; i < 100; i++)
     {
-      saved_array.push_back(ASTVec());
+      saved_array.push_back(new ASTVec());
     }
 
   zero = mgr->CreateZeroConst(bits);
@@ -643,6 +709,8 @@ startup()
   // Write out the work so far..
   signal(SIGUSR1,do_write_out);
 
+  signal(SIGUSR2,do_usr2);
+
 }
 
 void
@@ -807,6 +875,15 @@ findRewrites(ASTVec& expressions, const vector<VariableAssignment>& values, cons
       HASHMAP<uint64_t, ASTVec>::iterator it2;
 
       cout << "Split into " << map.size() << " pieces\n";
+      if (depth > 0)
+        {
+          if(map.size() ==1)
+            {
+              cerr << values[0].getV();
+              cerr << values[0].getW();
+              assert(false);
+            }
+        }
 
       int id = 1;
       for (it2 = map.begin(); it2 != map.end(); it2++)
@@ -820,7 +897,6 @@ findRewrites(ASTVec& expressions, const vector<VariableAssignment>& values, cons
     }
   ASTVec& equiv = expressions;
 
-
   // Sort so that constants, and smaller expressions will be checked first.
   std::sort(equiv.begin(), equiv.end(), lessThan);
 
@@ -837,39 +913,46 @@ findRewrites(ASTVec& expressions, const vector<VariableAssignment>& values, cons
           if (equiv[i].GetKind() == UNDEFINED || equiv[j].GetKind() == UNDEFINED)
             continue;
 
-          equiv[j] = rewrite_system.rewriteNode(equiv[j]);
-
-          ASTNode from = equiv[i];
-          ASTNode to = equiv[j];
-          bool r = orderEquivalence(from, to);
-
-          if (!r)
-          {
-            if (from != to)
-                cout << "can't be ordered" << from << to;
-            continue;
-          }
+          ASTNode& from = equiv[i];
+          ASTNode& to = equiv[j];
 
           VariableAssignment different;
           bool bad = false;
           const int st = getCurrentTime();
 
+          if (from.isConstant() && to.isConstant())
+            continue;
+
           if (checkRule(from, to, different, bad))
             {
-              cout << "Discovered a new rule.";
-              cout << from;
-              cout << to;
-              cout << getDifficulty(from) << " to " << getDifficulty(to) << endl;
-              cout << "After rewriting";
-              cout << rewrite_system.rewriteNode(from);
-              cout << rewrite_system.rewriteNode(to);
-              cout << "------";
+              equiv[i] = rewrite_system.rewriteNode(equiv[i]);
+              equiv[j] = rewrite_system.rewriteNode(equiv[j]);
 
-              Rewrite_rule rr(mgr, from, to, getCurrentTime() - st);
+              equiv[i] = rewriteThroughWithAIGS(equiv[i]);
+              equiv[j] = rewriteThroughWithAIGS(equiv[j]);
 
-              if (!rr.timedCheck(10000))
+              ASTNode& f = equiv[i];
+              ASTNode& t = equiv[j];
+
+              bool r = orderEquivalence(f, t);
+
+              if (!r)
                 continue;
 
+              cout << "Discovered a new rule.";
+              cout << f << t;
+              cout << getDifficulty(f) << " to " << getDifficulty(t) << endl;
+
+              Rewrite_rule rr(mgr, f, t, getCurrentTime() - st);
+
+              if (!rr.timedCheck(1000))
+                {
+                  cout << "Rule failed extended verification.";
+                  continue;
+                }
+              cout << "Verified Rule to: " << rr.getVerifiedToBits() << " bits" << endl;
+              cout << "------";
+
               rewrite_system.push_back(rr);
 
               // Remove the more difficult expression.
@@ -898,7 +981,7 @@ findRewrites(ASTVec& expressions, const vector<VariableAssignment>& values, cons
             }
 
           // Write out the rules intermitently.
-          if (force_writeout || lastOutput + 5000 < rewrite_system.size())
+          if (force_writeout || lastOutput + 500 < rewrite_system.size())
             {
               rewrite_system.rewriteAll();
               writeOutRules();
@@ -1038,6 +1121,8 @@ checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignme
   ASTVec children;
   children.push_back(from);
   children.push_back(to);
+
+  // The simplifying node factory sometimes meant it couldn't be widended.
   const ASTNode n = mgr->hashingNodeFactory->CreateNode(EQ, children);
 
   assert(widen_to > bits);
@@ -1056,7 +1141,7 @@ checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignme
         }
 
       // Send it to the SAT solver to verify that the widening has the same answer.
-      bool result = isConstant(widened, assignment);
+      bool result = isConstant(widened, assignment, bits);
 
       if (!result)
         {
@@ -1067,8 +1152,8 @@ checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignme
             }
 
           // Detected it's not a constant, or is constant FALSE.
-           cout << "*" << i - bits << "*";
 
+           cout << "*" << i - bits << "*";
           return false;
         }
     }
@@ -1286,6 +1371,7 @@ template<class T>
 void
 writeOutRules()
 {
+  cerr << "Writing out: " << rewrite_system.size() << " rules" << endl;
   force_writeout = false;
 
   std::vector<string> output;
@@ -1293,12 +1379,6 @@ writeOutRules()
 
   for (Rewrite_system::RewriteRuleContainer::iterator it = rewrite_system.toWrite.begin() ; it != rewrite_system.toWrite.end(); it++)
     {
-      if (!it->isOK())
-        {
-          rewrite_system.erase(it--);
-          continue;
-        }
-
       ASTNode to = it->getTo();
       ASTNode from = it->getFrom();
 
@@ -1344,22 +1424,19 @@ writeOutRules()
 
               if (dup.find(sofar) != dup.end())
                 {
-                  cout << "-----";
+                  cout << "-----Writing out has found a duplicate rule-----";
                   cout << sofar;
 
                   ASTNode f = it->getFrom();
-                  cout << f << std::endl;
-                  cout << dup.find(sofar)->second.getFrom();
+                  cout << "This:" << f << std::endl;
+                  cout << "Has the same text as this: " << dup.find(sofar)->second.getFrom();
+                  cout << "Rule " << it->getId() << " has the same text as " <<  dup.find(sofar)->second.getId() << endl;
 
                   ASTNodeMap fromTo;
-
-                  cerr << f;
                   f = renameVars(f);
-                  //cerr << "renamed" << f;
-                  bool result = unifyNode(f,dup.find(sofar)->second.getFrom(),fromTo,2) ;
-                  cout << "unified" << result << endl;
+                  bool result = commutative_matchNode(f,dup.find(sofar)->second.getFrom(),fromTo,2) ;
+                  cout << "Has it unified:" << result << endl;
                   ASTNodeMap seen;
-                  cout << rewrite(f,*it,seen );
 
                   // The text of this rule is the same as another rule.
                   rewrite_system.erase(it--);
@@ -1381,8 +1458,13 @@ writeOutRules()
   bucket("n.GetKind() ==", output, buckets);
 
   ofstream outputFile;
+
+  // Because we output the difficulty (i.e. the number of CNF clauses),
+  // this is very slow.
+  #ifdef OUTPUT_CPP_RULES
   outputFile.open("rewrite_data_new.cpp", ios::trunc);
 
+
   // output the C++ code.
   hash_map<string, vector<string>, hashF<std::string> >::const_iterator it;
   for (it = buckets.begin(); it != buckets.end(); it++)
@@ -1396,6 +1478,7 @@ writeOutRules()
       outputFile << "}" << endl;
     }
   outputFile.close();
+  #endif
 
   ///////////////
   outputFile.open("rules_new.smt2", ios::trunc);
@@ -1429,40 +1512,44 @@ rename_then_rewrite(ASTNode n, const Rewrite_rule& original_rule)
 {
   n = renameVars(n);
   ASTNodeMap seen;
-  n = rewrite(n,original_rule,seen);
+  n = rewrite(n,original_rule,seen,0);
   return renameVarsBack(n);
 }
 
 // assumes the variables in n are two characters wide.
 ASTNode
-rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen)
+rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen, int depth)
 {
-  if (n.isAtom())
+  if (depth > 50)
     return n;
 
-  // nb. won't rewrite through EQ etc.
-  if (n.GetType() != BITVECTOR_TYPE)
+  if (n.isAtom())
     return n;
 
   ASTVec v;
+  v.reserve(n.Degree());
   for (int i = 0; i < n.Degree(); i++)
-    v.push_back(rewrite(n[i],original_rule,seen));
+    v.push_back(rewrite(n[i],original_rule,seen,depth+1));
 
   assert(v.size() > 0);
   ASTNode n2;
 
   if (v!=n.GetChildren())
-    n2 = mgr->CreateTerm(n.GetKind(), n.GetValueWidth(), v);
+    {
+      if (n.GetType() != BOOLEAN_TYPE)
+        n2 = mgr->CreateArrayTerm(n.GetKind(),n.GetIndexWidth(), n.GetValueWidth(), v);
+      else
+        n2 = mgr->CreateNode(n.GetKind(), v);
+    }
   else
     n2 = n;
 
   ASTNodeMap fromTo;
 
-  vector<Rewrite_rule>& rr =
-      n[0].Degree() > 0 ?
-      (rewrite_system.kind_kind_to_rr[n.GetKind()][n[0].GetKind()]) :
-      (rewrite_system.kind_to_rr[n.GetKind()]) ;
+  if (rewrite_system.lookups_invalid)
+    rewrite_system.buildLookupTable();
 
+  vector<Rewrite_rule>& rr = rewrite_system.kind_to_rr[n.GetKind()];
 
   for (int i = 0; i < rr.size(); i++)
     {
@@ -1475,42 +1562,52 @@ rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen)
 
       const ASTNode& f = rr[i].getFrom();
 
-      if (unifyNode(f, n2, fromTo,1))
+      if (commutative_matchNode(f, n2, fromTo, 1))
         {
-          /*
-          cerr << "Unifying" << f;
-          cerr << "with:" << n2;
+          if (debug_usr2)
+            {
+              cerr << "Original Rule(" << original_rule.getId() << ")";
+              cerr << original_rule.getFrom();
+              cerr << "->" << original_rule.getTo();
 
-          cerr << "Now" << rr[i].getTo();
-          cerr << "reducing rule" << rr[i].getN();
+              cerr << "Matching Rule(" << rr[i].getId() << ")";
+              cerr << rr[i].getFrom();
+              cerr << "->" << rr[i].getTo();
 
-          for (ASTNodeMap::iterator it = fromTo.begin(); it != fromTo.end(); it++)
-            {
-              cerr << it->first << "to" << it->second << endl;
-            }
+              cerr << "--------------";
+              cerr << "Unifying" << f;
+              cerr << "with:" << n2;
+              cerr << "--------------";
 
-          cerr << "--------------";
-           */
+              for (ASTNodeMap::iterator it = fromTo.begin(); it != fromTo.end(); it++)
+                {
+                  cerr << it->first << "to" << it->second << endl;
+                }
+
+              cerr << "--------------";
+            }
 
           if (seen.find(n) != seen.end())
             return seen.find(n)->second;
 
-          seen.insert(make_pair(n,rr[i].getTo()));
+          seen.insert(make_pair(n, rr[i].getTo()));
           ASTNodeMap cache;
-          ASTNode r SubstitutionMap::replace(rr[i].getTo(), fromTo, cache, nf, false, true);
+          ASTNode r = SubstitutionMap::replace(rr[i].getTo(), fromTo, cache, nf, false, true);
           seen.erase(n);
 
-          seen.insert(make_pair(n,r));
-          r= rewrite(r,original_rule,seen);
+          seen.insert(make_pair(n, r));
+          r = rewrite(r, original_rule, seen, depth + 1);
           seen.erase(n);
-          seen.insert(make_pair(n,r));
+          seen.insert(make_pair(n, r));
           return r;
+
         }
     }
 
   return n2;
 }
 
+
 int smt2_scan_string(const char *yy_str);
 
 // http://stackoverflow.com/questions/3418231/c-replace-part-of-a-string-with-another-string
@@ -1524,11 +1621,10 @@ bool replace(std::string& str, const std::string& from, const std::string& to) {
 
 
 void
-loadNewRules()
+load_new_rules(const string fileName = "rules_new.smt2")
 {
   FILE * in;
   bool opended= false;
-  string fileName = "rules_new.smt2"
 
   if(!ifstream(fileName.c_str())) /// use stdin if the default file is not found.
     in = stdin;
@@ -1538,7 +1634,7 @@ loadNewRules()
       opended = true; // so we know to fclose it.
     }
 
-  // We store references to "v" and "w", so we need to removed the
+  // We store references to "v" and "w", so we need to remove the
   // definitions from the input we parse.
 
   v = mgr->LookupOrCreateSymbol("v");
@@ -1576,6 +1672,7 @@ loadNewRules()
             int rv = sscanf(line, ";id:%d\tverified_to:%d\ttime:%d\tfrom_difficulty:%d\tto_difficulty:%d\n", &id, &verified_to_bits, &time_used, &from_v, &to_v);
             if (rv !=5)
               {
+                cerr << line;
                 done = true;
                 break;
               }
@@ -1602,8 +1699,11 @@ loadNewRules()
 
       assert(values.size() ==1);
 
-      ASTNode from = values[0][0];
-      ASTNode to = values[0][1];
+      // The nodes have been built with the hashing node factory.
+      // In practice we want to match nodes that are created with the simplifying NF.
+      // If we enabled the simplifying NF, the EQUALS checks would remove rules we want to keep.
+      ASTNode from = withNF(values[0][0]);
+      ASTNode to = withNF(values[0][1]);
 
       // Rule should be orderable.
       bool ok = orderEquivalence(from, to);
@@ -1612,42 +1712,79 @@ loadNewRules()
           cout << "discarding rule that can't be ordered";
           cout << from << to;
           cout << "----";
+          mgr->PopQuery();
+          parserInterface->popToFirstLevel();
           continue;
         }
 
       Rewrite_rule r(mgr, from, to, 0, id);
       r.setVerified(verified_to_bits,time_used);
 
-      assert(r.isOK());
       rewrite_system.push_back(r);
 
       mgr->PopQuery();
       parserInterface->popToFirstLevel();
     }
 
+  extern int smt2lex_destroy(void);
+  smt2lex_destroy();
+
   parserInterface->cleanUp();
   if (opended)
-   fclose(in);
+    {
+      cout << "New Style Rules Loaded:" << rewrite_system.size() << endl;
+      fclose(in);
+    }
+
+  // So we don't output as soon as one is discovered...
+  lastOutput = rewrite_system.size();
 
-  cout << "New Style Rules Loaded:" << rewrite_system.size() << endl;
 }
 
-//read from stdin, then tests it until the timeout.
+//Reads in new format rules. And tests each of them for the allocated time.
 void
-expandRules(int timeout_ms)
+expandRules(int timeout_ms, const char* fileName = "")
 {
-  loadNewRules();
+  load_new_rules(fileName);
   createVariables();
 
   for (Rewrite_system::RewriteRuleContainer::iterator it = rewrite_system.begin(); it!= rewrite_system.end();it++)
     {
       if ((*it).timedCheck(timeout_ms))
-        it->writeOut(cout); // omit failed.
+        {
+          it->writeOut(cout); // omit failed.
+          cerr  << getDifficulty(it->getFrom()) <<" " <<  getDifficulty(it->getTo());
+        }
     }
 }
 
+
+void t2()
+{
+  extern FILE *smt2in;
+
+  smt2in = fopen("big_array.smt2", "r");
+  TypeChecker nfTypeCheckDefault(*mgr->hashingNodeFactory, *mgr);
+  Cpp_interface piTypeCheckDefault(*mgr, &nfTypeCheckDefault);
+  parserInterface = &piTypeCheckDefault;
+
+  mgr->GetRunTimes()->start(RunTimes::Parsing);
+  smt2parse();
+
+  ASTVec v = FlattenKind(AND,  piTypeCheckDefault.GetAsserts());
+  ASTNode n = nf->CreateNode(XOR, v);
+
+  //rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen)
+  ASTNodeMap seen;
+  createVariables();
+  ASTNode r = rename_then_rewrite(n,Rewrite_rule::getNullRule());
+  cerr << r;
+
+
+}
+
 // loads the already existing rules.
-void loadExistingRules(string fileName)
+void load_old_rules(string fileName)
 {
   if(!ifstream(fileName.c_str()))
      return; // file doesn't exist.
@@ -1693,8 +1830,7 @@ void loadExistingRules(string fileName)
 
       Rewrite_rule r(mgr, from, to, 0);
 
-      if (r.isOK());
-        rewrite_system.push_back(r);
+      rewrite_system.push_back(r);
     }
 
   mgr->PopQuery();
@@ -1733,7 +1869,7 @@ testProps()
 int test()
 {
   // Test code.
-  loadExistingRules("test.smt2");
+  load_old_rules("test.smt2");
 
   v = mgr->LookupOrCreateSymbol("v");
   v.SetValueWidth(bits);
@@ -1768,14 +1904,43 @@ createVariables()
   w0.SetValueWidth(bits);
 }
 
+void unit_test()
+{
+
+  // Create the negation and not terms in different orders. This tests the commutative matching.
+  ASTVec c;
+  c.push_back(v);
+  ASTNode not_v = create(BEEV::BVNEG, c);
+  ASTNode neg_v = create(BEEV::BVUMINUS, c);
+
+  ASTNode plus_v = create(BVPLUS,not_v,neg_v);
+
+  c.clear();
+  c.push_back(w);
+  ASTNode neg_w = create(BEEV::BVUMINUS, c);
+  ASTNode not_w = create(BEEV::BVNEG, c);
+  ASTNode plus_w = create(BVPLUS,not_w,neg_w);
+
+  ASTNodeMap sub;
+  plus_w = renameVars(plus_w);
+  assert(commutative_matchNode(plus_w,plus_v,sub,2));
+  sub.clear();
+
+  assert(commutative_matchNode(plus_v,plus_w,sub,1));
+
+
+}
+
+
 int
 main(int argc, const char* argv[])
 {
   startup();
 
+
   if (argc == 1) // Read the current rule set, find new rules.
     {
-        loadExistingRules("array.smt2");
+        load_new_rules();
         createVariables();
         ////////////
         rewrite_system.buildLookupTable();
@@ -1795,33 +1960,54 @@ main(int argc, const char* argv[])
         rewrite_system.rewriteAll();
         writeOutRules();
     }
-  else if (argc == 3 && !strcmp("expand",argv[1])) // expand the bit-widths rules are tested at.
+  else if (argc ==2 && !strcmp("unit-test",argv[1]))
     {
+      load_new_rules();
+      createVariables();
+      unit_test();
+    }
+  else if (argc ==2 && !strcmp("verify",argv[1]))
+      {
+      load_new_rules();
+      rewrite_system.verifyAllwithSAT();
+      }
+  else if ((argc == 4 || argc ==3) && !strcmp("expand",argv[1]))
+    {
+      // expand the bit-widths rules are tested at.
       int timeout_ms = atoi(argv[2]);
       assert(timeout_ms > 0);
-      expandRules(timeout_ms);
+      expandRules(timeout_ms, (argc == 4 ? argv[3]:""));
     }
-  else if (argc == 2 && !strcmp("verify-all",argv[1]))
+  else if (argc == 2 && !strcmp("rewrite",argv[1]))
     {
-      loadNewRules();
+      // load the rules and apply the rewrite system to itself.
+      load_new_rules();
       createVariables();
-      rewrite_system.verifyAllwithSAT();
-      writeOutRules(); // have the times now..
+      rewrite_system.rewriteAll();
+      writeOutRules();
     }
   else if (argc == 2 && !strcmp("write-out",argv[1]))
     {
-      loadNewRules();
+      load_new_rules();
       createVariables();
       rewrite_system.rewriteAll();
       writeOutRules(); // have the times now..
     }
+  else if (argc == 2 && !strcmp("renumber",argv[1]))
+    {
+      // Intended to merge two sets of rules, then renumber them.
+      load_new_rules();
+      createVariables();
+      rewrite_system.renumber();
+      writeOutRules();
+    }
   else if (argc == 2 && !strcmp("test",argv[1]))
     {
       testProps();
     }
   else if (argc == 2 && !strcmp("delete-failed",argv[1]))
     {
-      loadNewRules();
+      load_new_rules();
       ifstream fin;
       fin.open("failed.txt",ios::in);
       char line[256];
@@ -1836,13 +2022,20 @@ main(int argc, const char* argv[])
       createVariables();
       writeOutRules();
     }
-}
-
+  else if (argc == 2 && !strcmp("test2",argv[1]))
+    {
+      load_new_rules();
+      t2();
+    }
 
+  for (int i=0;  i< saved_array.size();i++)
+    delete saved_array[i];
+}
 
+#if 0
 // Term variables have a specified width!!!
 bool
-unifyNode(const ASTNode& n0, const ASTNode& n1,  ASTNodeMap& fromTo, const int term_variable_width)
+matchNode(const ASTNode& n0, const ASTNode& n1,  ASTNodeMap& fromTo, const int term_variable_width)
 {
   // Pointers to the same value. OK.
   if (n0 == n1)
@@ -1851,10 +2044,10 @@ unifyNode(const ASTNode& n0, const ASTNode& n1,  ASTNodeMap& fromTo, const int t
   if (n0.GetKind() == SYMBOL && strlen(n0.GetName()) == term_variable_width)
     {
       if (fromTo.find(n0) != fromTo.end())
-        return unifyNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width);
+        return matchNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width);
 
       fromTo.insert(make_pair(n0, n1));
-      return unifyNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width);
+      return matchNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width);
     }
 
   // Here:
@@ -1869,9 +2062,319 @@ unifyNode(const ASTNode& n0, const ASTNode& n1,  ASTNodeMap& fromTo, const int t
 
   for (int i = 0; i < n0.Degree(); i++)
     {
-      if (!unifyNode(n0[i], n1[i], fromTo, term_variable_width))
-        return false;
+       if (!matchNode(n0[i], n1[i], fromTo, term_variable_width))
+          return false;
     }
 
   return true;
 }
+#endif
+
+
+bool debug_matching = false;
+
+/////////
+// Term variables have a specified width!!!
+// "false" if it definately can't be matched with any possible commutative ordering.
+// "true" can be matched, later you need to check if all the "commutative" can be matched.
+bool
+commutative_matchNode(const ASTNode& n0, const ASTNode& n1,   const int term_variable_width,
+    deque<pair< ASTNode, ASTNode> >&  commutative, ASTNode& vNode, ASTNode& wNode)
+{
+  // Pointers to the same value. OK.
+  if (n0 == n1)
+    return true;
+
+  // If we try and match sub-terms of concatenations,e,g. 000::x = 000111, we want it to fail.
+  if(n0.GetValueWidth() != n1.GetValueWidth())
+    return false;
+
+  if (n0.GetKind() == SYMBOL && strlen(n0.GetName()) == term_variable_width)
+    {
+      if (n0.GetName()[0] == 'v')
+        {
+          if (vNode != mgr->ASTUndefined)
+            return commutative_matchNode(vNode, n1, term_variable_width, commutative, vNode, wNode);
+          else
+            {
+              vNode = n1;
+              return true;
+            }
+        }
+      else if (n0.GetName()[0] == 'w')
+        {
+          if (wNode != mgr->ASTUndefined)
+            return commutative_matchNode(wNode, n1, term_variable_width, commutative, vNode, wNode);
+          else
+            {
+              wNode = n1;
+              return true;
+            }
+        }
+      else
+        FatalError("nefeafs");
+    }
+
+  // Here:
+  // They could be different BVConsts, different symbols, or
+  // different functions.
+
+  if (n0.Degree() != n1.Degree() || (n0.Degree() == 0))
+      return false;
+
+  if (n0.GetKind() != n1.GetKind())
+      return false;
+
+  // If it's commutative, check it specially / seprately later.
+  if (isCommutative(n0.GetKind()) && n0.Degree() > 1)
+    {
+      commutative.push_back(make_pair(n0,n1));
+      return true;
+    }
+  else
+    {
+    for (int i = 0; i < n0.Degree(); i++)
+      {
+         if (!commutative_matchNode(n0[i], n1[i], term_variable_width,commutative,vNode,wNode))
+            return false;
+      }
+    }
+  return true;
+}
+
+//
+// Term variables have a specified width!!!
+bool
+c_matchNode(const ASTNode& n0, const ASTNode& n1, const int term_variable_width,
+    deque<pair<ASTNode, ASTNode> >& commutative_to_check, ASTNode& vNode, ASTNode& wNode)
+{
+  ASTNode vNode_copy = vNode;
+  ASTNode wNode_copy = wNode;
+
+  const int init_comm_size = commutative_to_check.size();
+
+  bool r = commutative_matchNode(n0, n1, term_variable_width, commutative_to_check, vNode,wNode);
+  assert(commutative_to_check.size() >= init_comm_size);  // if anything, only pushed onto the back.
+
+  if (debug_matching)
+    {
+      cerr << "======Commut-match=======" << r << endl;
+      cerr << "given" << n0 << n1;
+      cerr << "Commutative still to match:" << endl;
+      for (int j=0;j < commutative_to_check.size(); j++)
+        {
+          cerr << "++++++++++" << endl;
+          cerr << "first" << commutative_to_check[j].first;
+          cerr << "second" << commutative_to_check[j].second;
+        }
+      cerr << "From To Map is:" << endl;
+      cerr << "vNode" << vNode;
+      cerr << "wNode" << wNode;
+      cerr << "=============";
+    }
+
+  if (!r)
+    {
+      // If it's bad we restore it all back.
+      commutative_to_check.erase(commutative_to_check.begin() + init_comm_size, commutative_to_check.end());
+      vNode = vNode_copy;
+      wNode = wNode_copy;
+      return false;
+    }
+
+  // base case.
+  if (commutative_to_check.size() == 0)
+    return r;
+
+  pair<ASTNode, ASTNode> p = commutative_to_check.back();
+  commutative_to_check.pop_back();
+  assert(p.first.GetKind() == p.second.GetKind());
+  const ASTVec& f = p.first.GetChildren();
+  ASTVec s = p.second.GetChildren(); // non-const, needs to be sorted later.
+
+  if (f.size() != s.size())
+    {
+      cerr << "different sized!!!";
+      // If it's bad we restore it all back.
+      if (commutative_to_check.size() < init_comm_size)
+        commutative_to_check.push_back(p);
+      else
+        commutative_to_check.erase(commutative_to_check.begin() + init_comm_size, commutative_to_check.end());
+
+      vNode = vNode_copy;
+      wNode = wNode_copy;
+
+      return false;
+    }
+
+  // The next_permutation function requires this.
+  sort(s.begin(), s.end());
+
+  ASTNode vNode_copy2 = vNode;
+  ASTNode wNode_copy2 = wNode;
+
+  //deque<pair<ASTNode, ASTNode> > todo_copy2 = commutative_to_check;
+  const int new_comm_size = commutative_to_check.size();
+
+
+  // Check each permutation of the commutative operation's operands.
+  do
+    {
+      // Check each of the operands matches. Store Extra away in "commutative_to_check".
+      bool good= true;
+      for (int i=0;i < f.size(); i++)
+        {
+          if (!commutative_matchNode(f[i], s[i], term_variable_width, commutative_to_check, vNode, wNode))
+            {
+              good = false;
+              break;
+            }
+        }
+
+      // Empty out the "commutative_to_check".
+      if (good)
+        if (!c_matchNode(mgr->ASTTrue, mgr->ASTTrue,  term_variable_width, commutative_to_check,vNode,wNode))
+          good =false;
+
+      if (good)
+        {
+         assert(commutative_to_check.size() ==0);
+         return true; // all match.
+       }
+      else
+        {
+          vNode = vNode_copy2;
+          wNode = wNode_copy2;
+          commutative_to_check.erase(commutative_to_check.begin() + new_comm_size, commutative_to_check.end());
+          //assert(commutative_to_check == todo_copy2);
+          //commutative_to_check = todo_copy2;
+        }
+    }
+  while (next_permutation(s.begin(), s.end()));
+
+  // None of the permutations match. We return the data unchanged.
+
+  vNode = vNode_copy;
+  wNode = wNode_copy;
+
+
+  if (commutative_to_check.size() < init_comm_size)
+    commutative_to_check.push_back(p);
+  else
+    commutative_to_check.erase(commutative_to_check.begin() + init_comm_size, commutative_to_check.end());
+
+
+  return false;
+}
+
+/* This does commutative matching of nodes. A substitution to the term variables (which are the
+ * those with a name of the width specified), of n0 is found. That is if the variables of n0 are
+ * substituted with the "substitution", then it will equal n1.
+ *
+ * Initially I thought commutative matching was easy to get right. NO!
+ *
+ * NB: This uses a "statc" container so this can't be called recursively.
+ */
+bool
+commutative_matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& substitution, const int term_variable_width)
+{
+  assert(substitution.size() ==0);
+
+#ifdef PEDANTIC_MATCHING_ASSERTS
+  {
+  // There shouldn't be any term variables on the RHS.
+  vector<ASTNode> vars = getVariables(n1);
+  vector<ASTNode>::iterator it = vars.begin();
+  while (it != vars.end())
+    {
+      assert(strlen(it->GetName()) != term_variable_width);
+      assert(it->GetName()[0] == 'v' || it->GetName()[0] == 'w');
+      it++;
+    }
+  assert(vars.size() <=2);
+
+  // All the LHS variables should be term variables.
+  vars = getVariables(n0);
+  it = vars.begin();
+  while (it != vars.end())
+    {
+      assert(strlen(it->GetName()) == term_variable_width);
+      it++;
+    }
+  assert(vars.size() <=2);
+  }
+
+
+#endif
+
+  deque<pair<ASTNode, ASTNode> > commutative;
+  commutative.clear();
+
+  ASTNode vNode = mgr->ASTUndefined;
+  ASTNode wNode = mgr->ASTUndefined;
+  bool r =  c_matchNode(n0,n1,term_variable_width,commutative,vNode,wNode);
+
+  if (r)
+    {
+      vector<ASTNode> s = getVariables(n0);
+      for (vector<ASTNode>::iterator it = s.begin(); it != s.end();it++)
+        {
+          assert(it->GetKind() ==SYMBOL);
+          assert(strlen(it->GetName()) == term_variable_width);
+          if (it->GetName()[0] == 'v')
+            {
+              assert(vNode != mgr->ASTUndefined);
+              assert(vNode.GetValueWidth() == it->GetValueWidth());
+              substitution.insert(make_pair(*it,vNode));
+            }
+          if (it->GetName()[0] == 'w')
+            {
+              assert(wNode != mgr->ASTUndefined);
+              assert(wNode.GetValueWidth() == it->GetValueWidth());
+              substitution.insert(make_pair(*it,wNode));
+            }
+        }
+    }
+
+  if (debug_matching)
+    {
+      cerr << "=======" << endl << "The result is: " << r << "for the inputs" << n0 << n1 << "=-===";
+    }
+
+  if (!r)
+    {
+      assert(substitution.size() == 0);
+      assert(commutative.size() ==0);  // should be none left to process.
+    }
+
+  return r;
+}
+
+ASTNode
+rewriteThroughWithAIGS(const ASTNode &n_)
+{
+  assert(n_.GetType() == BITVECTOR_TYPE);
+  ASTNode f = mgr->LookupOrCreateSymbol("rewriteThroughWithAIGS");
+  f.SetValueWidth(n_.GetValueWidth());
+  ASTNode n = create(EQ, n_, f);
+
+  BBNodeManagerAIG nm;
+  BitBlaster<BBNodeAIG, BBNodeManagerAIG> bb(&nm, simp, mgr->defaultNodeFactory, &mgr->UserFlags);
+  ASTNodeMap fromTo;
+  ASTNodeMap equivs;
+  bb.getConsts(n, fromTo, equivs);
+
+  ASTNode result = n_;
+  if (equivs.size() > 0)
+    {
+      ASTNodeMap cache;
+      result = SubstitutionMap::replace(result, equivs, cache, nf, false, true);
+    }
+
+  if (fromTo.size() > 0)
+    {
+      ASTNodeMap cache;
+      result = SubstitutionMap::replace(result, fromTo, cache, nf);
+    }
+  return result;
+}
index b97fb0b3220a269b6d91104312796b44eeac47b5..cab1845287c780b533c57e3c360dd7011d42eb22 100644 (file)
@@ -2,15 +2,7 @@
 #define REWRITERULE_H
 
 #include "../../STPManager/STPManager.h"
-
-extern const int widen_to;
-extern const int bits;
-
-ASTNode
-widen(const ASTNode& w, int width);
-
-int
-getDifficulty(const ASTNode& n_);
+#include "misc.h"
 
 void soft_time_out(int ignored)
 {
@@ -18,37 +10,46 @@ void soft_time_out(int ignored)
 }
 
 bool
-isConstant(const ASTNode& n, VariableAssignment& different);
+orderEquivalence(ASTNode& from, ASTNode& to);
 
 
-vector<ASTNode>
-getVariables(const ASTNode& n);
-
 class Rewrite_rule
 {
-private:
    ASTNode from;
    ASTNode to;
    ASTNode n;
 
-   int id;
    static int  static_id;
 
    int time_to_verify;
    int verified_to_bits;
 
+   // Only used to build the NULL rule
+   Rewrite_rule()
+   {
+     from = mgr->CreateZeroConst(1);
+     to = mgr->CreateZeroConst(1);
+     n = mgr->ASTTrue;
+   }
+
 public:
 
-  void writeOut(ostream& outputFileSMT2)
+   static Rewrite_rule
+   getNullRule()
+   {
+     return Rewrite_rule();
+   }
+
+   int id;
+  void writeOut(ostream& outputFileSMT2) const
   {
-    assert(isOK());
     outputFileSMT2 << ";id:" << getId()
         << "\tverified_to:" << verified_to_bits << "\ttime:" << getTime()
         << "\tfrom_difficulty:" << getDifficulty(getFrom())
         << "\tto_difficulty:"   << getDifficulty(getTo())
         << "\n";
     outputFileSMT2 << "(push 1)" << endl;
-    printer::SMTLIB2_PrintBack(outputFileSMT2, getN(), true, false);
+    printer::SMTLIB2_PrintBack(outputFileSMT2, getN(), true);
     outputFileSMT2 << "(exit)" << endl;
   }
 
@@ -64,7 +65,7 @@ public:
   }
 
   int
-  getVerifiedToBits()
+  getVerifiedToBits() const
   {
     return verified_to_bits;
   }
@@ -112,27 +113,7 @@ public:
     return (n == t.n);
   }
 
-  bool
-  isOK()
-  {
-    ASTNode w = widen(getN(), widen_to);
-
-    if  (w.IsNull() || w.GetKind() == UNDEFINED)
-      return false;
-
-    assert(BVTypeCheckRecursive(n));
-    assert(BVTypeCheckRecursive(w));
-
-    if (from.isAtom() && to.isAtom())
-        return false;
-
-    if (from == to)
-      return false;
-
-    return true;
-
-  }
-
+  // The "from" and "to" should be ordered with the orderEquivalence function.
   Rewrite_rule(BEEV::STPMgr* bm, const BEEV::ASTNode& from_, const BEEV::ASTNode& to_, const int t, int _id=-1)
   : from(from_), to(to_)
     {
@@ -151,49 +132,10 @@ public:
       c.push_back(from_);
       n =  bm->hashingNodeFactory->CreateNode(BEEV::EQ,c);
 
-
-      ////
-      assert(!from.IsNull());
-      assert(from.GetKind() != UNDEFINED);
-
-      ////
-      assert(!to.IsNull());
-      assert(to.GetKind() != UNDEFINED);
-
-      ////
-      assert(!n.IsNull());
-      assert(n.GetKind() != UNDEFINED);
-      ////
-
-      if (from.GetKind() == SYMBOL)
-        {
-          assert(to == from); // If it's a symbol. It should be the same one.
-        }
-
-      if (from.isAtom())
-        {
-          assert(to.isAtom()); // sometimes its easiest to make it 0->0 rather than deleting it.
-        }
-
-      // only v or w
-      vector<ASTNode> s_from= getVariables(from);
-      for (vector<ASTNode>::iterator it = s_from.begin(); it != s_from.end() ;it++)
-        {
-          assert(strlen(it->GetName()) ==1);
-          assert(it->GetName()[0] =='v' || it->GetName()[0] =='w');
-          assert(it->GetValueWidth() == bits);
-        }
-
-      vector<ASTNode> s_to= getVariables(to);
-      for (vector<ASTNode>::iterator it = s_to.begin(); it != s_to.end() ;it++)
-        {
-          assert(strlen(it->GetName()) ==1);
-          assert(it->GetName()[0] =='v' || it->GetName()[0] =='w');
-          assert(it->GetValueWidth() == bits);
-        }
-
-      // The "to" side should have fewer nodes.
-      assert(s_from.size() >= s_to.size());
+      assert(orderEquivalence(from,to));
+      assert(from == from_);
+      assert(to == to_);
+      assert(BVTypeCheckRecursive(n));
     }
 
   bool
@@ -236,12 +178,16 @@ public:
             cerr << from << to;
           }
 
-        bool result = isConstant(widened, assignment);
+        bool result = isConstant(widened, assignment,i);
         if (!result && !mgr->soft_timeout_expired)
           {
             // not a constant, and not timed out!
             cerr << "FAILED:" << getId() << endl << i << from << to;
             writeOut(cerr);
+
+            // The timer might not have expired yet.
+            setitimer(ITIMER_VIRTUAL, NULL, NULL);
+            mgr->soft_timeout_expired = false;
             return false;
           }
         if (mgr->soft_timeout_expired)
@@ -253,6 +199,9 @@ public:
     if (getVerifiedToBits() <= checked_to)
       setVerified(checked_to, getTime() + (getCurrentTime() - st));
 
+    // The timer might not have expired yet.
+    setitimer(ITIMER_VIRTUAL, NULL, NULL);
+    mgr->soft_timeout_expired = false;
     return true;
   }
 
index 7c9c6c4709537245dab2c75552dbed19121e0188..41c3a831cd070fe6c3f0a659d46bdac2d46a8054 100644 (file)
@@ -23,13 +23,17 @@ ASTNode
 widen(const ASTNode& w, int width);
 
 bool
-unifyNode(const ASTNode& n0, const ASTNode& n1,  ASTNodeMap& fromTo);
+matchNode(const ASTNode& n0, const ASTNode& n1,  ASTNodeMap& fromTo, const int term_variable_width);
+
+bool
+commutative_matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width);
+
 
 ASTNode
 renameVars(const ASTNode &n);
 
 ASTNode
-rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen);
+rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen, int depth);
 
 bool
 checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignment, bool&bad);
@@ -46,6 +50,10 @@ isConstantToSat(const ASTNode & query);
 bool
 isConstant(const ASTNode& n, VariableAssignment& different);
 
+ASTNode
+rewriteThroughWithAIGS(const ASTNode &n_);
+
+
 class Rewrite_system
 {
 public:
@@ -59,17 +67,27 @@ private:
   void
   writeOutRules();
 
-
-  friend ASTNode rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen);
+  friend ASTNode rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen, int depth);
 
   RewriteRuleContainer toWrite;
+
   std::map< Kind, vector<Rewrite_rule> > kind_to_rr;
-  std::map< Kind, std::map< Kind, vector<Rewrite_rule> > > kind_kind_to_rr;
+   bool lookups_invalid; // whether the above table is bad.
+
+  void
+  addRuleToLookup(Rewrite_rule& r)
+  {
+    const ASTNode& from = r.getFrom();
+    kind_to_rr[from.GetKind()].push_back(r);
+    assert(from.Degree() > 0); // Shouldn't map from a constant, nor from a variable.
+  }
+
 
 public:
 
   Rewrite_system()
   {
+    lookups_invalid = false;
   }
 
   RewriteRuleContainer::iterator
@@ -84,17 +102,18 @@ public:
     return toWrite.end();
   }
 
+  bool areLookupsGood()
+  {
+    return lookups_invalid;
+  }
 
   void
-  addRuleToLookup(Rewrite_rule& r)
+  renumber()
   {
-    const ASTNode& from = r.getFrom();
-    kind_to_rr[from.GetKind()].push_back(r);
-
-    assert(from.Degree() > 0); // Shouldn't map from a constant, nor from a variable.
-
-    if (from[0].Degree() > 0)
-      kind_kind_to_rr[from.GetKind()][from[0].GetKind()].push_back(r);
+    int id=0;
+    for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++)
+      it->id = id++;
+    lookups_invalid=true;
   }
 
   void
@@ -106,6 +125,7 @@ public:
           {
             toWrite.erase(it--);
             cerr << "matched" << id;
+            lookups_invalid = true;
           }
       }
   }
@@ -114,41 +134,26 @@ public:
   buildLookupTable()
   {
     kind_to_rr.clear();
-    kind_kind_to_rr.clear();
 
     for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++)
       {
         addRuleToLookup(*it);
       }
-  }
-
-  void
-  removeBad()
-  {
-    for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++)
-      {
-        if (!it->isOK())
-          {
-            cout << "Removing Rule that is bad";
-            cout << it->getFrom();
-            cout << it->getTo();
-            cout << "----\n";
-
-            toWrite.erase(it--);
-          }
-      }
+    lookups_invalid =false;
   }
 
   void
   eraseDuplicates()
   {
-    removeBad();
     toWrite.sort();
     toWrite.unique();
+    lookups_invalid = true;
   }
 
+  // Rewrite the "from" and "To", add the rule if it's still good.
+  // NB: Doesn't rebuild the lookup table.
   void
-  push_back(Rewrite_rule rr)
+  push_back(Rewrite_rule& rr)
   {
     toWrite.push_back(rr);
     addRuleToLookup(rr);
@@ -158,6 +163,7 @@ public:
   erase(RewriteRuleContainer::iterator it)
   {
     toWrite.erase(it);
+    lookups_invalid = true;
   }
 
   int
@@ -166,10 +172,9 @@ public:
     return toWrite.size();
   }
 
-  ASTNode rewriteNode(ASTNode n)
+  static ASTNode rewriteNode(ASTNode n)
   {
-    Rewrite_rule  null_rule( Rewrite_rule(mgr,mgr->CreateZeroConst(1), mgr->CreateZeroConst(1),0));
-    return rename_then_rewrite(n,null_rule);
+    return rename_then_rewrite(n,Rewrite_rule::getNullRule());
   }
 
   void
@@ -186,76 +191,89 @@ public:
         if (i % 1000 == 0)
           cout << "rewrite all:" << i << " of " << toWrite.size() << endl;
 
-        // if not OK, should have been removed during duplicates.
-        // shouldn't add extra rules that aren't ok.
-        assert (it->isOK());
+        ASTNode from_wide = renameVars(it->getFrom());
+        ASTNode to_wide = renameVars(it->getTo());
+
+        // The renamed should match the original, and vice versa.
+          {
+            ASTNodeMap fromTo;
+            assert(commutative_matchNode(from_wide, it->getFrom(),fromTo,2));
+            fromTo.clear();
+            assert(commutative_matchNode(it->getFrom(), from_wide, fromTo, 1));
+            fromTo.clear();
+            assert(commutative_matchNode(to_wide, it->getTo(),fromTo,2));
+            fromTo.clear();
+            assert(commutative_matchNode(it->getTo(),to_wide, fromTo,1));
+          }
 
-        ASTNode n = renameVars(it->getFrom());
         ASTNodeMap seen;
-        ASTNode rewritten_from = rewrite(n, *it,seen);
+        ASTNode from_wide_rewritten = rewrite(from_wide, *it,seen,0);
+        seen = ASTNodeMap();
+        ASTNode to_wide_rewritten = rewrite(to_wide, *it,seen,0);
+        seen = ASTNodeMap();
+
+        // Also apply the AIG rules.
+        to_wide_rewritten = rewriteThroughWithAIGS(to_wide_rewritten);
 
-        if (n != rewritten_from)
+        if ((from_wide != from_wide_rewritten) || (to_wide != to_wide_rewritten))
           {
-            assert (isConstantToSat(create(EQ, rewritten_from,n)));
+            ASTNode from_rewritten = renameVarsBack(from_wide_rewritten);
+            ASTNode to_rewritten = renameVarsBack(to_wide_rewritten);
+
+            assert(BVTypeCheckRecursive(from_rewritten));
+            assert(BVTypeCheckRecursive(to_rewritten));
 
-            rewritten_from = renameVarsBack(rewritten_from);
-            ASTNode to = it->getTo();
-            bool ok = orderEquivalence(rewritten_from, to);
+            assert (isConstantToSat(create(EQ, from_wide_rewritten,from_wide)));
+            assert (isConstantToSat(create(EQ, to_wide_rewritten,to_wide)));
+            assert (isConstantToSat(create(EQ, it->getFrom(),from_rewritten)));
+            assert (isConstantToSat(create(EQ, it->getTo(),to_rewritten)));
+
+            bool ok = orderEquivalence(from_rewritten, to_rewritten);
             if (ok)
               {
-                Rewrite_rule rr(mgr, rewritten_from, to, 0);
-                if (rr.isOK())
-                  {
-                    cout << "Modifying Rule\n";
-                    cout << "Initially From";
-                    cout << it->getFrom();
-                    cout << "new From";
-                    cout << rewritten_from;
-                    cout << "---";
-
-                    *it= rr;
-                    buildLookupTable(); // Otherwise two rules will remove each other?
-                  }
-                else
-                  {
-                    cout << "Erasing rule";
-                    cout << "Initially From";
-                    cout << it->getFrom();
-                    cout << "new From";
-                    cout << rewritten_from;
-                    cout << "---";
-
-                    erase(it--);
-                    i--;
-                    buildLookupTable(); // Otherwise two rules will remove each other?
-                  }
+                Rewrite_rule rr(mgr, from_rewritten, to_rewritten, 0);
+                cout << "Modifying Rule\n";
+                cout << "Initially From";
+                cout << it->getFrom();
+                cout << "Initially To";
+                cout << it->getTo();
+                cout << "New From";
+                cout << from_rewritten;
+                cout << "New To";
+                cout << to_rewritten;
+                cout << "---";
+                cout << getDifficulty(rr.getFrom()) << " --> " << getDifficulty(rr.getTo()) << endl;
+                cout << "replacing" << it->getId() << " with " << rr.getId() << endl;
+
+                *it = rr;
+                lookups_invalid = true;
+
               }
-            else
-              {
-                if (rewritten_from != to)
-                  {
-                      cout << "Mapped but couldn't order";
-                      cout << rewritten_from << to;
-                  }
+            if (!ok)
+            {
+                cout << "Erasing bad rule.\n";
                 erase(it--);
                 i--;
-                buildLookupTable(); // Otherwise two rules will remove each other?
+                lookups_invalid = true;
               }
           }
       }
 
     eraseDuplicates();
     cout << "Size after rewriteAll:" << toWrite.size() << endl;
+    buildLookupTable();
   }
 
   void clear()
   {
     toWrite.clear();
+    buildLookupTable();
   }
 
   void
   verifyAllwithSAT()
   {
+    cerr << "Started verifying all" << endl;
     for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++)
       {
         VariableAssignment assignment;
@@ -270,12 +288,11 @@ public:
             assert(r);
             assert(!bad);
           }
-        if (bits > it->getVerifiedToBits())
+        if (bits >= it->getVerifiedToBits())
           it->setVerified(bits,getCurrentTime() - st);
       }
   }
 
-
   void
   writeOut(ostream &o)
   {