From f61a54ab17c6da659bbffdf0b26ad0290e4c35a9 Mon Sep 17 00:00:00 2001 From: trevor_hansen Date: Sat, 3 Mar 2012 12:15:42 +0000 Subject: [PATCH] Improvements to the utility for generating rewrite rules. git-svn-id: https://stp-fast-prover.svn.sourceforge.net/svnroot/stp-fast-prover/trunk/stp@1579 e59a4935-1847-0410-ae03-e826735625c1 --- src/util/find_rewrites/Functionlist.h | 7 +- src/util/find_rewrites/rewrite.cpp | 149 +++++++++++++++--------- src/util/find_rewrites/rewrite_rule.h | 18 +-- src/util/find_rewrites/rewrite_system.h | 97 +++++++++++---- 4 files changed, 184 insertions(+), 87 deletions(-) diff --git a/src/util/find_rewrites/Functionlist.h b/src/util/find_rewrites/Functionlist.h index 9086793..07d1064 100644 --- a/src/util/find_rewrites/Functionlist.h +++ b/src/util/find_rewrites/Functionlist.h @@ -26,7 +26,8 @@ class Function_list getAllFunctions(const ASTNode v, const ASTNode w, ASTVec& result) { - Kind types[] = {BVMULT, BVPLUS, BVXOR, BVAND, BVOR ,BVRIGHTSHIFT, BVLEFTSHIFT}; + Kind types[] = {BVMULT, BVPLUS, BVXOR, BVAND}; + //Kind types[] = {BVMULT, BVDIV, SBVDIV, SBVREM, SBVMOD, BVPLUS, BVMOD, BVRIGHTSHIFT, BVLEFTSHIFT, BVOR, BVAND, BVXOR, BVSRSHIFT}; const int number_types = sizeof(types) / sizeof(Kind); @@ -91,7 +92,7 @@ class Function_list void applyRewritesToAll(ASTVec& functions) { - to_write.buildRules(); + to_write.buildLookupTable(); cerr << "Applying:" << to_write.size() <<"rewrite rules" << endl; for (int i = 0; i < functions.size(); i++) @@ -264,7 +265,7 @@ public: cerr << "One Level:" << functions.size() << endl; - const bool two_level = false; + const bool two_level = true; if (two_level) { diff --git a/src/util/find_rewrites/rewrite.cpp b/src/util/find_rewrites/rewrite.cpp index 8572082..f1344cb 100644 --- a/src/util/find_rewrites/rewrite.cpp +++ b/src/util/find_rewrites/rewrite.cpp @@ -319,7 +319,7 @@ isConstant(const ASTNode& n, VariableAssignment& different) ASTNode widen(const ASTNode& w, int width) { - assert(bits >=3); + assert(bits >=4); if (w.isConstant() && w.GetValueWidth() == 1) return w; @@ -421,6 +421,12 @@ orderEquivalence(ASTNode& from, ASTNode& to) if (intersection != s_from.size() && intersection != s_to.size()) return false; + if (to.isAtom() && from.isAtom()) + return false; // no such rules + + if (to == from) + return false; // no such rules + if (to.isAtom()) return true; @@ -473,12 +479,21 @@ orderEquivalence(ASTNode& from, ASTNode& to) getVariables(to, symbols, visited); int to_c = visited.size(); + if (to_c < from_c) + { + return true; + } + if (to_c > from_c) { swap(from, to); + return true; } - return true; + + + // Can't order they have the same number of nodes and the same AIG size. + return false; } int @@ -740,7 +755,19 @@ is_candidate(ASTNode from, ASTNode to) bool lessThan(const ASTNode& n1, const ASTNode& n2) { - return (((unsigned) n1.GetNodeNum()) < ((unsigned) n2.GetNodeNum())); + bool n1_bad = n1.IsNull() || (n1.GetKind() == UNDEFINED); + bool n2_bad = n2.IsNull() || (n2.GetKind() == UNDEFINED); + + if (n1_bad && !n2_bad) + return true; + + if (!n1_bad && n2_bad) + return false; + + if (n1_bad && n2_bad) + return false; + + return getDifficulty(n1) < getDifficulty(n2); } // Breaks the expressions into buckets recursively, then pairwise checks that they are equivalent. @@ -760,6 +787,9 @@ findRewrites(ASTVec& expressions, const vector& values, cons if (values.size() > 0) { + if (values.size() > 10) + removeDuplicates(expressions); + // Put the functions in buckets based on their results on the values. HASHMAP map; for (int i = 0; i < expressions.size(); i++) @@ -768,7 +798,7 @@ findRewrites(ASTVec& expressions, const vector& values, cons continue; // omit undefined. if (i % 50000 == 49999) - cerr << "."; + cout << "."; uint64_t hash = getHash(expressions[i], values); if (map.find(hash) == map.end()) map.insert(make_pair(hash, ASTVec())); @@ -792,8 +822,9 @@ findRewrites(ASTVec& expressions, const vector& values, cons } ASTVec& equiv = expressions; + // Sort so that constants, and smaller expressions will be checked first. - sort(equiv.begin(), equiv.end(), lessThan); + std::sort(equiv.begin(), equiv.end(), lessThan); for (int i = 0; i < equiv.size(); i++) { @@ -810,40 +841,46 @@ findRewrites(ASTVec& expressions, const vector& values, cons equiv[j] = to_write.rewriteNode(equiv[j]); - ASTNode n = nf->CreateNode(EQ, equiv[i], equiv[j]); - if (n.GetKind() != EQ) - continue; - ASTNode from = equiv[i]; ASTNode to = equiv[j]; bool r = orderEquivalence(from, to); + if (!r) + { + if (from != to) + cout << "can't be ordered" << from << to; + continue; + } + VariableAssignment different; bool bad = false; const int st = getCurrentTime(); if (checkRule(from, to, different, bad)) { + cout << "Discovered a new rule."; + cout << from; + cout << to; + cout << getDifficulty(from) << " to " << getDifficulty(to) << endl; + cout << "After rewriting"; + cout << to_write.rewriteNode(from); + cout << to_write.rewriteNode(to); + cout << "------"; + to_write.push_back(Rewrite_rule(mgr, from, to, getCurrentTime() - st)); // Remove the more difficult expression. if (from == equiv[i]) { - cerr << "."; + cout << "."; equiv[i] = mgr->ASTUndefined; } if (from == equiv[j]) { - cerr << "."; + cout << "."; equiv[j] = mgr->ASTUndefined; } } - else if (!r) - { - // It probably shouldn't get to here.. - cerr << "can't be ordered" << from << to; - continue; // can't be ordered. - } else if (!bad) { vector ass; @@ -993,9 +1030,11 @@ containsNode(const ASTNode& n, const ASTNode& hunting, string& current) bool checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignment, bool&bad) { - const ASTNode n = create(EQ, from, to); + ASTVec children; + children.push_back(from); + children.push_back(to); + const ASTNode n = mgr->hashingNodeFactory->CreateNode(EQ, children); - assert(n.GetKind() == BEEV::EQ); assert(widen_to > bits); for (int i = bits; i < widen_to; i++) @@ -1005,8 +1044,8 @@ checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignme // Can't widen (usually because of CONCAT or a BVCONST). if (widened == mgr->ASTUndefined) { - cerr << "can't widen"; - cerr << n; + cout << "cannot widen"; + //cerr << n; bad = true; return false; } @@ -1023,8 +1062,7 @@ checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignme } // Detected it's not a constant, or is constant FALSE. - if (i-bits > 0) - cerr << "*" << i - bits << "*"; + cout << "*" << i - bits << "*"; return false; } @@ -1037,11 +1075,11 @@ template void removeDuplicates(T & big) { - cerr << "Before removing duplicates:" << big.size(); + cout << "Before removing duplicates:" << big.size(); std::sort(big.begin(), big.end()); typename T::iterator it = std::unique(big.begin(), big.end()); big.erase(it, big.end()); - cerr << ".After removing duplicates:" << big.size() << endl; + cout << ".After removing duplicates:" << big.size() << endl; } // Hash function for the hash_map of a string.. @@ -1237,7 +1275,6 @@ void writeOutRules(string fileName) { to_write.rewriteAll(); - to_write.eraseDuplicates(); std::vector output; std::map dup; @@ -1246,7 +1283,7 @@ writeOutRules(string fileName) { if (!it->isOK()) { - to_write.toWrite.erase(it--); + to_write.erase(it--); continue; } @@ -1295,24 +1332,24 @@ writeOutRules(string fileName) if (dup.find(sofar) != dup.end()) { - cerr << "-----"; - cerr << sofar; + cout << "-----"; + cout << sofar; ASTNode f = it->getFrom(); - cerr << f << std::endl; - cerr << dup.find(sofar)->second.getFrom(); + cout << f << std::endl; + cout << dup.find(sofar)->second.getFrom(); ASTNodeMap fromTo; f = renameVars(f); //cerr << "renamed" << f; bool result = unifyNode(f,dup.find(sofar)->second.getFrom(),fromTo,2) ; - cerr << "unified" << result << endl; + cout << "unified" << result << endl; ASTNodeMap seen; - cerr << rewrite(f,*it,seen ); + cout << rewrite(f,*it,seen ); // The text of this rule is the same as another rule. - to_write.toWrite.erase(it--); + to_write.erase(it--); continue; } else @@ -1324,7 +1361,7 @@ writeOutRules(string fileName) // Remove the duplicates from output. removeDuplicates(output); - cerr << "Rules Discovered in total: " << to_write.size() << endl; + cout << "Rules Discovered in total: " << to_write.size() << endl; // Group functions of the same kind all together. hash_map, hashF > buckets; @@ -1452,16 +1489,19 @@ rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen) cerr << "--------------"; */ - // This doesn't distinguish between the second time it's seen in the term, and seeing it again. - ASTNodeMap cache; if (seen.find(n) != seen.end()) return seen.find(n)->second; + + seen.insert(make_pair(n,rr[i].getTo())); + ASTNodeMap cache; ASTNode r= SubstitutionMap::replace(rr[i].getTo(), fromTo, cache, nf, false, true); + seen.erase(n); + seen.insert(make_pair(n,r)); - ASTNode r2= rewrite(r,original_rule,seen); + r= rewrite(r,original_rule,seen); seen.erase(n); - seen.insert(make_pair(n,r2)); - return r2; + seen.insert(make_pair(n,r)); + return r; } } @@ -1487,14 +1527,14 @@ void loadExistingRules(string fileName) ASTVec values = piTypeCheckDefault.GetAsserts(); values = FlattenKind(AND, values); - cerr << "Rewrite rule size:" << values.size() << endl; + cout << "Rewrite rule size:" << values.size() << endl; for (int i = 0; i < values.size(); i++) { if ((values[i].GetKind() != EQ)) { - cerr << "Not equality??"; - cerr << values[i]; + cout << "Not equality??"; + cout << values[i]; continue; } @@ -1503,24 +1543,27 @@ void loadExistingRules(string fileName) // Rule should be orderable. bool ok = orderEquivalence(from, to); - assert(ok); - Rewrite_rule r(mgr, from, to, 0); + if (!ok) + { + cout << "discarding rule that can't be ordere"; + continue; + } + Rewrite_rule r(mgr, from, to, 0); if (r.isOK()); to_write.push_back(r); - } mgr->PopQuery(); parserInterface->popToFirstLevel(); parserInterface->cleanUp(); - to_write.buildRules(); + to_write.buildLookupTable(); ASTVec vvv = mgr->GetAsserts(); for (int i=0; i < vvv.size() ;i++) - cerr << vvv[i]; + cout << vvv[i]; // So we don't output as soon as one is discovered... lastOutput = to_write.size(); @@ -1593,22 +1636,24 @@ main() findNewRewrites(); writeOutRules("array.smt2"); to_write.verifyAllwithSAT(); + writeOutRules("array-with-times.smt2"); // verifyingallwithsat gives us the times.. } int findNewRewrites() { + to_write.buildLookupTable(); + Function_list functionList; functionList.buildAll(); // The hash is generated on these values. vector values; findRewrites(functionList.functions, values); - writeOutRules("array.smt2"); - cerr << "Initial:" << bits << " widening to :" << widen_to << endl; - cerr << "Highest disproved @ level: " << highestLevel << endl; - cerr << highestDisproved << endl; + cout << "Initial:" << bits << " widening to :" << widen_to << endl; + cout << "Highest disproved @ level: " << highestLevel << endl; + cout << highestDisproved << endl; return 0; } diff --git a/src/util/find_rewrites/rewrite_rule.h b/src/util/find_rewrites/rewrite_rule.h index 2f0b2a6..36c08d8 100644 --- a/src/util/find_rewrites/rewrite_rule.h +++ b/src/util/find_rewrites/rewrite_rule.h @@ -65,17 +65,14 @@ public: bool isOK() { - if (getN().GetKind() != EQ) - return false; - ASTNode w = widen(getN(), widen_to); - BVTypeCheckRecursive(n); - BVTypeCheckRecursive(w); - if (w.IsNull() || w.GetKind() == UNDEFINED) return false; + assert(BVTypeCheckRecursive(n)); + assert(BVTypeCheckRecursive(w)); + if (from.isAtom() && to.isAtom()) return false; @@ -87,12 +84,17 @@ public: } Rewrite_rule(BEEV::STPMgr* bm, const BEEV::ASTNode& from_, const BEEV::ASTNode& to_, const int t) - : from(from_), to(to_), n ( bm->CreateNode(BEEV::EQ,to_,from_)) + : from(from_), to(to_) { id = static_id++; - time = t; + ASTVec c; + c.push_back(to_); + c.push_back(from_); + n = bm->hashingNodeFactory->CreateNode(BEEV::EQ,c); + + //// assert(!from.IsNull()); assert(from.GetKind() != UNDEFINED); diff --git a/src/util/find_rewrites/rewrite_system.h b/src/util/find_rewrites/rewrite_system.h index 220aefa..683208a 100644 --- a/src/util/find_rewrites/rewrite_system.h +++ b/src/util/find_rewrites/rewrite_system.h @@ -50,7 +50,7 @@ private: friend void writeOutRules(string fileName); - friend ASTNode rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen); + friend ASTNode rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen); // Rules to write out when we get the chance. typedef list RewriteRuleContainer; @@ -66,24 +66,50 @@ public: } void - buildRules() + addRuleToLookup(Rewrite_rule& r) + { + const ASTNode& from = r.getFrom(); + kind_to_rr[from.GetKind()].push_back(r); + + assert(from.Degree() > 0); // Shouldn't map from a constant, nor from a variable. + + if (from[0].Degree() > 0) + kind_kind_to_rr[from.GetKind()][from[0].GetKind()].push_back(r); + } + + void + buildLookupTable() { kind_to_rr.clear(); kind_kind_to_rr.clear(); for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++) { - ASTNode from = it->getFrom(); - kind_to_rr[from.GetKind()].push_back(*it); + addRuleToLookup(*it); + } + } - if (from[0].Degree() > 0) - kind_kind_to_rr[from.GetKind()][from[0].GetKind()].push_back(*it); + void + removeBad() + { + for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++) + { + if (!it->isOK()) + { + cout << "Removing Rule that is bad"; + cout << it->getFrom(); + cout << it->getTo(); + cout << "----\n"; + + toWrite.erase(it--); + } } } void eraseDuplicates() { + removeBad(); toWrite.sort(); toWrite.unique(); } @@ -92,6 +118,13 @@ public: push_back(Rewrite_rule rr) { toWrite.push_back(rr); + addRuleToLookup(rr); + } + + void + erase(RewriteRuleContainer::iterator it) + { + toWrite.erase(it); } int @@ -110,21 +143,19 @@ public: rewriteAll() { eraseDuplicates(); - cerr << "Size before rewriteAll:" << toWrite.size() << endl; + cout << "Size before rewriteAll:" << toWrite.size() << endl; - buildRules(); + buildLookupTable(); int i=0; for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++, i++) { if (i % 1000 == 0) - cerr << "rewrite all:" << i << " of " << toWrite.size() << endl; + cout << "rewrite all:" << i << " of " << toWrite.size() << endl; - if (!it->isOK()) - { - toWrite.erase(it--); - continue; - } + // if not OK, should have been removed during duplicates. + // shouldn't add extra rules that aren't ok. + assert (it->isOK()); ASTNode n = renameVars(it->getFrom()); ASTNodeMap seen; @@ -136,30 +167,48 @@ public: rewritten_from = renameVarsBack(rewritten_from); ASTNode to = it->getTo(); - bool r = orderEquivalence(rewritten_from, to); - if (r) + bool ok = orderEquivalence(rewritten_from, to); + if (ok) { Rewrite_rule rr(mgr, rewritten_from, to, 0); if (rr.isOK()) { + cout << "Modifying Rule\n"; + cout << "Initially From"; + cout << it->getFrom(); + cout << "new From"; + cout << rewritten_from; + cout << "---"; + *it= rr; - buildRules(); // Otherwise two rules will remove each other? + buildLookupTable(); // Otherwise two rules will remove each other? } else { cout << "Erasing rule"; - toWrite.erase(it--); } + cout << "Initially From"; + cout << it->getFrom(); + cout << "new From"; + cout << rewritten_from; + cout << "---"; + + erase(it--); + i--; + buildLookupTable(); // Otherwise two rules will remove each other? + } } else { - cerr << "Mapped but couldn't order"; - cerr << rewritten_from << to; + cout << "Mapped but couldn't order"; + cout << rewritten_from << to; + erase(it--); + i--; } } } eraseDuplicates(); - cerr << "Size after rewriteAll:" << toWrite.size() << endl; + cout << "Size after rewriteAll:" << toWrite.size() << endl; } void clear() @@ -178,9 +227,9 @@ public: bool r = checkRule(it->getFrom(), it->getTo(), assignment, bad); if (!r || bad) { - cerr << "Bad to, then from" << endl; - cerr << it->getFrom(); - cerr << it->getTo(); + cout << "Bad to, then from" << endl; + cout << it->getFrom(); + cout << it->getTo(); assert(r); assert(!bad); } -- 2.47.3