From: trevor_hansen Date: Mon, 12 Mar 2012 04:01:56 +0000 (+0000) Subject: Improvements to the code for generating rewrite rules. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=8f51e0ca18f555ef5c6c450cb7ebff29edd2a642;p=francis%2Fstp.git Improvements to the code for generating rewrite rules. git-svn-id: https://stp-fast-prover.svn.sourceforge.net/svnroot/stp-fast-prover/trunk/stp@1586 e59a4935-1847-0410-ae03-e826735625c1 --- diff --git a/src/util/find_rewrites/Functionlist.h b/src/util/find_rewrites/Functionlist.h index 3e45a58..77ea4ee 100644 --- a/src/util/find_rewrites/Functionlist.h +++ b/src/util/find_rewrites/Functionlist.h @@ -5,20 +5,16 @@ #ifndef FUNCTIONLIST_H_ #define FUNCTIONLIST_H_ +#include "rewrite_system.h" +#include "misc.h" -extern const int bits; -extern Simplifier *simp; extern Rewrite_system rewrite_system; -ASTNode -widen(const ASTNode& w, int width); - -ASTNode -create(Kind k, const ASTNode& n0, const ASTNode& n1); - class Function_list { + private: + // Because v and w might come from "result", if "result" is resized, they will // be moved. So we can't use references to them. @@ -38,35 +34,6 @@ class Function_list } - ASTNode - rewriteThroughWithAIGS(const ASTNode &n_) - { - assert(n_.GetType() == BITVECTOR_TYPE); - ASTNode f = mgr->LookupOrCreateSymbol("rewriteThroughWithAIGS"); - f.SetValueWidth(n_.GetValueWidth()); - ASTNode n = create(EQ, n_, f); - - BBNodeManagerAIG nm; - BitBlaster bb(&nm, simp, mgr->defaultNodeFactory, &mgr->UserFlags); - ASTNodeMap fromTo; - ASTNodeMap equivs; - bb.getConsts(n, fromTo, equivs); - - ASTNode result = n_; - if (equivs.size() > 0) - { - ASTNodeMap cache; - result = SubstitutionMap::replace(result, equivs, cache, nf, false, true); - } - - if (fromTo.size() > 0) - { - ASTNodeMap cache; - result = SubstitutionMap::replace(result, fromTo, cache, nf); - } - return result; - } - void applyBigRewrite() { @@ -187,7 +154,7 @@ class Function_list if (mgr->ASTUndefined == widen(functions[i], bits + 1)) { - cerr << "Can't widen" << functions[i]; + //cerr << "Can't widen" << functions[i]; functions[i] = mgr->ASTUndefined; // We can't widen it later. So remove it. continue; } @@ -224,7 +191,7 @@ class Function_list if (i % 100000 == 0) cerr << "ApplyAigs:" << i << " of " << functions.size() << endl; - rewriteThroughWithAIGS(functions[i]); + functions[i] = rewriteThroughWithAIGS(functions[i]); } } @@ -255,9 +222,8 @@ public: allUnary(); - applyAIGs(); - applySpeculative(); + //applySpeculative(); applyRewritesToAll(functions); checkFunctions(); removeDuplicates(functions); @@ -276,24 +242,21 @@ public: for (int j = 0; j < size; j++) getAllFunctions(functions_copy[i], functions_copy[j], functions); - cerr << "Removing single variables" < + getVariables(const ASTNode& n); + + ASTNode + rewriteThroughWithAIGS(const ASTNode &n_); + +#endif diff --git a/src/util/find_rewrites/rewrite.cpp b/src/util/find_rewrites/rewrite.cpp index af38e34..7c46ca8 100644 --- a/src/util/find_rewrites/rewrite.cpp +++ b/src/util/find_rewrites/rewrite.cpp @@ -25,6 +25,7 @@ #include "rewrite_rule.h" #include "rewrite_system.h" #include "Functionlist.h" +#include "misc.h" extern int smt2parse(); @@ -51,7 +52,7 @@ const int mask = (1 << (bits)) - 1; volatile bool force_writeout = false; // Saves a little bit of time. The vectors are saved between invocations. -vector saved_array; +vector saved_array; // Stores the difficulties that have already been generated. map difficulty_cache; @@ -87,12 +88,14 @@ vector getVariables(const ASTNode& n); bool -unifyNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width); +matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width); typedef HASHMAP ASTNodeString; BEEV::STPMgr* mgr; NodeFactory* nf; +NodeFactory* simpNf; + SATSolver * ss; ASTNodeSet stored; // Store nodes so they aren't garbage collected. Simplifier *simp; @@ -106,6 +109,23 @@ int lastOutput = 0; bool checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& ass, bool& bad); +ASTNode +withNF(const ASTNode &n) +{ + if (n.isAtom()) + return n; + + ASTVec c; + for (int i=0; i< n.Degree();i++) + c.push_back(withNF(n[i])); + + if (n.GetType() == BOOLEAN_TYPE) + return nf->CreateNode(n.GetKind(), c); + else + return nf->CreateArrayTerm(n.GetKind(), n.GetIndexWidth(), n.GetValueWidth(), c); +} + + ASTNode renameVars(const ASTNode &n) { @@ -226,15 +246,16 @@ eval(const ASTNode &n, ASTNodeMap& map, int count = 0) // We have an array of arrays already created to store the children. // This reduces the number of objects created/destroyed. if (count >= saved_array.size()) - saved_array.push_back(ASTVec()); + saved_array.push_back(new ASTVec()); - ASTVec& new_children = saved_array[count]; + ASTVec& new_children = *saved_array[count]; new_children.clear(); for (int i = 0; i < n.Degree(); i++) new_children.push_back(eval(n[i], map, count + 1)); ASTNode r = NonMemberBVConstEvaluator(mgr, n.GetKind(), new_children, n.GetValueWidth()); + new_children.clear(); map.insert(make_pair(n, r)); return r; } @@ -284,7 +305,7 @@ checkProp(const ASTNode& n) // True if it's always true. Otherwise fills the assignment. bool -isConstant(const ASTNode& n, VariableAssignment& different) +isConstant(const ASTNode& n, VariableAssignment& different, const int bits) { if (isConstantToSat(n)) return true; @@ -293,11 +314,10 @@ isConstant(const ASTNode& n, VariableAssignment& different) mgr->ValidFlag = false; vector symbols = getVariables(n); - assert(symbols.size() > 0); // Both of them might not be contained in the assignment. - different.setV(mgr->CreateZeroConst(symbols[0].GetValueWidth())); - different.setW(mgr->CreateZeroConst(symbols[0].GetValueWidth())); + different.setV(mgr->CreateZeroConst(bits)); + different.setW(mgr->CreateZeroConst(bits)); // It might have been widened. for (int i = 0; i < symbols.size(); i++) @@ -398,15 +418,39 @@ widen(const ASTNode& w, int width) bool orderEquivalence(ASTNode& from, ASTNode& to) { + if(from.IsNull()) + return false; + if(from.GetKind() == UNDEFINED) + return false; + if(to.IsNull()) + return false; + if(to.GetKind() == UNDEFINED) + return false; + + { + ASTVec c; + c.push_back(from); + c.push_back(to); + ASTNode w = widen(mgr->hashingNodeFactory->CreateNode(EQ,c), widen_to); + + if (w.IsNull() || w.GetKind() == UNDEFINED) + return false; + } + vector s_from; // The variables in the from node. ASTNodeSet visited; getVariables(from, s_from, visited); std::sort(s_from.begin(), s_from.end()); + const int from_c = visited.size(); vector s_to; // The variables in the to node. visited.clear(); getVariables(to, s_to, visited); sort(s_to.begin(), s_to.end()); + const int to_c = visited.size(); + + if (from_c > 50 || to_c > 50) + return false; // not interested in giant rules. vector result(s_to.size() + s_from.size()); // We must map from most variables to fewer variables. @@ -453,28 +497,17 @@ orderEquivalence(ASTNode& from, ASTNode& to) if (s_from.size() > s_to.size()) return true; - if (getDifficulty(from) < getDifficulty(to)) + if ((getDifficulty(from)+5) < getDifficulty(to)) { swap(from, to); return true; } - if (getDifficulty(from) > getDifficulty(to)) + if (getDifficulty(from) > (getDifficulty(to)+5)) { return true; } - // Difficulty is equal. Order based on the number of nodes. - vector symbols; - visited.clear(); - getVariables(from, symbols, visited); - int from_c = visited.size(); - - symbols.clear(); - visited.clear(); - getVariables(to, symbols, visited); - int to_c = visited.size(); - if (to_c < from_c) { return true; @@ -524,8 +557,37 @@ getDifficulty(const ASTNode& n_) ToSATBase::ASTNodeToSATVar nodeToSATVar; toCNF.toCNF(BBFormula, cnfData, nodeToSATVar, false, nm); + // Send the clauses to Minisat, do unit propagation. + /////////////// + + // Create a new sat variable for each of the variables in the CNF. + assert(ss->nVars() ==0); + for (int i = 0; i < cnfData->nVars ; i++) + ss->newVar(); + + SATSolver::vec_literals satSolverClause; + for (int i = 0; i < cnfData->nClauses; i++) + { + satSolverClause.clear(); + for (int * pLit = cnfData->pClauses[i], *pStop = cnfData->pClauses[i + + 1]; pLit < pStop; pLit++) + { + SATSolver::Var var = (*pLit) >> 1; + assert ((var < ss->nVars())); + Minisat::Lit l = SATSolver::mkLit(var, (*pLit) & 1); + satSolverClause.push(l); + } + + ss->addClause(satSolverClause); + } + + ss->simplify(); + assert (ss->okay()); // should be satisfiable. + // Why we go to all this trouble. The number of clauses. - int score = cnfData->nClauses; + const int score = ss->nClauses(); + assert(score <= cnfData->nClauses); + ////////////// //Cnf_ClearMemory(); Cnf_DataFree(cnfData); @@ -586,6 +648,14 @@ void do_write_out(int ignore) force_writeout = true; } +volatile bool debug_usr2 = false; + +//toggle. +void do_usr2(int ignore) +{ + debug_usr2=!debug_usr2; +} + int startup() @@ -609,13 +679,9 @@ startup() GlobalSTP = new STP(mgr, simp, at, tosat, abs); -#ifndef NOTSIMPLIFYING_NF - nf = new SimplifyingNodeFactory(*(mgr->hashingNodeFactory), *mgr); - mgr->defaultNodeFactory = nf; -#else - nf = mgr->hashingNodeFactory; - mgr->defaultNodeFactory = mgr->hashingNodeFactory; -#endif + simpNf = new SimplifyingNodeFactory(*(mgr->hashingNodeFactory), *mgr); + nf = new TypeChecker(*simpNf, *mgr); + mgr->defaultNodeFactory = simpNf; mgr->UserFlags.stats_flag = false; mgr->UserFlags.optimize_flag = true; @@ -625,7 +691,7 @@ startup() // Prime the cache with 100.. for (int i = 0; i < 100; i++) { - saved_array.push_back(ASTVec()); + saved_array.push_back(new ASTVec()); } zero = mgr->CreateZeroConst(bits); @@ -643,6 +709,8 @@ startup() // Write out the work so far.. signal(SIGUSR1,do_write_out); + signal(SIGUSR2,do_usr2); + } void @@ -807,6 +875,15 @@ findRewrites(ASTVec& expressions, const vector& values, cons HASHMAP::iterator it2; cout << "Split into " << map.size() << " pieces\n"; + if (depth > 0) + { + if(map.size() ==1) + { + cerr << values[0].getV(); + cerr << values[0].getW(); + assert(false); + } + } int id = 1; for (it2 = map.begin(); it2 != map.end(); it2++) @@ -820,7 +897,6 @@ findRewrites(ASTVec& expressions, const vector& values, cons } ASTVec& equiv = expressions; - // Sort so that constants, and smaller expressions will be checked first. std::sort(equiv.begin(), equiv.end(), lessThan); @@ -837,39 +913,46 @@ findRewrites(ASTVec& expressions, const vector& values, cons if (equiv[i].GetKind() == UNDEFINED || equiv[j].GetKind() == UNDEFINED) continue; - equiv[j] = rewrite_system.rewriteNode(equiv[j]); - - ASTNode from = equiv[i]; - ASTNode to = equiv[j]; - bool r = orderEquivalence(from, to); - - if (!r) - { - if (from != to) - cout << "can't be ordered" << from << to; - continue; - } + ASTNode& from = equiv[i]; + ASTNode& to = equiv[j]; VariableAssignment different; bool bad = false; const int st = getCurrentTime(); + if (from.isConstant() && to.isConstant()) + continue; + if (checkRule(from, to, different, bad)) { - cout << "Discovered a new rule."; - cout << from; - cout << to; - cout << getDifficulty(from) << " to " << getDifficulty(to) << endl; - cout << "After rewriting"; - cout << rewrite_system.rewriteNode(from); - cout << rewrite_system.rewriteNode(to); - cout << "------"; + equiv[i] = rewrite_system.rewriteNode(equiv[i]); + equiv[j] = rewrite_system.rewriteNode(equiv[j]); - Rewrite_rule rr(mgr, from, to, getCurrentTime() - st); + equiv[i] = rewriteThroughWithAIGS(equiv[i]); + equiv[j] = rewriteThroughWithAIGS(equiv[j]); - if (!rr.timedCheck(10000)) + ASTNode& f = equiv[i]; + ASTNode& t = equiv[j]; + + bool r = orderEquivalence(f, t); + + if (!r) continue; + cout << "Discovered a new rule."; + cout << f << t; + cout << getDifficulty(f) << " to " << getDifficulty(t) << endl; + + Rewrite_rule rr(mgr, f, t, getCurrentTime() - st); + + if (!rr.timedCheck(1000)) + { + cout << "Rule failed extended verification."; + continue; + } + cout << "Verified Rule to: " << rr.getVerifiedToBits() << " bits" << endl; + cout << "------"; + rewrite_system.push_back(rr); // Remove the more difficult expression. @@ -898,7 +981,7 @@ findRewrites(ASTVec& expressions, const vector& values, cons } // Write out the rules intermitently. - if (force_writeout || lastOutput + 5000 < rewrite_system.size()) + if (force_writeout || lastOutput + 500 < rewrite_system.size()) { rewrite_system.rewriteAll(); writeOutRules(); @@ -1038,6 +1121,8 @@ checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignme ASTVec children; children.push_back(from); children.push_back(to); + + // The simplifying node factory sometimes meant it couldn't be widended. const ASTNode n = mgr->hashingNodeFactory->CreateNode(EQ, children); assert(widen_to > bits); @@ -1056,7 +1141,7 @@ checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignme } // Send it to the SAT solver to verify that the widening has the same answer. - bool result = isConstant(widened, assignment); + bool result = isConstant(widened, assignment, bits); if (!result) { @@ -1067,8 +1152,8 @@ checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignme } // Detected it's not a constant, or is constant FALSE. - cout << "*" << i - bits << "*"; + cout << "*" << i - bits << "*"; return false; } } @@ -1286,6 +1371,7 @@ template void writeOutRules() { + cerr << "Writing out: " << rewrite_system.size() << " rules" << endl; force_writeout = false; std::vector output; @@ -1293,12 +1379,6 @@ writeOutRules() for (Rewrite_system::RewriteRuleContainer::iterator it = rewrite_system.toWrite.begin() ; it != rewrite_system.toWrite.end(); it++) { - if (!it->isOK()) - { - rewrite_system.erase(it--); - continue; - } - ASTNode to = it->getTo(); ASTNode from = it->getFrom(); @@ -1344,22 +1424,19 @@ writeOutRules() if (dup.find(sofar) != dup.end()) { - cout << "-----"; + cout << "-----Writing out has found a duplicate rule-----"; cout << sofar; ASTNode f = it->getFrom(); - cout << f << std::endl; - cout << dup.find(sofar)->second.getFrom(); + cout << "This:" << f << std::endl; + cout << "Has the same text as this: " << dup.find(sofar)->second.getFrom(); + cout << "Rule " << it->getId() << " has the same text as " << dup.find(sofar)->second.getId() << endl; ASTNodeMap fromTo; - - cerr << f; f = renameVars(f); - //cerr << "renamed" << f; - bool result = unifyNode(f,dup.find(sofar)->second.getFrom(),fromTo,2) ; - cout << "unified" << result << endl; + bool result = commutative_matchNode(f,dup.find(sofar)->second.getFrom(),fromTo,2) ; + cout << "Has it unified:" << result << endl; ASTNodeMap seen; - cout << rewrite(f,*it,seen ); // The text of this rule is the same as another rule. rewrite_system.erase(it--); @@ -1381,8 +1458,13 @@ writeOutRules() bucket("n.GetKind() ==", output, buckets); ofstream outputFile; + + // Because we output the difficulty (i.e. the number of CNF clauses), + // this is very slow. + #ifdef OUTPUT_CPP_RULES outputFile.open("rewrite_data_new.cpp", ios::trunc); + // output the C++ code. hash_map, hashF >::const_iterator it; for (it = buckets.begin(); it != buckets.end(); it++) @@ -1396,6 +1478,7 @@ writeOutRules() outputFile << "}" << endl; } outputFile.close(); + #endif /////////////// outputFile.open("rules_new.smt2", ios::trunc); @@ -1429,40 +1512,44 @@ rename_then_rewrite(ASTNode n, const Rewrite_rule& original_rule) { n = renameVars(n); ASTNodeMap seen; - n = rewrite(n,original_rule,seen); + n = rewrite(n,original_rule,seen,0); return renameVarsBack(n); } // assumes the variables in n are two characters wide. ASTNode -rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen) +rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen, int depth) { - if (n.isAtom()) + if (depth > 50) return n; - // nb. won't rewrite through EQ etc. - if (n.GetType() != BITVECTOR_TYPE) + if (n.isAtom()) return n; ASTVec v; + v.reserve(n.Degree()); for (int i = 0; i < n.Degree(); i++) - v.push_back(rewrite(n[i],original_rule,seen)); + v.push_back(rewrite(n[i],original_rule,seen,depth+1)); assert(v.size() > 0); ASTNode n2; if (v!=n.GetChildren()) - n2 = mgr->CreateTerm(n.GetKind(), n.GetValueWidth(), v); + { + if (n.GetType() != BOOLEAN_TYPE) + n2 = mgr->CreateArrayTerm(n.GetKind(),n.GetIndexWidth(), n.GetValueWidth(), v); + else + n2 = mgr->CreateNode(n.GetKind(), v); + } else n2 = n; ASTNodeMap fromTo; - vector& rr = - n[0].Degree() > 0 ? - (rewrite_system.kind_kind_to_rr[n.GetKind()][n[0].GetKind()]) : - (rewrite_system.kind_to_rr[n.GetKind()]) ; + if (rewrite_system.lookups_invalid) + rewrite_system.buildLookupTable(); + vector& rr = rewrite_system.kind_to_rr[n.GetKind()]; for (int i = 0; i < rr.size(); i++) { @@ -1475,42 +1562,52 @@ rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen) const ASTNode& f = rr[i].getFrom(); - if (unifyNode(f, n2, fromTo,1)) + if (commutative_matchNode(f, n2, fromTo, 1)) { - /* - cerr << "Unifying" << f; - cerr << "with:" << n2; + if (debug_usr2) + { + cerr << "Original Rule(" << original_rule.getId() << ")"; + cerr << original_rule.getFrom(); + cerr << "->" << original_rule.getTo(); - cerr << "Now" << rr[i].getTo(); - cerr << "reducing rule" << rr[i].getN(); + cerr << "Matching Rule(" << rr[i].getId() << ")"; + cerr << rr[i].getFrom(); + cerr << "->" << rr[i].getTo(); - for (ASTNodeMap::iterator it = fromTo.begin(); it != fromTo.end(); it++) - { - cerr << it->first << "to" << it->second << endl; - } + cerr << "--------------"; + cerr << "Unifying" << f; + cerr << "with:" << n2; + cerr << "--------------"; - cerr << "--------------"; - */ + for (ASTNodeMap::iterator it = fromTo.begin(); it != fromTo.end(); it++) + { + cerr << it->first << "to" << it->second << endl; + } + + cerr << "--------------"; + } if (seen.find(n) != seen.end()) return seen.find(n)->second; - seen.insert(make_pair(n,rr[i].getTo())); + seen.insert(make_pair(n, rr[i].getTo())); ASTNodeMap cache; - ASTNode r= SubstitutionMap::replace(rr[i].getTo(), fromTo, cache, nf, false, true); + ASTNode r = SubstitutionMap::replace(rr[i].getTo(), fromTo, cache, nf, false, true); seen.erase(n); - seen.insert(make_pair(n,r)); - r= rewrite(r,original_rule,seen); + seen.insert(make_pair(n, r)); + r = rewrite(r, original_rule, seen, depth + 1); seen.erase(n); - seen.insert(make_pair(n,r)); + seen.insert(make_pair(n, r)); return r; + } } return n2; } + int smt2_scan_string(const char *yy_str); // http://stackoverflow.com/questions/3418231/c-replace-part-of-a-string-with-another-string @@ -1524,11 +1621,10 @@ bool replace(std::string& str, const std::string& from, const std::string& to) { void -loadNewRules() +load_new_rules(const string fileName = "rules_new.smt2") { FILE * in; bool opended= false; - string fileName = "rules_new.smt2" if(!ifstream(fileName.c_str())) /// use stdin if the default file is not found. in = stdin; @@ -1538,7 +1634,7 @@ loadNewRules() opended = true; // so we know to fclose it. } - // We store references to "v" and "w", so we need to removed the + // We store references to "v" and "w", so we need to remove the // definitions from the input we parse. v = mgr->LookupOrCreateSymbol("v"); @@ -1576,6 +1672,7 @@ loadNewRules() int rv = sscanf(line, ";id:%d\tverified_to:%d\ttime:%d\tfrom_difficulty:%d\tto_difficulty:%d\n", &id, &verified_to_bits, &time_used, &from_v, &to_v); if (rv !=5) { + cerr << line; done = true; break; } @@ -1602,8 +1699,11 @@ loadNewRules() assert(values.size() ==1); - ASTNode from = values[0][0]; - ASTNode to = values[0][1]; + // The nodes have been built with the hashing node factory. + // In practice we want to match nodes that are created with the simplifying NF. + // If we enabled the simplifying NF, the EQUALS checks would remove rules we want to keep. + ASTNode from = withNF(values[0][0]); + ASTNode to = withNF(values[0][1]); // Rule should be orderable. bool ok = orderEquivalence(from, to); @@ -1612,42 +1712,79 @@ loadNewRules() cout << "discarding rule that can't be ordered"; cout << from << to; cout << "----"; + mgr->PopQuery(); + parserInterface->popToFirstLevel(); continue; } Rewrite_rule r(mgr, from, to, 0, id); r.setVerified(verified_to_bits,time_used); - assert(r.isOK()); rewrite_system.push_back(r); mgr->PopQuery(); parserInterface->popToFirstLevel(); } + extern int smt2lex_destroy(void); + smt2lex_destroy(); + parserInterface->cleanUp(); if (opended) - fclose(in); + { + cout << "New Style Rules Loaded:" << rewrite_system.size() << endl; + fclose(in); + } + + // So we don't output as soon as one is discovered... + lastOutput = rewrite_system.size(); - cout << "New Style Rules Loaded:" << rewrite_system.size() << endl; } -//read from stdin, then tests it until the timeout. +//Reads in new format rules. And tests each of them for the allocated time. void -expandRules(int timeout_ms) +expandRules(int timeout_ms, const char* fileName = "") { - loadNewRules(); + load_new_rules(fileName); createVariables(); for (Rewrite_system::RewriteRuleContainer::iterator it = rewrite_system.begin(); it!= rewrite_system.end();it++) { if ((*it).timedCheck(timeout_ms)) - it->writeOut(cout); // omit failed. + { + it->writeOut(cout); // omit failed. + cerr << getDifficulty(it->getFrom()) <<" " << getDifficulty(it->getTo()); + } } } + +void t2() +{ + extern FILE *smt2in; + + smt2in = fopen("big_array.smt2", "r"); + TypeChecker nfTypeCheckDefault(*mgr->hashingNodeFactory, *mgr); + Cpp_interface piTypeCheckDefault(*mgr, &nfTypeCheckDefault); + parserInterface = &piTypeCheckDefault; + + mgr->GetRunTimes()->start(RunTimes::Parsing); + smt2parse(); + + ASTVec v = FlattenKind(AND, piTypeCheckDefault.GetAsserts()); + ASTNode n = nf->CreateNode(XOR, v); + + //rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen) + ASTNodeMap seen; + createVariables(); + ASTNode r = rename_then_rewrite(n,Rewrite_rule::getNullRule()); + cerr << r; + + +} + // loads the already existing rules. -void loadExistingRules(string fileName) +void load_old_rules(string fileName) { if(!ifstream(fileName.c_str())) return; // file doesn't exist. @@ -1693,8 +1830,7 @@ void loadExistingRules(string fileName) Rewrite_rule r(mgr, from, to, 0); - if (r.isOK()); - rewrite_system.push_back(r); + rewrite_system.push_back(r); } mgr->PopQuery(); @@ -1733,7 +1869,7 @@ testProps() int test() { // Test code. - loadExistingRules("test.smt2"); + load_old_rules("test.smt2"); v = mgr->LookupOrCreateSymbol("v"); v.SetValueWidth(bits); @@ -1768,14 +1904,43 @@ createVariables() w0.SetValueWidth(bits); } +void unit_test() +{ + + // Create the negation and not terms in different orders. This tests the commutative matching. + ASTVec c; + c.push_back(v); + ASTNode not_v = create(BEEV::BVNEG, c); + ASTNode neg_v = create(BEEV::BVUMINUS, c); + + ASTNode plus_v = create(BVPLUS,not_v,neg_v); + + c.clear(); + c.push_back(w); + ASTNode neg_w = create(BEEV::BVUMINUS, c); + ASTNode not_w = create(BEEV::BVNEG, c); + ASTNode plus_w = create(BVPLUS,not_w,neg_w); + + ASTNodeMap sub; + plus_w = renameVars(plus_w); + assert(commutative_matchNode(plus_w,plus_v,sub,2)); + sub.clear(); + + assert(commutative_matchNode(plus_v,plus_w,sub,1)); + + +} + + int main(int argc, const char* argv[]) { startup(); + if (argc == 1) // Read the current rule set, find new rules. { - loadExistingRules("array.smt2"); + load_new_rules(); createVariables(); //////////// rewrite_system.buildLookupTable(); @@ -1795,33 +1960,54 @@ main(int argc, const char* argv[]) rewrite_system.rewriteAll(); writeOutRules(); } - else if (argc == 3 && !strcmp("expand",argv[1])) // expand the bit-widths rules are tested at. + else if (argc ==2 && !strcmp("unit-test",argv[1])) { + load_new_rules(); + createVariables(); + unit_test(); + } + else if (argc ==2 && !strcmp("verify",argv[1])) + { + load_new_rules(); + rewrite_system.verifyAllwithSAT(); + } + else if ((argc == 4 || argc ==3) && !strcmp("expand",argv[1])) + { + // expand the bit-widths rules are tested at. int timeout_ms = atoi(argv[2]); assert(timeout_ms > 0); - expandRules(timeout_ms); + expandRules(timeout_ms, (argc == 4 ? argv[3]:"")); } - else if (argc == 2 && !strcmp("verify-all",argv[1])) + else if (argc == 2 && !strcmp("rewrite",argv[1])) { - loadNewRules(); + // load the rules and apply the rewrite system to itself. + load_new_rules(); createVariables(); - rewrite_system.verifyAllwithSAT(); - writeOutRules(); // have the times now.. + rewrite_system.rewriteAll(); + writeOutRules(); } else if (argc == 2 && !strcmp("write-out",argv[1])) { - loadNewRules(); + load_new_rules(); createVariables(); rewrite_system.rewriteAll(); writeOutRules(); // have the times now.. } + else if (argc == 2 && !strcmp("renumber",argv[1])) + { + // Intended to merge two sets of rules, then renumber them. + load_new_rules(); + createVariables(); + rewrite_system.renumber(); + writeOutRules(); + } else if (argc == 2 && !strcmp("test",argv[1])) { testProps(); } else if (argc == 2 && !strcmp("delete-failed",argv[1])) { - loadNewRules(); + load_new_rules(); ifstream fin; fin.open("failed.txt",ios::in); char line[256]; @@ -1836,13 +2022,20 @@ main(int argc, const char* argv[]) createVariables(); writeOutRules(); } -} - + else if (argc == 2 && !strcmp("test2",argv[1])) + { + load_new_rules(); + t2(); + } + for (int i=0; i< saved_array.size();i++) + delete saved_array[i]; +} +#if 0 // Term variables have a specified width!!! bool -unifyNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width) +matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width) { // Pointers to the same value. OK. if (n0 == n1) @@ -1851,10 +2044,10 @@ unifyNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int t if (n0.GetKind() == SYMBOL && strlen(n0.GetName()) == term_variable_width) { if (fromTo.find(n0) != fromTo.end()) - return unifyNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width); + return matchNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width); fromTo.insert(make_pair(n0, n1)); - return unifyNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width); + return matchNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width); } // Here: @@ -1869,9 +2062,319 @@ unifyNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int t for (int i = 0; i < n0.Degree(); i++) { - if (!unifyNode(n0[i], n1[i], fromTo, term_variable_width)) - return false; + if (!matchNode(n0[i], n1[i], fromTo, term_variable_width)) + return false; } return true; } +#endif + + +bool debug_matching = false; + +///////// +// Term variables have a specified width!!! +// "false" if it definately can't be matched with any possible commutative ordering. +// "true" can be matched, later you need to check if all the "commutative" can be matched. +bool +commutative_matchNode(const ASTNode& n0, const ASTNode& n1, const int term_variable_width, + deque >& commutative, ASTNode& vNode, ASTNode& wNode) +{ + // Pointers to the same value. OK. + if (n0 == n1) + return true; + + // If we try and match sub-terms of concatenations,e,g. 000::x = 000111, we want it to fail. + if(n0.GetValueWidth() != n1.GetValueWidth()) + return false; + + if (n0.GetKind() == SYMBOL && strlen(n0.GetName()) == term_variable_width) + { + if (n0.GetName()[0] == 'v') + { + if (vNode != mgr->ASTUndefined) + return commutative_matchNode(vNode, n1, term_variable_width, commutative, vNode, wNode); + else + { + vNode = n1; + return true; + } + } + else if (n0.GetName()[0] == 'w') + { + if (wNode != mgr->ASTUndefined) + return commutative_matchNode(wNode, n1, term_variable_width, commutative, vNode, wNode); + else + { + wNode = n1; + return true; + } + } + else + FatalError("nefeafs"); + } + + // Here: + // They could be different BVConsts, different symbols, or + // different functions. + + if (n0.Degree() != n1.Degree() || (n0.Degree() == 0)) + return false; + + if (n0.GetKind() != n1.GetKind()) + return false; + + // If it's commutative, check it specially / seprately later. + if (isCommutative(n0.GetKind()) && n0.Degree() > 1) + { + commutative.push_back(make_pair(n0,n1)); + return true; + } + else + { + for (int i = 0; i < n0.Degree(); i++) + { + if (!commutative_matchNode(n0[i], n1[i], term_variable_width,commutative,vNode,wNode)) + return false; + } + } + return true; +} + +// +// Term variables have a specified width!!! +bool +c_matchNode(const ASTNode& n0, const ASTNode& n1, const int term_variable_width, + deque >& commutative_to_check, ASTNode& vNode, ASTNode& wNode) +{ + ASTNode vNode_copy = vNode; + ASTNode wNode_copy = wNode; + + const int init_comm_size = commutative_to_check.size(); + + bool r = commutative_matchNode(n0, n1, term_variable_width, commutative_to_check, vNode,wNode); + assert(commutative_to_check.size() >= init_comm_size); // if anything, only pushed onto the back. + + if (debug_matching) + { + cerr << "======Commut-match=======" << r << endl; + cerr << "given" << n0 << n1; + cerr << "Commutative still to match:" << endl; + for (int j=0;j < commutative_to_check.size(); j++) + { + cerr << "++++++++++" << endl; + cerr << "first" << commutative_to_check[j].first; + cerr << "second" << commutative_to_check[j].second; + } + cerr << "From To Map is:" << endl; + cerr << "vNode" << vNode; + cerr << "wNode" << wNode; + cerr << "============="; + } + + if (!r) + { + // If it's bad we restore it all back. + commutative_to_check.erase(commutative_to_check.begin() + init_comm_size, commutative_to_check.end()); + vNode = vNode_copy; + wNode = wNode_copy; + return false; + } + + // base case. + if (commutative_to_check.size() == 0) + return r; + + pair p = commutative_to_check.back(); + commutative_to_check.pop_back(); + assert(p.first.GetKind() == p.second.GetKind()); + const ASTVec& f = p.first.GetChildren(); + ASTVec s = p.second.GetChildren(); // non-const, needs to be sorted later. + + if (f.size() != s.size()) + { + cerr << "different sized!!!"; + // If it's bad we restore it all back. + if (commutative_to_check.size() < init_comm_size) + commutative_to_check.push_back(p); + else + commutative_to_check.erase(commutative_to_check.begin() + init_comm_size, commutative_to_check.end()); + + vNode = vNode_copy; + wNode = wNode_copy; + + return false; + } + + // The next_permutation function requires this. + sort(s.begin(), s.end()); + + ASTNode vNode_copy2 = vNode; + ASTNode wNode_copy2 = wNode; + + //deque > todo_copy2 = commutative_to_check; + const int new_comm_size = commutative_to_check.size(); + + + // Check each permutation of the commutative operation's operands. + do + { + // Check each of the operands matches. Store Extra away in "commutative_to_check". + bool good= true; + for (int i=0;i < f.size(); i++) + { + if (!commutative_matchNode(f[i], s[i], term_variable_width, commutative_to_check, vNode, wNode)) + { + good = false; + break; + } + } + + // Empty out the "commutative_to_check". + if (good) + if (!c_matchNode(mgr->ASTTrue, mgr->ASTTrue, term_variable_width, commutative_to_check,vNode,wNode)) + good =false; + + if (good) + { + assert(commutative_to_check.size() ==0); + return true; // all match. + } + else + { + vNode = vNode_copy2; + wNode = wNode_copy2; + commutative_to_check.erase(commutative_to_check.begin() + new_comm_size, commutative_to_check.end()); + //assert(commutative_to_check == todo_copy2); + //commutative_to_check = todo_copy2; + } + } + while (next_permutation(s.begin(), s.end())); + + // None of the permutations match. We return the data unchanged. + + vNode = vNode_copy; + wNode = wNode_copy; + + + if (commutative_to_check.size() < init_comm_size) + commutative_to_check.push_back(p); + else + commutative_to_check.erase(commutative_to_check.begin() + init_comm_size, commutative_to_check.end()); + + + return false; +} + +/* This does commutative matching of nodes. A substitution to the term variables (which are the + * those with a name of the width specified), of n0 is found. That is if the variables of n0 are + * substituted with the "substitution", then it will equal n1. + * + * Initially I thought commutative matching was easy to get right. NO! + * + * NB: This uses a "statc" container so this can't be called recursively. + */ +bool +commutative_matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& substitution, const int term_variable_width) +{ + assert(substitution.size() ==0); + +#ifdef PEDANTIC_MATCHING_ASSERTS + { + // There shouldn't be any term variables on the RHS. + vector vars = getVariables(n1); + vector::iterator it = vars.begin(); + while (it != vars.end()) + { + assert(strlen(it->GetName()) != term_variable_width); + assert(it->GetName()[0] == 'v' || it->GetName()[0] == 'w'); + it++; + } + assert(vars.size() <=2); + + // All the LHS variables should be term variables. + vars = getVariables(n0); + it = vars.begin(); + while (it != vars.end()) + { + assert(strlen(it->GetName()) == term_variable_width); + it++; + } + assert(vars.size() <=2); + } + + +#endif + + deque > commutative; + commutative.clear(); + + ASTNode vNode = mgr->ASTUndefined; + ASTNode wNode = mgr->ASTUndefined; + bool r = c_matchNode(n0,n1,term_variable_width,commutative,vNode,wNode); + + if (r) + { + vector s = getVariables(n0); + for (vector::iterator it = s.begin(); it != s.end();it++) + { + assert(it->GetKind() ==SYMBOL); + assert(strlen(it->GetName()) == term_variable_width); + if (it->GetName()[0] == 'v') + { + assert(vNode != mgr->ASTUndefined); + assert(vNode.GetValueWidth() == it->GetValueWidth()); + substitution.insert(make_pair(*it,vNode)); + } + if (it->GetName()[0] == 'w') + { + assert(wNode != mgr->ASTUndefined); + assert(wNode.GetValueWidth() == it->GetValueWidth()); + substitution.insert(make_pair(*it,wNode)); + } + } + } + + if (debug_matching) + { + cerr << "=======" << endl << "The result is: " << r << "for the inputs" << n0 << n1 << "=-==="; + } + + if (!r) + { + assert(substitution.size() == 0); + assert(commutative.size() ==0); // should be none left to process. + } + + return r; +} + +ASTNode +rewriteThroughWithAIGS(const ASTNode &n_) +{ + assert(n_.GetType() == BITVECTOR_TYPE); + ASTNode f = mgr->LookupOrCreateSymbol("rewriteThroughWithAIGS"); + f.SetValueWidth(n_.GetValueWidth()); + ASTNode n = create(EQ, n_, f); + + BBNodeManagerAIG nm; + BitBlaster bb(&nm, simp, mgr->defaultNodeFactory, &mgr->UserFlags); + ASTNodeMap fromTo; + ASTNodeMap equivs; + bb.getConsts(n, fromTo, equivs); + + ASTNode result = n_; + if (equivs.size() > 0) + { + ASTNodeMap cache; + result = SubstitutionMap::replace(result, equivs, cache, nf, false, true); + } + + if (fromTo.size() > 0) + { + ASTNodeMap cache; + result = SubstitutionMap::replace(result, fromTo, cache, nf); + } + return result; +} diff --git a/src/util/find_rewrites/rewrite_rule.h b/src/util/find_rewrites/rewrite_rule.h index b97fb0b..cab1845 100644 --- a/src/util/find_rewrites/rewrite_rule.h +++ b/src/util/find_rewrites/rewrite_rule.h @@ -2,15 +2,7 @@ #define REWRITERULE_H #include "../../STPManager/STPManager.h" - -extern const int widen_to; -extern const int bits; - -ASTNode -widen(const ASTNode& w, int width); - -int -getDifficulty(const ASTNode& n_); +#include "misc.h" void soft_time_out(int ignored) { @@ -18,37 +10,46 @@ void soft_time_out(int ignored) } bool -isConstant(const ASTNode& n, VariableAssignment& different); +orderEquivalence(ASTNode& from, ASTNode& to); -vector -getVariables(const ASTNode& n); - class Rewrite_rule { -private: ASTNode from; ASTNode to; ASTNode n; - int id; static int static_id; int time_to_verify; int verified_to_bits; + // Only used to build the NULL rule + Rewrite_rule() + { + from = mgr->CreateZeroConst(1); + to = mgr->CreateZeroConst(1); + n = mgr->ASTTrue; + } + public: - void writeOut(ostream& outputFileSMT2) + static Rewrite_rule + getNullRule() + { + return Rewrite_rule(); + } + + int id; + void writeOut(ostream& outputFileSMT2) const { - assert(isOK()); outputFileSMT2 << ";id:" << getId() << "\tverified_to:" << verified_to_bits << "\ttime:" << getTime() << "\tfrom_difficulty:" << getDifficulty(getFrom()) << "\tto_difficulty:" << getDifficulty(getTo()) << "\n"; outputFileSMT2 << "(push 1)" << endl; - printer::SMTLIB2_PrintBack(outputFileSMT2, getN(), true, false); + printer::SMTLIB2_PrintBack(outputFileSMT2, getN(), true); outputFileSMT2 << "(exit)" << endl; } @@ -64,7 +65,7 @@ public: } int - getVerifiedToBits() + getVerifiedToBits() const { return verified_to_bits; } @@ -112,27 +113,7 @@ public: return (n == t.n); } - bool - isOK() - { - ASTNode w = widen(getN(), widen_to); - - if (w.IsNull() || w.GetKind() == UNDEFINED) - return false; - - assert(BVTypeCheckRecursive(n)); - assert(BVTypeCheckRecursive(w)); - - if (from.isAtom() && to.isAtom()) - return false; - - if (from == to) - return false; - - return true; - - } - + // The "from" and "to" should be ordered with the orderEquivalence function. Rewrite_rule(BEEV::STPMgr* bm, const BEEV::ASTNode& from_, const BEEV::ASTNode& to_, const int t, int _id=-1) : from(from_), to(to_) { @@ -151,49 +132,10 @@ public: c.push_back(from_); n = bm->hashingNodeFactory->CreateNode(BEEV::EQ,c); - - //// - assert(!from.IsNull()); - assert(from.GetKind() != UNDEFINED); - - //// - assert(!to.IsNull()); - assert(to.GetKind() != UNDEFINED); - - //// - assert(!n.IsNull()); - assert(n.GetKind() != UNDEFINED); - //// - - if (from.GetKind() == SYMBOL) - { - assert(to == from); // If it's a symbol. It should be the same one. - } - - if (from.isAtom()) - { - assert(to.isAtom()); // sometimes its easiest to make it 0->0 rather than deleting it. - } - - // only v or w - vector s_from= getVariables(from); - for (vector::iterator it = s_from.begin(); it != s_from.end() ;it++) - { - assert(strlen(it->GetName()) ==1); - assert(it->GetName()[0] =='v' || it->GetName()[0] =='w'); - assert(it->GetValueWidth() == bits); - } - - vector s_to= getVariables(to); - for (vector::iterator it = s_to.begin(); it != s_to.end() ;it++) - { - assert(strlen(it->GetName()) ==1); - assert(it->GetName()[0] =='v' || it->GetName()[0] =='w'); - assert(it->GetValueWidth() == bits); - } - - // The "to" side should have fewer nodes. - assert(s_from.size() >= s_to.size()); + assert(orderEquivalence(from,to)); + assert(from == from_); + assert(to == to_); + assert(BVTypeCheckRecursive(n)); } bool @@ -236,12 +178,16 @@ public: cerr << from << to; } - bool result = isConstant(widened, assignment); + bool result = isConstant(widened, assignment,i); if (!result && !mgr->soft_timeout_expired) { // not a constant, and not timed out! cerr << "FAILED:" << getId() << endl << i << from << to; writeOut(cerr); + + // The timer might not have expired yet. + setitimer(ITIMER_VIRTUAL, NULL, NULL); + mgr->soft_timeout_expired = false; return false; } if (mgr->soft_timeout_expired) @@ -253,6 +199,9 @@ public: if (getVerifiedToBits() <= checked_to) setVerified(checked_to, getTime() + (getCurrentTime() - st)); + // The timer might not have expired yet. + setitimer(ITIMER_VIRTUAL, NULL, NULL); + mgr->soft_timeout_expired = false; return true; } diff --git a/src/util/find_rewrites/rewrite_system.h b/src/util/find_rewrites/rewrite_system.h index 7c9c6c4..41c3a83 100644 --- a/src/util/find_rewrites/rewrite_system.h +++ b/src/util/find_rewrites/rewrite_system.h @@ -23,13 +23,17 @@ ASTNode widen(const ASTNode& w, int width); bool -unifyNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo); +matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width); + +bool +commutative_matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width); + ASTNode renameVars(const ASTNode &n); ASTNode -rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen); +rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen, int depth); bool checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignment, bool&bad); @@ -46,6 +50,10 @@ isConstantToSat(const ASTNode & query); bool isConstant(const ASTNode& n, VariableAssignment& different); +ASTNode +rewriteThroughWithAIGS(const ASTNode &n_); + + class Rewrite_system { public: @@ -59,17 +67,27 @@ private: void writeOutRules(); - - friend ASTNode rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen); + friend ASTNode rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen, int depth); RewriteRuleContainer toWrite; + std::map< Kind, vector > kind_to_rr; - std::map< Kind, std::map< Kind, vector > > kind_kind_to_rr; + bool lookups_invalid; // whether the above table is bad. + + void + addRuleToLookup(Rewrite_rule& r) + { + const ASTNode& from = r.getFrom(); + kind_to_rr[from.GetKind()].push_back(r); + assert(from.Degree() > 0); // Shouldn't map from a constant, nor from a variable. + } + public: Rewrite_system() { + lookups_invalid = false; } RewriteRuleContainer::iterator @@ -84,17 +102,18 @@ public: return toWrite.end(); } + bool areLookupsGood() + { + return lookups_invalid; + } void - addRuleToLookup(Rewrite_rule& r) + renumber() { - const ASTNode& from = r.getFrom(); - kind_to_rr[from.GetKind()].push_back(r); - - assert(from.Degree() > 0); // Shouldn't map from a constant, nor from a variable. - - if (from[0].Degree() > 0) - kind_kind_to_rr[from.GetKind()][from[0].GetKind()].push_back(r); + int id=0; + for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++) + it->id = id++; + lookups_invalid=true; } void @@ -106,6 +125,7 @@ public: { toWrite.erase(it--); cerr << "matched" << id; + lookups_invalid = true; } } } @@ -114,41 +134,26 @@ public: buildLookupTable() { kind_to_rr.clear(); - kind_kind_to_rr.clear(); for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++) { addRuleToLookup(*it); } - } - - void - removeBad() - { - for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++) - { - if (!it->isOK()) - { - cout << "Removing Rule that is bad"; - cout << it->getFrom(); - cout << it->getTo(); - cout << "----\n"; - - toWrite.erase(it--); - } - } + lookups_invalid =false; } void eraseDuplicates() { - removeBad(); toWrite.sort(); toWrite.unique(); + lookups_invalid = true; } + // Rewrite the "from" and "To", add the rule if it's still good. + // NB: Doesn't rebuild the lookup table. void - push_back(Rewrite_rule rr) + push_back(Rewrite_rule& rr) { toWrite.push_back(rr); addRuleToLookup(rr); @@ -158,6 +163,7 @@ public: erase(RewriteRuleContainer::iterator it) { toWrite.erase(it); + lookups_invalid = true; } int @@ -166,10 +172,9 @@ public: return toWrite.size(); } - ASTNode rewriteNode(ASTNode n) + static ASTNode rewriteNode(ASTNode n) { - Rewrite_rule null_rule( Rewrite_rule(mgr,mgr->CreateZeroConst(1), mgr->CreateZeroConst(1),0)); - return rename_then_rewrite(n,null_rule); + return rename_then_rewrite(n,Rewrite_rule::getNullRule()); } void @@ -186,76 +191,89 @@ public: if (i % 1000 == 0) cout << "rewrite all:" << i << " of " << toWrite.size() << endl; - // if not OK, should have been removed during duplicates. - // shouldn't add extra rules that aren't ok. - assert (it->isOK()); + ASTNode from_wide = renameVars(it->getFrom()); + ASTNode to_wide = renameVars(it->getTo()); + + // The renamed should match the original, and vice versa. + { + ASTNodeMap fromTo; + assert(commutative_matchNode(from_wide, it->getFrom(),fromTo,2)); + fromTo.clear(); + assert(commutative_matchNode(it->getFrom(), from_wide, fromTo, 1)); + fromTo.clear(); + assert(commutative_matchNode(to_wide, it->getTo(),fromTo,2)); + fromTo.clear(); + assert(commutative_matchNode(it->getTo(),to_wide, fromTo,1)); + } - ASTNode n = renameVars(it->getFrom()); ASTNodeMap seen; - ASTNode rewritten_from = rewrite(n, *it,seen); + ASTNode from_wide_rewritten = rewrite(from_wide, *it,seen,0); + seen = ASTNodeMap(); + ASTNode to_wide_rewritten = rewrite(to_wide, *it,seen,0); + seen = ASTNodeMap(); + + // Also apply the AIG rules. + to_wide_rewritten = rewriteThroughWithAIGS(to_wide_rewritten); - if (n != rewritten_from) + if ((from_wide != from_wide_rewritten) || (to_wide != to_wide_rewritten)) { - assert (isConstantToSat(create(EQ, rewritten_from,n))); + ASTNode from_rewritten = renameVarsBack(from_wide_rewritten); + ASTNode to_rewritten = renameVarsBack(to_wide_rewritten); + + assert(BVTypeCheckRecursive(from_rewritten)); + assert(BVTypeCheckRecursive(to_rewritten)); - rewritten_from = renameVarsBack(rewritten_from); - ASTNode to = it->getTo(); - bool ok = orderEquivalence(rewritten_from, to); + assert (isConstantToSat(create(EQ, from_wide_rewritten,from_wide))); + assert (isConstantToSat(create(EQ, to_wide_rewritten,to_wide))); + assert (isConstantToSat(create(EQ, it->getFrom(),from_rewritten))); + assert (isConstantToSat(create(EQ, it->getTo(),to_rewritten))); + + bool ok = orderEquivalence(from_rewritten, to_rewritten); if (ok) { - Rewrite_rule rr(mgr, rewritten_from, to, 0); - if (rr.isOK()) - { - cout << "Modifying Rule\n"; - cout << "Initially From"; - cout << it->getFrom(); - cout << "new From"; - cout << rewritten_from; - cout << "---"; - - *it= rr; - buildLookupTable(); // Otherwise two rules will remove each other? - } - else - { - cout << "Erasing rule"; - cout << "Initially From"; - cout << it->getFrom(); - cout << "new From"; - cout << rewritten_from; - cout << "---"; - - erase(it--); - i--; - buildLookupTable(); // Otherwise two rules will remove each other? - } + Rewrite_rule rr(mgr, from_rewritten, to_rewritten, 0); + cout << "Modifying Rule\n"; + cout << "Initially From"; + cout << it->getFrom(); + cout << "Initially To"; + cout << it->getTo(); + cout << "New From"; + cout << from_rewritten; + cout << "New To"; + cout << to_rewritten; + cout << "---"; + cout << getDifficulty(rr.getFrom()) << " --> " << getDifficulty(rr.getTo()) << endl; + cout << "replacing" << it->getId() << " with " << rr.getId() << endl; + + *it = rr; + lookups_invalid = true; + } - else - { - if (rewritten_from != to) - { - cout << "Mapped but couldn't order"; - cout << rewritten_from << to; - } + if (!ok) + { + cout << "Erasing bad rule.\n"; erase(it--); i--; - buildLookupTable(); // Otherwise two rules will remove each other? + lookups_invalid = true; } } } eraseDuplicates(); cout << "Size after rewriteAll:" << toWrite.size() << endl; + buildLookupTable(); } void clear() { toWrite.clear(); + buildLookupTable(); } void verifyAllwithSAT() { + cerr << "Started verifying all" << endl; for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++) { VariableAssignment assignment; @@ -270,12 +288,11 @@ public: assert(r); assert(!bad); } - if (bits > it->getVerifiedToBits()) + if (bits >= it->getVerifiedToBits()) it->setVerified(bits,getCurrentTime() - st); } } - void writeOut(ostream &o) {