]> git.unchartedbackwaters.co.uk Git - francis/stp.git/commitdiff
Improvements to the utility for generating rewrite rules.
authortrevor_hansen <trevor_hansen@e59a4935-1847-0410-ae03-e826735625c1>
Sat, 3 Mar 2012 12:15:42 +0000 (12:15 +0000)
committertrevor_hansen <trevor_hansen@e59a4935-1847-0410-ae03-e826735625c1>
Sat, 3 Mar 2012 12:15:42 +0000 (12:15 +0000)
git-svn-id: https://stp-fast-prover.svn.sourceforge.net/svnroot/stp-fast-prover/trunk/stp@1579 e59a4935-1847-0410-ae03-e826735625c1

src/util/find_rewrites/Functionlist.h
src/util/find_rewrites/rewrite.cpp
src/util/find_rewrites/rewrite_rule.h
src/util/find_rewrites/rewrite_system.h

index 9086793f88df79b71891e45d96d27dcec794de23..07d10648c63d9f0d84b8f55b75f01c20f70bff3f 100644 (file)
@@ -26,7 +26,8 @@ class Function_list
   getAllFunctions(const ASTNode v, const ASTNode w, ASTVec& result)
   {
 
-    Kind types[] = {BVMULT, BVPLUS, BVXOR, BVAND, BVOR ,BVRIGHTSHIFT, BVLEFTSHIFT};
+    Kind types[] = {BVMULT, BVPLUS, BVXOR, BVAND};
+
 
     //Kind types[] = {BVMULT, BVDIV, SBVDIV, SBVREM, SBVMOD, BVPLUS, BVMOD, BVRIGHTSHIFT, BVLEFTSHIFT, BVOR, BVAND, BVXOR, BVSRSHIFT};
     const int number_types = sizeof(types) / sizeof(Kind);
@@ -91,7 +92,7 @@ class Function_list
   void
   applyRewritesToAll(ASTVec& functions)
   {
-    to_write.buildRules();
+    to_write.buildLookupTable();
     cerr << "Applying:" << to_write.size()  <<"rewrite rules" << endl;
 
     for (int i = 0; i < functions.size(); i++)
@@ -264,7 +265,7 @@ public:
 
     cerr << "One Level:" << functions.size() << endl;
 
-   const bool two_level = false;
+   const bool two_level = true;
 
     if (two_level)
       {
index 8572082480c242424ec08198a87f3977be85bcd3..f1344cb1e28729577f83288301803d954da0c43a 100644 (file)
@@ -319,7 +319,7 @@ isConstant(const ASTNode& n, VariableAssignment& different)
 ASTNode
 widen(const ASTNode& w, int width)
 {
-  assert(bits >=3);
+  assert(bits >=4);
 
   if (w.isConstant() && w.GetValueWidth() == 1)
     return w;
@@ -421,6 +421,12 @@ orderEquivalence(ASTNode& from, ASTNode& to)
   if (intersection != s_from.size() && intersection != s_to.size())
     return false;
 
+  if (to.isAtom() && from.isAtom())
+    return false; // no such rules
+
+  if (to == from)
+    return false; // no such rules
+
   if (to.isAtom())
     return true;
 
@@ -473,12 +479,21 @@ orderEquivalence(ASTNode& from, ASTNode& to)
   getVariables(to, symbols, visited);
   int to_c = visited.size();
 
+  if (to_c < from_c)
+    {
+      return true;
+    }
+
   if (to_c > from_c)
     {
       swap(from, to);
+      return true;
     }
 
-  return true;
+
+
+  // Can't order they have the same number of nodes and the same AIG size.
+  return false;
 }
 
 int
@@ -740,7 +755,19 @@ is_candidate(ASTNode from, ASTNode to)
 bool
 lessThan(const ASTNode& n1, const ASTNode& n2)
 {
-  return (((unsigned) n1.GetNodeNum()) < ((unsigned) n2.GetNodeNum()));
+  bool n1_bad = n1.IsNull() || (n1.GetKind() == UNDEFINED);
+  bool n2_bad = n2.IsNull() || (n2.GetKind() == UNDEFINED);
+
+  if (n1_bad && !n2_bad)
+    return true;
+
+  if (!n1_bad && n2_bad)
+      return false;
+
+  if (n1_bad && n2_bad)
+    return false;
+
+  return getDifficulty(n1) < getDifficulty(n2);
 }
 
 // Breaks the expressions into buckets recursively, then pairwise checks that they are equivalent.
@@ -760,6 +787,9 @@ findRewrites(ASTVec& expressions, const vector<VariableAssignment>& values, cons
 
   if (values.size() > 0)
     {
+      if (values.size() > 10)
+        removeDuplicates(expressions);
+
       // Put the functions in buckets based on their results on the values.
       HASHMAP<uint64_t, ASTVec> map;
       for (int i = 0; i < expressions.size(); i++)
@@ -768,7 +798,7 @@ findRewrites(ASTVec& expressions, const vector<VariableAssignment>& values, cons
             continue; // omit undefined.
 
           if (i % 50000 == 49999)
-            cerr << ".";
+            cout << ".";
           uint64_t hash = getHash(expressions[i], values);
           if (map.find(hash) == map.end())
             map.insert(make_pair(hash, ASTVec()));
@@ -792,8 +822,9 @@ findRewrites(ASTVec& expressions, const vector<VariableAssignment>& values, cons
     }
   ASTVec& equiv = expressions;
 
+
   // Sort so that constants, and smaller expressions will be checked first.
-  sort(equiv.begin(), equiv.end(), lessThan);
+  std::sort(equiv.begin(), equiv.end(), lessThan);
 
   for (int i = 0; i < equiv.size(); i++)
     {
@@ -810,40 +841,46 @@ findRewrites(ASTVec& expressions, const vector<VariableAssignment>& values, cons
 
           equiv[j] = to_write.rewriteNode(equiv[j]);
 
-          ASTNode n = nf->CreateNode(EQ, equiv[i], equiv[j]);
-          if (n.GetKind() != EQ)
-            continue;
-
           ASTNode from = equiv[i];
           ASTNode to = equiv[j];
           bool r = orderEquivalence(from, to);
 
+          if (!r)
+          {
+            if (from != to)
+                cout << "can't be ordered" << from << to;
+            continue;
+          }
+
           VariableAssignment different;
           bool bad = false;
           const int st = getCurrentTime();
 
           if (checkRule(from, to, different, bad))
             {
+              cout << "Discovered a new rule.";
+              cout << from;
+              cout << to;
+              cout << getDifficulty(from) << " to " << getDifficulty(to) << endl;
+              cout << "After rewriting";
+              cout << to_write.rewriteNode(from);
+              cout << to_write.rewriteNode(to);
+              cout << "------";
+
               to_write.push_back(Rewrite_rule(mgr, from, to, getCurrentTime() - st));
 
               // Remove the more difficult expression.
               if (from == equiv[i])
                 {
-                  cerr << ".";
+                  cout << ".";
                   equiv[i] = mgr->ASTUndefined;
                 }
               if (from == equiv[j])
                 {
-                  cerr << ".";
+                  cout << ".";
                   equiv[j] = mgr->ASTUndefined;
                 }
             }
-          else if (!r)
-            {
-              // It probably shouldn't get to here..
-              cerr << "can't be ordered" << from << to;
-              continue; // can't be ordered.
-            }
           else if (!bad)
             {
               vector<VariableAssignment> ass;
@@ -993,9 +1030,11 @@ containsNode(const ASTNode& n, const ASTNode& hunting, string& current)
 bool
 checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignment, bool&bad)
 {
-  const ASTNode n = create(EQ, from, to);
+  ASTVec children;
+  children.push_back(from);
+  children.push_back(to);
+  const ASTNode n = mgr->hashingNodeFactory->CreateNode(EQ, children);
 
-  assert(n.GetKind() == BEEV::EQ);
   assert(widen_to > bits);
 
   for (int i = bits; i < widen_to; i++)
@@ -1005,8 +1044,8 @@ checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignme
       // Can't widen (usually because of CONCAT or a BVCONST).
       if (widened == mgr->ASTUndefined)
         {
-          cerr << "can't widen";
-          cerr << n;
+          cout << "cannot widen";
+          //cerr << n;
           bad = true;
           return false;
         }
@@ -1023,8 +1062,7 @@ checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignme
             }
 
           // Detected it's not a constant, or is constant FALSE.
-          if (i-bits > 0)
-            cerr << "*" << i - bits << "*";
+           cout << "*" << i - bits << "*";
 
           return false;
         }
@@ -1037,11 +1075,11 @@ template<class T>
   void
   removeDuplicates(T & big)
   {
-    cerr << "Before removing duplicates:" << big.size();
+    cout << "Before removing duplicates:" << big.size();
     std::sort(big.begin(), big.end());
     typename T::iterator it = std::unique(big.begin(), big.end());
     big.erase(it, big.end());
-    cerr << ".After removing duplicates:" << big.size() << endl;
+    cout << ".After removing duplicates:" << big.size() << endl;
   }
 
 // Hash function for the hash_map of a string..
@@ -1237,7 +1275,6 @@ void
 writeOutRules(string fileName)
 {
   to_write.rewriteAll();
-  to_write.eraseDuplicates();
 
   std::vector<string> output;
   std::map<string, Rewrite_rule> dup;
@@ -1246,7 +1283,7 @@ writeOutRules(string fileName)
     {
       if (!it->isOK())
         {
-          to_write.toWrite.erase(it--);
+          to_write.erase(it--);
           continue;
         }
 
@@ -1295,24 +1332,24 @@ writeOutRules(string fileName)
 
               if (dup.find(sofar) != dup.end())
                 {
-                  cerr << "-----";
-                  cerr << sofar;
+                  cout << "-----";
+                  cout << sofar;
 
                   ASTNode f = it->getFrom();
-                  cerr << f << std::endl;
-                  cerr << dup.find(sofar)->second.getFrom();
+                  cout << f << std::endl;
+                  cout << dup.find(sofar)->second.getFrom();
 
                   ASTNodeMap fromTo;
 
                   f = renameVars(f);
                   //cerr << "renamed" << f;
                   bool result = unifyNode(f,dup.find(sofar)->second.getFrom(),fromTo,2) ;
-                  cerr << "unified" << result << endl;
+                  cout << "unified" << result << endl;
                   ASTNodeMap seen;
-                  cerr << rewrite(f,*it,seen );
+                  cout << rewrite(f,*it,seen );
 
                   // The text of this rule is the same as another rule.
-                  to_write.toWrite.erase(it--);
+                  to_write.erase(it--);
                   continue;
                 }
               else
@@ -1324,7 +1361,7 @@ writeOutRules(string fileName)
   // Remove the duplicates from output.
   removeDuplicates(output);
 
-  cerr << "Rules Discovered in total: " << to_write.size() << endl;
+  cout << "Rules Discovered in total: " << to_write.size() << endl;
 
   // Group functions of the same kind all together.
   hash_map<string, vector<string>, hashF<std::string> > buckets;
@@ -1452,16 +1489,19 @@ rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen)
           cerr << "--------------";
            */
 
-          // This doesn't distinguish between the second time it's seen in the term, and seeing it again.
-          ASTNodeMap cache;
           if (seen.find(n) != seen.end())
             return seen.find(n)->second;
+
+          seen.insert(make_pair(n,rr[i].getTo()));
+          ASTNodeMap cache;
           ASTNode r=  SubstitutionMap::replace(rr[i].getTo(), fromTo, cache, nf, false, true);
+          seen.erase(n);
+
           seen.insert(make_pair(n,r));
-          ASTNode r2= rewrite(r,original_rule,seen);
+          r= rewrite(r,original_rule,seen);
           seen.erase(n);
-          seen.insert(make_pair(n,r2));
-          return r2;
+          seen.insert(make_pair(n,r));
+          return r;
         }
     }
 
@@ -1487,14 +1527,14 @@ void loadExistingRules(string fileName)
   ASTVec values = piTypeCheckDefault.GetAsserts();
   values = FlattenKind(AND, values);
 
-  cerr << "Rewrite rule size:" << values.size() << endl;
+  cout << "Rewrite rule size:" << values.size() << endl;
 
   for (int i = 0; i < values.size(); i++)
     {
       if ((values[i].GetKind() != EQ))
         {
-          cerr << "Not equality??";
-          cerr << values[i];
+          cout << "Not equality??";
+          cout << values[i];
           continue;
         }
 
@@ -1503,24 +1543,27 @@ void loadExistingRules(string fileName)
 
       // Rule should be orderable.
       bool ok = orderEquivalence(from, to);
-      assert(ok);
-      Rewrite_rule r(mgr, from, to, 0);
+      if (!ok)
+        {
+          cout << "discarding rule that can't be ordere";
+          continue;
+        }
 
+      Rewrite_rule r(mgr, from, to, 0);
 
       if (r.isOK());
         to_write.push_back(r);
-
     }
 
   mgr->PopQuery();
   parserInterface->popToFirstLevel();
   parserInterface->cleanUp();
 
-  to_write.buildRules();
+  to_write.buildLookupTable();
 
   ASTVec vvv = mgr->GetAsserts();
   for (int i=0; i < vvv.size() ;i++)
-    cerr << vvv[i];
+    cout << vvv[i];
 
   // So we don't output as soon as one is discovered...
   lastOutput = to_write.size();
@@ -1593,22 +1636,24 @@ main()
   findNewRewrites();
   writeOutRules("array.smt2");
   to_write.verifyAllwithSAT();
+  writeOutRules("array-with-times.smt2"); // verifyingallwithsat gives us the times..
 }
 
 int
 findNewRewrites()
 {
+  to_write.buildLookupTable();
+
   Function_list functionList;
   functionList.buildAll();
 
   // The hash is generated on these values.
   vector<VariableAssignment> values;
   findRewrites(functionList.functions, values);
-  writeOutRules("array.smt2");
 
-  cerr << "Initial:" << bits << " widening to :" << widen_to << endl;
-  cerr << "Highest disproved @ level: " << highestLevel << endl;
-  cerr << highestDisproved << endl;
+  cout << "Initial:" << bits << " widening to :" << widen_to << endl;
+  cout << "Highest disproved @ level: " << highestLevel << endl;
+  cout << highestDisproved << endl;
   return 0;
 }
 
index 2f0b2a62f41ebe5f141bfbcdb315b434bb2b3322..36c08d84dd05d14870d73fae2b5d8d506b7cb241 100644 (file)
@@ -65,17 +65,14 @@ public:
   bool
   isOK()
   {
-    if (getN().GetKind() != EQ)
-      return false;
-
     ASTNode w = widen(getN(), widen_to);
 
-    BVTypeCheckRecursive(n);
-    BVTypeCheckRecursive(w);
-
     if  (w.IsNull() || w.GetKind() == UNDEFINED)
       return false;
 
+    assert(BVTypeCheckRecursive(n));
+    assert(BVTypeCheckRecursive(w));
+
     if (from.isAtom() && to.isAtom())
         return false;
 
@@ -87,12 +84,17 @@ public:
   }
 
   Rewrite_rule(BEEV::STPMgr* bm, const BEEV::ASTNode& from_, const BEEV::ASTNode& to_, const int t)
-  : from(from_), to(to_), n ( bm->CreateNode(BEEV::EQ,to_,from_))
+  : from(from_), to(to_)
     {
       id = static_id++;
-
       time = t;
 
+      ASTVec c;
+      c.push_back(to_);
+      c.push_back(from_);
+      n =  bm->hashingNodeFactory->CreateNode(BEEV::EQ,c);
+
+
       ////
       assert(!from.IsNull());
       assert(from.GetKind() != UNDEFINED);
index 220aefa813ca99e1dd55a973218f93163e690abc..683208a7e6e802e88c73e6b7aace12f056d9d43a 100644 (file)
@@ -50,7 +50,7 @@ private:
 
   friend void writeOutRules(string fileName);
 
-  friend ASTNode  rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen);
+  friend ASTNode rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen);
 
   // Rules to write out when we get the chance.
   typedef list<Rewrite_rule> RewriteRuleContainer;
@@ -66,24 +66,50 @@ public:
   }
 
   void
-  buildRules()
+  addRuleToLookup(Rewrite_rule& r)
+  {
+    const ASTNode& from = r.getFrom();
+    kind_to_rr[from.GetKind()].push_back(r);
+
+    assert(from.Degree() > 0); // Shouldn't map from a constant, nor from a variable.
+
+    if (from[0].Degree() > 0)
+      kind_kind_to_rr[from.GetKind()][from[0].GetKind()].push_back(r);
+  }
+
+  void
+  buildLookupTable()
   {
     kind_to_rr.clear();
     kind_kind_to_rr.clear();
 
     for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++)
       {
-        ASTNode from = it->getFrom();
-        kind_to_rr[from.GetKind()].push_back(*it);
+        addRuleToLookup(*it);
+      }
+  }
 
-        if (from[0].Degree() > 0)
-          kind_kind_to_rr[from.GetKind()][from[0].GetKind()].push_back(*it);
+  void
+  removeBad()
+  {
+    for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++)
+      {
+        if (!it->isOK())
+          {
+            cout << "Removing Rule that is bad";
+            cout << it->getFrom();
+            cout << it->getTo();
+            cout << "----\n";
+
+            toWrite.erase(it--);
+          }
       }
   }
 
   void
   eraseDuplicates()
   {
+    removeBad();
     toWrite.sort();
     toWrite.unique();
   }
@@ -92,6 +118,13 @@ public:
   push_back(Rewrite_rule rr)
   {
     toWrite.push_back(rr);
+    addRuleToLookup(rr);
+  }
+
+  void
+  erase(RewriteRuleContainer::iterator it)
+  {
+    toWrite.erase(it);
   }
 
   int
@@ -110,21 +143,19 @@ public:
   rewriteAll()
   {
     eraseDuplicates();
-    cerr << "Size before rewriteAll:" << toWrite.size() << endl;
+    cout << "Size before rewriteAll:" << toWrite.size() << endl;
 
-    buildRules();
+    buildLookupTable();
 
     int i=0;
     for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++, i++)
       {
         if (i % 1000 == 0)
-          cerr << "rewrite all:" << i << " of " << toWrite.size() << endl;
+          cout << "rewrite all:" << i << " of " << toWrite.size() << endl;
 
-        if (!it->isOK())
-          {
-            toWrite.erase(it--);
-            continue;
-          }
+        // if not OK, should have been removed during duplicates.
+        // shouldn't add extra rules that aren't ok.
+        assert (it->isOK());
 
         ASTNode n = renameVars(it->getFrom());
         ASTNodeMap seen;
@@ -136,30 +167,48 @@ public:
 
             rewritten_from = renameVarsBack(rewritten_from);
             ASTNode to = it->getTo();
-            bool r = orderEquivalence(rewritten_from, to);
-            if (r)
+            bool ok = orderEquivalence(rewritten_from, to);
+            if (ok)
               {
                 Rewrite_rule rr(mgr, rewritten_from, to, 0);
                 if (rr.isOK())
                   {
+                    cout << "Modifying Rule\n";
+                    cout << "Initially From";
+                    cout << it->getFrom();
+                    cout << "new From";
+                    cout << rewritten_from;
+                    cout << "---";
+
                     *it= rr;
-                    buildRules(); // Otherwise two rules will remove each other?
+                    buildLookupTable(); // Otherwise two rules will remove each other?
                   }
                 else
                   {
                     cout << "Erasing rule";
-                    toWrite.erase(it--);                  }
+                    cout << "Initially From";
+                    cout << it->getFrom();
+                    cout << "new From";
+                    cout << rewritten_from;
+                    cout << "---";
+
+                    erase(it--);
+                    i--;
+                    buildLookupTable(); // Otherwise two rules will remove each other?
+                  }
               }
             else
               {
-                cerr << "Mapped but couldn't order";
-                cerr << rewritten_from << to;
+                cout << "Mapped but couldn't order";
+                cout << rewritten_from << to;
+                erase(it--);
+                i--;
               }
           }
       }
 
     eraseDuplicates();
-    cerr << "Size after rewriteAll:" << toWrite.size() << endl;
+    cout << "Size after rewriteAll:" << toWrite.size() << endl;
   }
 
   void clear()
@@ -178,9 +227,9 @@ public:
         bool r = checkRule(it->getFrom(), it->getTo(), assignment, bad);
         if (!r || bad)
           {
-            cerr << "Bad to, then from" << endl;
-            cerr << it->getFrom();
-            cerr << it->getTo();
+            cout << "Bad to, then from" << endl;
+            cout << it->getFrom();
+            cout << it->getTo();
             assert(r);
             assert(!bad);
           }