From: trevor_hansen Date: Wed, 22 Feb 2012 12:54:53 +0000 (+0000) Subject: Cleanup the utility code for off-line generation of candidate rewrite rules. X-Git-Url: https://git.unchartedbackwaters.co.uk/w/?a=commitdiff_plain;h=7dbf0b0d2a2bca5d342ae78b7aa80ff624d15b45;p=francis%2Fstp.git Cleanup the utility code for off-line generation of candidate rewrite rules. git-svn-id: https://stp-fast-prover.svn.sourceforge.net/svnroot/stp-fast-prover/trunk/stp@1572 e59a4935-1847-0410-ae03-e826735625c1 --- diff --git a/src/util/rewrite.cpp b/src/util/rewrite.cpp index d3617c2..aeb7eb1 100644 --- a/src/util/rewrite.cpp +++ b/src/util/rewrite.cpp @@ -15,28 +15,36 @@ #include "../to-sat/AIG/ToSATAIG.h" #include "../sat/MinisatCore.h" #include "../STPManager/STP.h" +#include "../STPManager/DifficultyScore.h" +#include "../simplifier/BigRewriter.h" using namespace std; using namespace BEEV; +// Asynchronously stop solving. bool finished = false; -extern int -smtparse(void*); -extern FILE *smtin; +// Holds the rewrite that was disproved at the largest bitwidth. +ASTNode highestDisproved; +int highestLevel =0; ////////////////////////////////// -const int bits = 8; -const int widen_to = 13; +const int bits = 6; +const int widen_to = 10; const int values_in_hash = 64 / bits; const int mask = (1 << (bits)) - 1; -const int unique_vars = 1 << bits; ////////////////////////////////// // Saves a little bit of time. The vectors are saved between invocations. vector saved_array; +// Stores the difficulties that have already been generated. +map difficulty_cache; + +void +clearSAT(); + bool isConstantToSat(const ASTNode & query); @@ -46,8 +54,7 @@ containsNode(const ASTNode& n, const ASTNode& hunting, string& current); void writeOutRules(); -bool -checkAndStoreRule(const ASTNode & n); +void applyBigRewrite(ASTVec& functions); typedef HASHMAP ASTNodeString; @@ -55,10 +62,53 @@ BEEV::STPMgr* mgr; NodeFactory* nf; SATSolver * ss; ASTNodeSet stored; // Store nodes so they aren't garbage collected. +Simplifier *simp; ASTNode zero, one, maxNode, v, w; -ASTVec toWrite; // Rules to write out when we get the chance. +struct ToWrite +{ + ASTNode from; + ASTNode to; + ASTNode n; + int time; + + ToWrite() + { + } + + ToWrite(ASTNode from_, ASTNode to_, int t) + { + from = from_; + to = to_; + time = t; + n = mgr->CreateNode(EQ,to,from); + } + + bool isEmpty() + { + return (n == mgr->ASTUndefined); + } + + bool + operator==(const ToWrite t) const + { + return (n == t.n); + } + + bool + operator<(const ToWrite t) const + { + return (n < t.n); + } +}; + +// Rules to write out when we get the chance. +vector toWrite; + +// Width of the rewrite rules that were output last time. +int lastOutput = 0; + struct Assignment { @@ -95,6 +145,11 @@ public: w = nW; } + bool isEmpty() + { + return (v == mgr->ASTUndefined || w == mgr->ASTUndefined); + } + Assignment() { } @@ -105,7 +160,7 @@ public: { setV(n); srand(v.GetUnsignedConst()); - w = BEEV::ParserBM->CreateBVConst(bits, rand()); + w = BEEV::ParserBM->CreateBVConst(n.GetValueWidth(), rand()); } Assignment(ASTNode & n0, ASTNode & n1) @@ -115,12 +170,16 @@ public: } }; +bool +checkAndStoreRule(const ASTNode & from, const ASTNode & to, Assignment& ass); + + // Helper functions. Don't need to specify the width. ASTNode create(Kind k, const ASTNode& n0, const ASTNode& n1) { if (is_Term_kind(k)) - return nf->CreateTerm(k, bits, n0, n1); + return nf->CreateTerm(k, n0.GetValueWidth(), n0, n1); else return nf->CreateNode(k, n0, n1); } @@ -129,7 +188,7 @@ ASTNode create(Kind k, ASTVec& c) { if (is_Term_kind(k)) - return nf->CreateTerm(k, bits, c); + return nf->CreateTerm(k, c[0].GetValueWidth(), c); else return nf->CreateNode(k, c); } @@ -138,7 +197,7 @@ create(Kind k, ASTVec& c) // If it's a constant it's the name of the constant, // otherwise it's the position of the lhs in the rhs. Otherwise empty. string -getLHSName(const ASTNode& lhs, const ASTNode& rhs) +getToName(const ASTNode& lhs, const ASTNode& rhs) { string name = "n"; if (!lhs.isConstant()) @@ -174,11 +233,22 @@ getVariables(const ASTNode& n, vector& symbols, ASTNodeSet& visited) ASTNode eval(const ASTNode &n, ASTNodeMap& map, int count = 0) { + assert(n != mgr->ASTUndefined); + if (n.isConstant()) return n; if (map.find(n) != map.end()) - return (*map.find(n)).second; + { + assert((*map.find(n)).second != mgr->ASTUndefined); + return (*map.find(n)).second; + } + + if(n.Degree() == 0 ) + { + cerr << n; + assert(false); + } // We have an array of arrays already created to store the children. // This reduces the number of objects created/destroyed. @@ -191,7 +261,7 @@ eval(const ASTNode &n, ASTNodeMap& map, int count = 0) for (int i = 0; i < n.Degree(); i++) new_children.push_back(eval(n[i], map, count + 1)); - ASTNode r = NonMemberBVConstEvaluator(mgr, n.GetKind(), new_children, bits); + ASTNode r = NonMemberBVConstEvaluator(mgr, n.GetKind(), new_children, n.GetValueWidth()); map.insert(make_pair(n, r)); return r; } @@ -247,26 +317,61 @@ isConstant(const ASTNode& n, Assignment& different) return true; else { - different.setV(GlobalSTP->Ctr_Example->GetCounterExample(true, v)); - different.setW(GlobalSTP->Ctr_Example->GetCounterExample(true, w)); + vector symbols; + ASTNodeSet visited; + getVariables(n,symbols,visited); + assert(symbols.size() > 0); + + // Both of them might not be contained in the assignment. + different.setV(mgr->CreateBVConst(symbols[0].GetValueWidth(),0)); + different.setW(mgr->CreateBVConst(symbols[0].GetValueWidth(),0)); + + // It might have been widened. + for (int i =0; i < symbols.size();i++) + { + if (strncmp(symbols[i].GetName(), "v", 1) ==0) + different.setV(GlobalSTP->Ctr_Example->GetCounterExample(true, symbols[i])); + else if (strncmp(symbols[i].GetName(), "w", 1) ==0) + different.setW(GlobalSTP->Ctr_Example->GetCounterExample(true, symbols[i])); + } return false; } } -// Intended to widen the problem from "bits" to "width". +// Widens terms from "bits" to "width". ASTNode widen(const ASTNode& w, int width) { - if (w.isConstant() && w.GetValueWidth() == bits && (w == one)) - return (BEEV::ParserBM->CreateOneConst(width)); + assert(bits >=3); + + if (w.isConstant() && w.GetValueWidth() == 1) + return w; + + if (w.isConstant() && w.GetValueWidth() == bits) + return nf->CreateTerm(BVSX, width, w); + + if (w.isConstant() && w.GetValueWidth() == bits - 1) + return nf->CreateTerm(BVSX, width - 1, w); + + if (w.isConstant() && w.GetValueWidth() == 32) // Extract DEFINATELY. + { + if (w == mgr->CreateZeroConst(32)) + return w; + + if (w == mgr->CreateOneConst(32)) + return w; + + if (w == mgr->CreateBVConst(32, bits)) + return mgr->CreateBVConst(32, width); - if (w.isConstant() && w.GetValueWidth() == bits && (w == zero)) - return (BEEV::ParserBM->CreateZeroConst(width)); + if (w == mgr->CreateBVConst(32, bits - 1)) + return mgr->CreateBVConst(32, width - 1); - if (w.isConstant() && w.GetValueWidth() == bits && (w == maxNode)) - return (BEEV::ParserBM->CreateMaxConst(width)); + if (w == mgr->CreateBVConst(32, bits - 2)) + return mgr->CreateBVConst(32, width - 2); + } - if (w.isConstant() /*&& w.GetValueWidth() == 1*/) + if (w.isConstant()) return mgr->ASTUndefined; if (w.GetKind() == SYMBOL && w.GetType() == BOOLEAN_TYPE) @@ -290,59 +395,96 @@ widen(const ASTNode& w, int width) return mgr->ASTUndefined; } + if (w.GetKind() == BVCONCAT && ((ch[0].GetValueWidth() + ch[1].GetValueWidth()) != width)) + return mgr->ASTUndefined; // Didn't widen properly. + ASTNode result; if (w.GetType() == BOOLEAN_TYPE) result = nf->CreateNode(w.GetKind(), ch); - else if (w.GetKind() == BVEXTRACT && w.GetValueWidth() == 1) - { - result = nf->CreateTerm(BVEXTRACT, 1, ch[0], BEEV::ParserBM->CreateBVConst(32, width - 1), - BEEV::ParserBM->CreateBVConst(32, width - 1)); - } - else if (w.GetKind() == BVCONCAT) + else if (w.GetKind() == BVEXTRACT) { - cerr << "don't do thisyerT" << w; - return mgr->ASTUndefined; + int new_width = ch[1].GetUnsignedConst() - ch[2].GetUnsignedConst() + 1; + result = nf->CreateTerm(w.GetKind(), new_width, ch); } else result = nf->CreateTerm(w.GetKind(), width, ch); - BVTypeCheck(result); + BVTypeCheck(result); return result; } -ASTNodeSet visited; -bool -lhsInRHS(const ASTNode& n, const ASTNode& lookFor) + +ASTNode +rewriteThroughWithAIGS(const ASTNode &n) { - if (lookFor == n) - return true; + assert(n.GetKind() == EQ); - if (visited.find(n) != visited.end()) - return false; + BBNodeManagerAIG nm; + BitBlaster bb(&nm, simp, mgr->defaultNodeFactory, &mgr->UserFlags); + ASTNode input = n; + ASTNodeMap fromTo; + ASTNodeMap equivs; + bb.getConsts(input, fromTo,equivs); - for (int i = 0; i < n.Degree(); i++) - if (lhsInRHS(n[i], lookFor)) - return true; + if (equivs.size() > 0) + { + ASTNodeMap cache; + input = SubstitutionMap::replace(input, equivs, cache,nf,false,true); + } - visited.insert(n); - return false; + if (fromTo.size() > 0) + { + ASTNodeMap cache; + input = SubstitutionMap:: replace(input, fromTo, cache,nf); + } + + return input; } -// Shortcut. Don't examine the rule if it isn't a candidate. -bool -isCandidate(const ASTNode& n) +int +getDifficulty(const ASTNode& n_) { - if (n.GetKind() != EQ) - return false; + assert(n_.GetType() == BITVECTOR_TYPE); - if (n[0].isConstant()) - return true; + if (difficulty_cache.find(n_) != difficulty_cache.end()) + return difficulty_cache.find(n_)->second; - visited.clear(); - if (lhsInRHS(n[1], n[0])) - return true; + // Calculate the difficulty over the widened version. + ASTNode n = widen(n_,widen_to); + if (n.GetKind() == UNDEFINED) + return -1; - return false; + if (n.GetValueWidth() != widen_to) + return -1; + + BBNodeManagerAIG nm; + BitBlaster bb(&nm, simp, mgr->defaultNodeFactory, &mgr->UserFlags); + + // equals fresh variable to convert to boolean type. + ASTNode f = mgr->CreateFreshVariable(0, widen_to, "ffff"); + ASTNode input = create(EQ, f, n); + + BBNodeAIG BBFormula = bb.BBForm(input); + + clearSAT(); + + Cnf_Dat_t* cnfData = NULL; + ToCNFAIG toCNF(mgr->UserFlags); + ToSATBase::ASTNodeToSATVar nodeToSATVar; + toCNF.toCNF(BBFormula, cnfData, nodeToSATVar, false, nm); + + // Why we go to all this trouble. The number of clauses. + int score = cnfData->nClauses; + + Cnf_ClearMemory(); + Cnf_DataFree(cnfData); + cnfData = NULL; + + // Free the memory in the AIGs. + BBFormula = BBNodeAIG(); // null node + + difficulty_cache.insert(make_pair(n_, score)); + return score; } // binary proposition. @@ -388,12 +530,13 @@ doIte(ASTNode a) } } + void getAllFunctions(ASTNode v, ASTNode w, ASTVec& result) { Kind types[] = - { BVMULT, BVDIV, SBVDIV, SBVREM, SBVMOD, BVPLUS, BVMOD, BVRIGHTSHIFT, BVLEFTSHIFT, BVOR, BVAND, BVXOR, BVSRSHIFT }; + { BVMULT , BVDIV, SBVDIV, SBVREM, SBVMOD, BVPLUS, BVMOD, BVRIGHTSHIFT, BVLEFTSHIFT, BVOR, BVAND, BVXOR, BVSRSHIFT }; int number_types = sizeof(types) / sizeof(Kind); // all two argument functions. @@ -416,14 +559,15 @@ startup() mgr->UserFlags.division_by_zero_returns_one_flag = true; - Simplifier * simplifier = new Simplifier(mgr); - ArrayTransformer * at = new ArrayTransformer(mgr, simplifier); - AbsRefine_CounterExample* abs = new AbsRefine_CounterExample(mgr, simplifier, at); + simp = new Simplifier(mgr); + ArrayTransformer * at = new ArrayTransformer(mgr, simp); + AbsRefine_CounterExample* abs = new AbsRefine_CounterExample(mgr, simp, at); ToSAT* tosat = new ToSAT(mgr); - GlobalSTP = new STP(mgr, simplifier, at, tosat, abs); + GlobalSTP = new STP(mgr, simp, at, tosat, abs); nf = new SimplifyingNodeFactory(*(mgr->hashingNodeFactory), *mgr); + mgr->defaultNodeFactory =nf; mgr->UserFlags.stats_flag = false; mgr->UserFlags.optimize_flag = true; @@ -444,7 +588,6 @@ startup() w = mgr->CreateSymbol("w", 0, bits); srand(time(NULL)); - } void @@ -466,18 +609,31 @@ isConstantToSat(const ASTNode & query) ASTNode query2 = nf->CreateNode(NOT, query); - query2 = GlobalSTP->arrayTransformer->TransformFormula_TopLevel(query2); SOLVER_RETURN_TYPE r = GlobalSTP->Ctr_Example->CallSAT_ResultCheck(*ss, query2, query2, GlobalSTP->tosat, false); cerr << "from"; return (r == SOLVER_VALID); // unsat, always true } + // Replaces the symbols in n, by each of the values, and concatenates them // to turn it into a single 64-bit value. uint64_t -getHash(const ASTNode& n, const vector& values) +getHash(const ASTNode& n_, const vector& values) { + assert(values.size() > 0); + const int ass_bitwidth =values[0].getV().GetValueWidth(); + assert (ass_bitwidth >= bits); + + ASTNode n = n_; + + // The values might be at a higher bit-width. + if (ass_bitwidth > bits) + n = widen(n_,ass_bitwidth); + + if (n == mgr->ASTUndefined) // Can't be widened. + return 0; + vector symbols; // The variables in the n node. ASTNodeSet visited; getVariables(n, symbols, visited); @@ -489,177 +645,252 @@ getHash(const ASTNode& n, const vector& values) ASTNodeMap mapToVal; for (int j = 0; j < symbols.size(); j++) { - if (symbols[j] == v) - mapToVal.insert(make_pair(v, values[i].getV())); - else if (symbols[j] == w) - mapToVal.insert(make_pair(w, values[i].getW())); + if (strncmp(symbols[j].GetName(), "v", 1) ==0) + { + mapToVal.insert(make_pair(symbols[j], values[i].getV())); + assert(symbols[j].GetValueWidth() == values[i].getV().GetValueWidth() ); + } + else if (strncmp(symbols[j].GetName(), "w", 1) ==0) + { + mapToVal.insert(make_pair(symbols[j], values[i].getW())); + assert(symbols[j].GetValueWidth() == values[i].getW().GetValueWidth() ); + } else - cerr << "Unknown symbol!" << symbols[j]; + { + cerr << "Unknown symbol!" << symbols[j]; + FatalError("f"); + } + assert(symbols[j].GetValueWidth() == ass_bitwidth ); } + ASTNode r = eval(n, mapToVal); assert(r.isConstant()); - hash <<= bits; + hash <<= ass_bitwidth; hash += r.GetUnsignedConst(); } return hash; } +// is from a sub-term of "to"? +bool +contained_in(ASTNode from, ASTNode to) +{ + if (from == to) + return true; + + for (int i = 0; i < to.Degree(); i++) + if (contained_in(from, to[i])) + return true; + + return false; +} + + +// Is mapping from "From" to "to" a rule we are interested in?? +bool +is_candidate(ASTNode from, ASTNode to) +{ + if (to.Degree() == 0) + return true; // Leaves are always good. + + if (contained_in(from, to)) + return true; // If what we are mapping to is contained in the "from", that's good too. + + return false; +} + bool lessThan(const ASTNode& n1, const ASTNode& n2) { return (((unsigned) n1.GetNodeNum()) < ((unsigned) n2.GetNodeNum())); } + + // Breaks the expressions into buckets recursively, then pairwise checks that they are equivalent. void -findRewrites(const ASTVec& expressions, const vector& values, const int depth = 0) +findRewrites(ASTVec& expressions, const vector& values, const int depth = 0) { + if (expressions.size() < 2) + return; + + cout << '\n' << "depth:" << depth << ", size:" << expressions.size() << " values:" << values.size() << " found: " << toWrite.size() << '\n'; + assert(expressions.size() >0); - // Put the functions in buckets based on their results on the values. - HASHMAP map; - for (int i = 0; i < expressions.size(); i++) + if (values.size() > 0) { - if (i % 50000 == 49999) - cerr << "."; - uint64_t hash = getHash(expressions[i], values); - if (map.find(hash) == map.end()) - map.insert(make_pair(hash, ASTVec())); - map[hash].push_back(expressions[i]); + // Put the functions in buckets based on their results on the values. + HASHMAP map; + for (int i = 0; i < expressions.size(); i++) + { + if (expressions[i] == mgr->ASTUndefined) + continue; // omit undefined. + + if (i % 50000 == 49999) + cerr << "."; + uint64_t hash = getHash(expressions[i], values); + if (map.find(hash) == map.end()) + map.insert(make_pair(hash, ASTVec())); + map[hash].push_back(expressions[i]); + } + expressions.clear(); + + HASHMAP::iterator it2; + + cout << "Split into " << map.size() << " pieces\n"; + + int id = 1; + for (it2 = map.begin(); it2 != map.end(); it2++) + { + ASTVec& equiv = it2->second; + vector a; + findRewrites(equiv, a, depth + 1); + equiv.clear(); + } + return; } + ASTVec& equiv = expressions; - HASHMAP::iterator it2; - static int lastOutput = 0; - int id = 1; - for (it2 = map.begin(); it2 != map.end(); it2++) - { - ASTVec& equiv = it2->second; + // Sort so that constants, and smaller expressions will be checked first. + sort(equiv.begin(), equiv.end(), lessThan); - // fast shortcut. - if (equiv.size() == 1) + for (int i = 0; i < equiv.size(); i++) + { + if (equiv[i].GetKind() == UNDEFINED) continue; - cerr << "[" << id++ << " of " << map.size() << "] depth:" << depth << ", size:" << equiv.size() << endl; - - // We don't want to keep splitting if it's having no effect. - // In particular we bound the maximum depth, and don't split again, - // if the last time we tried it didn't split at all.. - if (equiv.size() > 50 && depth <= 50 && (map.size() != 1)) + for (int j = i + 1; j < equiv.size(); j++) /// commutative so skip some. { - vector newValues; - - int max_iterations = std::max(values_in_hash * 2, (int) equiv.size() / 1000); + if (equiv[i].GetKind() == UNDEFINED || equiv[j].GetKind() == UNDEFINED) + continue; - for (int j = 0; (j < max_iterations) && (newValues.size() < values_in_hash); j++) - { - ASTNode lhs = equiv[rand() % equiv.size()]; - ASTNode rhs = equiv[rand() % equiv.size()]; - ASTNode n = mgr->CreateNode(EQ, lhs, rhs); + ASTNode n = nf->CreateNode(EQ, equiv[i], equiv[j]); + if (n.GetKind() != EQ) + continue; - Assignment different; - bool con = isConstant(n, different); + n = rewriteThroughWithAIGS(n); - if (con) - continue; // always same. + if (n.GetKind() != EQ) + continue; - // nb: We may find the same values multiple times, but don't currently care.. - newValues.push_back(different); + ASTNode from, to; + if (getDifficulty(n[0]) < getDifficulty(n[1])) + { + to = n[0]; + from = n[1]; } - cerr << "Found:" << newValues.size() << endl; - - if (newValues.size() > 0) + else if (getDifficulty(n[0]) > getDifficulty(n[1])) { - findRewrites(equiv, newValues, depth + 1); - continue; + from = n[0]; + to = n[1]; } - } - - // Sort so that constants, and smaller expressions will be checked first. - sort(equiv.begin(), equiv.end(), lessThan); - - for (int i = 0; i < equiv.size(); i++) - { - if (equiv[i].GetKind() == UNDEFINED) - continue; - - for (int j = i + 1; j < equiv.size(); j++) /// commutative so skip some. + else { - if (equiv[i].GetKind() == UNDEFINED || equiv[j].GetKind() == UNDEFINED) - continue; - - const ASTNode n = nf->CreateNode(EQ, equiv[i], equiv[j]); - if (isCandidate(n) && checkAndStoreRule(n)) + // Difficulty is equal. Try it both ways and see if it's a candidate. + if (is_candidate(n[0], n[1])) { - // We remove the LHS from the list. Other equivalent expressions will match - // the RHS anyway. - if (n[1] == equiv[i]) - equiv[i] = mgr->ASTUndefined; - if (n[1] == equiv[j]) - equiv[j] = mgr->ASTUndefined; + from = n[0]; + to = n[1]; } - - // Write out the rules intermitently. - if (lastOutput + 500 < toWrite.size()) + else { - lastOutput = toWrite.size(); - writeOutRules(); + from = n[1]; + to = n[0]; } + } + + Assignment different; + if (checkAndStoreRule(from,to, different)) + { + // Remove the more difficult expression. + if (from == equiv[i]) + equiv[i] = mgr->ASTUndefined; + if (from == equiv[j]) + equiv[j] = mgr->ASTUndefined; + } + else if (!different.isEmpty()) + { + vector ass; + ass.push_back(different); + // Discard the ones we've checked entirely. + ASTVec newEquiv(equiv.begin() + std::max(i - 1, 0), equiv.end()); + equiv.clear(); + + findRewrites(newEquiv, ass, depth + 1); + return; + } + + // Write out the rules intermitently. + if (lastOutput + 500 < toWrite.size()) + { + writeOutRules(); + lastOutput = toWrite.size(); } + } } } + // Converts the node into an IF statement that matches the node. void rule_to_string(const ASTNode & n, ASTNodeString& names, string& current, string& sofar) { + if (n.isConstant() && n.GetValueWidth() == 1 && n == mgr->CreateZeroConst(1)) { sofar += "&& " + current + " == bm->CreateZeroConst(1) "; return; } - if (n.isConstant() && n.GetValueWidth() == 1 && n == mgr->CreateOneConst(1)) { sofar += "&& " + current + " == bm->CreateOneConst(1) "; return; } - if (n.isConstant() && n.GetValueWidth() == 32 && n == mgr->CreateZeroConst(1)) // extract probably. + if (n.isConstant() && (n.GetValueWidth() == bits || n.GetValueWidth() == bits-1)) { - sofar += "&& " + current + " == bm->CreateZeroConst(32) "; + sofar += "&& " + current + " == "; + stringstream constant; + constant << "bm->CreateBVConst(" << bits << "," << n.GetUnsignedConst() << ")"; + sofar += "bm->CreateTerm(BVSX,width," + constant.str() + ")"; return; } - if (n.isConstant() && n.GetValueWidth() == 32 && n == mgr->CreateBVConst(32, bits - 1)) // extract probably. + if (n.isConstant() && n.GetValueWidth() == 32) // Extract DEFINATELY. { - sofar += "&& " + current + " == mgr->CreateBVConst(32, width-1) "; - return; - } + if (n == mgr->CreateZeroConst(32)) + { + sofar += "&& " + current + " == bm->CreateZeroConst(32) "; + return; + } - if (n.isConstant() && n == mgr->CreateMaxConst(n.GetValueWidth())) - { - sofar += "&& " + current + " == max "; - return; - } + if (n == mgr->CreateOneConst(32)) + { + sofar += "&& " + current + " == bm->CreateOneConst(32) "; + return; + } - if (n.isConstant() && n == mgr->CreateOneConst(n.GetValueWidth())) - { - sofar += "&& " + current + " == one "; - return; - } - if (n.isConstant() && n == mgr->CreateZeroConst(n.GetValueWidth())) - { - sofar += "&& " + current + " == zero"; - return; - } + if (n == mgr->CreateBVConst(32, bits)) + { + sofar += "&& " + current + " == bm->CreateBVConst(32, width) "; + return; + } - if (n.isConstant() && n == mgr->CreateBVConst(n.GetValueWidth(), n.GetValueWidth())) - { - sofar += "&& " + current + " == zero "; - return; + if (n == mgr->CreateBVConst(32, bits - 1)) + { + sofar += "&& " + current + " == bm->CreateBVConst(32, width-1) "; + return; + } + + if (n == mgr->CreateBVConst(32, bits - 2)) + { + sofar += "&& " + current + " == bm->CreateBVConst(32, width-2) "; + return; + } } if (n.isConstant()) @@ -677,6 +908,19 @@ rule_to_string(const ASTNode & n, ASTNodeString& names, string& current, string& sofar += "&& " + current + ".GetKind() == " + _kind_names[n.GetKind()] + " "; + // constrain to being == 2 for those that can be flattened. + //if (current != "n") + switch (n.GetKind()) + { + case BVXOR: + case BVMULT: + case BVPLUS: + case BVOR: + case BVAND: + sofar += "&& " + current + ".Degree() ==2 "; + break; + } + for (int i = 0; i < n.Degree(); i++) { char t[1000]; @@ -714,12 +958,15 @@ containsNode(const ASTNode& n, const ASTNode& hunting, string& current) // Check it holds at higher bit-widths. // If so, then save the rule for later. bool -checkAndStoreRule(const ASTNode & n) +checkAndStoreRule(const ASTNode & from, const ASTNode & to, Assignment& assignment) { - assert(n.GetKind() == BEEV::EQ); + const ASTNode n = create(EQ,from,to); + assert(n.GetKind() == BEEV::EQ); assert(widen_to > bits); + const int st = getCurrentTime(); + for (int i = bits; i < widen_to; i++) { const ASTNode& widened = widen(n, i); @@ -727,22 +974,29 @@ checkAndStoreRule(const ASTNode & n) // Can't widen (usually because of CONCAT or a BVCONST). if (widened == mgr->ASTUndefined) { - cerr << ")"; + cerr << "can't widen"; + cerr << n; return false; } // Send it to the SAT solver to verify that the widening has the same answer. - bool result = isConstantToSat(widened); + bool result = isConstant(widened, assignment); if (!result) { + if (i > highestLevel) + { + highestLevel = i; + highestDisproved = n; + } + // Detected it's not a constant, or is constant FALSE. cerr << "*" << i - bits << "*"; return false; } } - toWrite.push_back(n); + toWrite.push_back(ToWrite(from,to,getCurrentTime() - st)); return true; } @@ -752,7 +1006,7 @@ template { cerr << "Before removing duplicates:" << big.size(); std::sort(big.begin(), big.end()); - ASTVec::iterator it = std::unique(big.begin(), big.end()); + typename T::iterator it = std::unique(big.begin(), big.end()); big.resize(it - big.begin()); cerr << ".After removing duplicates:" << big.size() << endl; } @@ -784,54 +1038,248 @@ bucket(string substring, vector& inputs, hash_map { size_t to = current.find("&&", from); string val = current.substr(from, to - from); - current = current.replace(from, to - from + 2, "/*" + val + " && */"); // Remove what we've searched for. + //current = current.replace(from, to - from + 2, "/*" + val + " && */"); // Remove what we've searched for. + //buckets[val].push_back(current); buckets[val].push_back(current); } } } + +string +name(const ASTNode& n) +{ + assert(n.GetValueWidth() ==32); // Widen a constant used in an extract only. + + if (n == mgr->CreateBVConst(32, bits)) + return "width"; + if (n == mgr->CreateBVConst(32, bits - 1)) + return "width-1"; + if (n == mgr->CreateBVConst(32, bits - 2)) + return "width-2"; + if (n == mgr->CreateZeroConst(32)) + return "0"; + if (n == mgr->CreateOneConst(32)) + return "1"; + + FatalError("@!#$@#$@#"); +} + + +// Turns "n" into a statement in STP's C++ language to create it. +string +createString(ASTNode n, map& val) +{ + if (val.find(n) != val.end()) + return val.find(n)->second; + + string result =""; + + if (n.GetKind() == BVCONST) + { + if (n.isConstant() && n.GetValueWidth() == 1 && n == mgr->CreateZeroConst(1)) + { + result = "bm->CreateZeroConst(1"; + + } + if (n.isConstant() && n.GetValueWidth() == 1 && n == mgr->CreateOneConst(1)) + { + result = "bm->CreateOneConst(1"; + } + + if (n.isConstant() && (n.GetValueWidth() == bits || n.GetValueWidth() == bits-1)) + { + stringstream constant; + constant << "bm->CreateBVConst(" << bits << "," << n.GetUnsignedConst() << ")"; + result += "bm->CreateTerm(BVSX,width," + constant.str() + ""; + } + + if (n.isConstant() && n.GetValueWidth() == 32) // Extract DEFINATELY. + { + if (n == mgr->CreateZeroConst(32)) + result += " bm->CreateZeroConst(32 "; + + if (n == mgr->CreateOneConst(32)) + result += " bm->CreateOneConst(32 "; + + if (n == mgr->CreateBVConst(32, bits)) + result = " bm->CreateBVConst(32, width "; + + if (n == mgr->CreateBVConst(32, bits - 1)) + result = " bm->CreateBVConst(32, width-1 "; + + if (n == mgr->CreateBVConst(32, bits - 2)) + result = " bm->CreateBVConst(32, width-2 "; + } + + if (result =="") + { + // uh oh. + result = "~~~~~~~!!!!!!!!~~~~~~~~~~~"; + } + + } + + else if (n.GetType() == BOOLEAN_TYPE) + { + char buf[100]; + sprintf(buf, "bm->CreateNode(%s,", _kind_names[n.GetKind()]); + result += buf; + + } + else if (n.GetKind() == BVEXTRACT) + { + std::stringstream ss; + ss << "bm->CreateTerm(BVEXTRACT,"; + + ss << name(n[2]) << " +1 - (" << name(n[1]) << "),"; // width. + ss << createString(n[0], val) << ","; + ss << "bm->CreateBVConst(32," << name(n[1]) << "),"; // top then bottom. + ss << "bm->CreateBVConst(32," << name(n[2]) << ")"; + + result += ss.str(); + } + else if (n.GetType() == BITVECTOR_TYPE) + { + char buf[100]; + sprintf(buf, "bm->CreateTerm(%s,width,", _kind_names[n.GetKind()]); + result += buf; + } + else + { + cerr << n; + cerr << "never here"; + exit(1); + } + + if (n.GetKind() != BVEXTRACT) + for (int i = 0; i < n.Degree(); i++) + { + if (i > 0) + result += ","; + + result += createString(n[i], val); + } + result += ")"; + + val.insert(make_pair(n, result)); + return result; +} + +// loads all the expressions in "n" into the list of available expressions. +void +visit_all(const ASTNode& n, map& visited, string current) +{ + if (visited.find(n) != visited.end()) + return; + + visited.insert(make_pair(n, current)); + + for (int i = 0; i < n.Degree(); i++) + { + char t[1000]; + sprintf(t, "%s[%d]", current.c_str(), i); + string s(t); + visit_all(n[i], visited, s); + } +} + +template +std::string to_string(T i) +{ + std::stringstream ss; + ss << i; + return ss.str(); +} + + + // Write out all the rules that have been discovered to file. void writeOutRules() { - vector output; - removeDuplicates(toWrite); - cerr << "Writing out " << toWrite.size() << " rules." << endl; - - ofstream outputFile; - outputFile.open("rewrite_data.cpp", ios::trunc); + vector output; for (int i = 0; i < toWrite.size(); i++) { - const ASTNode& n = toWrite[i]; - const ASTNode& lhs = n[0]; - const ASTNode& rhs = n[1]; + if (toWrite[i].isEmpty()) + continue; - string name = getLHSName(n[0], n[1]); - if (name == "") + ASTNode to = toWrite[i].to; + ASTNode from = toWrite[i].from; + + if (getDifficulty(to) > getDifficulty(from)) { - cerr << "Attempting to write out non name!" << n; - continue; + // Want the easier one on the lhs. Which is the opposite of what you expect.. + ASTNode t = to; + to = from; + from = t; + } + + // If the RHS is just part of the LHS, then we output something like children[0][1][0][1] as the RHS. + string to_name = getToName(to, from); + + if (to_name == "") + { + // The name is not contained in the rhs. + ASTNodeSet visited; + vector symbols; + + getVariables(to, symbols, visited); + map val; + for (int i = 0; i < symbols.size(); i++) + val.insert(make_pair(symbols[i], getToName(symbols[i], from))); + + val.insert(make_pair(one, "one")); + val.insert(make_pair(maxNode, "max")); + val.insert(make_pair(zero, "zero")); + + // loads all the expressions in the rhs into the list of available expressions. + visit_all(from, val, "children"); + + to_name = createString(to, val); } + ASTNodeString names; string current = "n"; - string sofar = "if (1==1 "; - rule_to_string(n[1], names, current, sofar); - sofar += " && 1==1) set(output," + name + ", __LINE__ );\n"; + string sofar = "if ( width >= " + to_string(bits) + " " ; - if (sofar.find("!!!") == std::string::npos && sofar.length() < 500) + rule_to_string(from, names, current, sofar); + sofar += ") set(result, " + to_name + ");"; + +// if (sofar.find("!!!") == std::string::npos && sofar.length() < 500) { - output.push_back(sofar); - //printer::SMTLIB2_PrintBack(outputFile,toWrite[i]); + assert(getDifficulty(from) >= getDifficulty(to)); + + if (mgr->ASTTrue == rewriteThroughWithAIGS(toWrite[i].n)) + { + toWrite[i] = ToWrite(mgr->ASTUndefined,mgr->ASTUndefined,0); + continue; + } + + { + char buf[100]; + sprintf(buf, "//%d -> %d | %d ms\n", getDifficulty(from), getDifficulty(to), 0 /*toWrite[i].time*/); + sofar += buf; + output.push_back(sofar); + } } } + // Remove the duplicates from output. + removeDuplicates(output); + + cerr << "Rules Discovered in total: " << toWrite.size() << endl; + // Group functions of the same kind all together. hash_map, hashF > buckets; bucket("n.GetKind() ==", output, buckets); + ofstream outputFile; + outputFile.open("rewrite_data_new.cpp", ios::trunc); + // output the C++ code. hash_map, hashF >::const_iterator it; for (it = buckets.begin(); it != buckets.end(); it++) @@ -844,8 +1292,44 @@ writeOutRules() outputFile << "}" << endl; } - outputFile.close(); + + ofstream outputFileSMT2; + outputFileSMT2.open("rewrite_data.smt2", ios::trunc); + + for (int i = 0; i < toWrite.size(); i++) + { + if (toWrite[i].isEmpty()) + continue; + + outputFileSMT2 << "; " << "bits:" << bits << "->" << widen_to << " time to verify:" << toWrite[i].time << '\n'; + for (int j= widen_to; j < widen_to+ 5;j++) + { + outputFileSMT2 << "(push 1)\n"; + printer::SMTLIB2_PrintBack(outputFileSMT2, mgr->CreateNode(NOT, widen(toWrite[i].n,j)),true,false); + outputFileSMT2 << "(pop 1)\n"; + } + } + + outputFileSMT2.close(); + + outputFileSMT2.open("array.smt2", ios::trunc); + ASTVec v; + for (int i = 0; i < toWrite.size(); i++) + { + if (toWrite[i].isEmpty()) + continue; + + v.push_back(toWrite[i].n); + } + + if (v.size() > 0) + { + ASTNode n = mgr->CreateNode(AND,v); + printer::SMTLIB2_PrintBack(outputFileSMT2, n,true); + } + outputFileSMT2.close(); + } // Triples the number of functions by adding all the unary ones. @@ -854,22 +1338,26 @@ allUnary(ASTVec& functions) { for (int i = 0, size = functions.size(); i < size; i++) { + if (functions[i] == mgr->ASTUndefined) + continue; + functions.push_back(nf->CreateTerm(BEEV::BVNEG, bits, functions[i])); functions.push_back(nf->CreateTerm(BEEV::BVUMINUS, bits, functions[i])); } - } -// If we can't widen it remove it. Very slow. void removeNonWidened(ASTVec & functions) { for (int i = 0; i < functions.size(); i++) { + if (mgr->ASTUndefined == functions[i]) + continue; + if (mgr->ASTUndefined == widen(functions[i], bits + 1)) { - functions.erase(functions.begin() +i); - i--; + functions[i] = mgr->ASTUndefined; // We can't widen it later. So remove it. + continue; } } } @@ -892,10 +1380,40 @@ removeSingleVariable(ASTVec & functions) continue; } } +} - removeDuplicates(functions); +void +removeSingleUndefined(ASTVec& functions) +{ + for (int i = 0; i < functions.size(); i++) + { + if (functions[i] == mgr->ASTUndefined) + { + functions.erase(functions.begin() + i); + break; + } + } } +void applyBigRewrite(ASTVec& functions) +{ + BEEV::BigRewriter b; + + for (int i = 0; i < functions.size(); i++) + { + if (functions[i] == mgr->ASTUndefined) + continue; + + ASTNodeMap fromTo; + ASTNode s = b.rewrite(functions[i], fromTo, nf, mgr); + if (s != functions[i]) + { + functions[i] = s; + } + } +} + + int main(void) { @@ -923,6 +1441,7 @@ main(void) functions.push_back(v); functions.push_back(mgr->CreateBVConst(bits, 0)); functions.push_back(mgr->CreateBVConst(bits, 1)); + functions.push_back(mgr->CreateMaxConst(bits)); // All unary of the leaves. allUnary(functions); @@ -937,11 +1456,20 @@ main(void) getAllFunctions(functions[i], functions[j], functions); allUnary(functions); + + // Duplicates removed, rewrite rules applied, non-widenable removed, + //removeNonWidened(functions); + //applyBigRewrite(functions); removeDuplicates(functions); - removeNonWidened(functions); + removeSingleUndefined(functions); + cerr << "One Level:" << functions.size() << endl; + applyBigRewrite(functions); + removeDuplicates(functions); + cerr << "After rewrite:" << functions.size() << endl; const bool two_level = true; + if (two_level) { int last = 0; @@ -949,69 +1477,128 @@ main(void) size = functions_copy.size(); for (int i = 0; i < size; i++) for (int j = 0; j < size; j++) - { - getAllFunctions(functions_copy[i], functions_copy[j], functions); - if (functions.size() > last + 5000000) // lots are duplicates. - { - removeSingleVariable(functions); - last = functions.size(); - } - } + getAllFunctions(functions_copy[i], functions_copy[j], functions); - // All the unary combinations of the binaries. - allUnary(functions); - - // This is an agressive filter. + //applyBigRewrite(functions); removeSingleVariable(functions); + removeDuplicates(functions); + removeSingleUndefined(functions); + + // All the unary combinations of the binaries. + //allUnary(functions); + //removeNonWidened(functions); + //removeDuplicates(functions); cerr << "Two Level:" << functions.size() << endl; } - BBNodeManagerAIG bbnm; - SimplifyingNodeFactory nf(*(mgr->hashingNodeFactory), *mgr); + // The hash is generated on these values. + vector values; + findRewrites(functions, values); + writeOutRules(); + + cerr << "Initial:" << bits << " widening to :" << widen_to << endl; + cerr << "Highest disproved @ level: " << highestLevel << endl; + cerr << highestDisproved << endl; + + return 0; +} - BitBlaster bb(&bbnm, GlobalSTP->simp, &nf, &(mgr->UserFlags)); - { - SimplifyingNodeFactory nf(*(mgr->hashingNodeFactory), *mgr); #if 0 - BEEV::BigRewriter b; +// Shortcut. Don't examine the rule if it isn't a candidate. +bool +isCandidateSizePreserving(const ASTNode& n) +{ + if (n.GetKind() != EQ) + return false; - for (int i = 0; i < functions.size(); i++) - { - if (false) - { - ASTNodeMap fromTo; - ASTNode s = b.rewrite(functions[i], fromTo, &nf, mgr); - if (s != functions[i]) - functions[i] = s; - } + if (n[0].isConstant()) + return true; - } -#endif - removeDuplicates(functions); + visited.clear(); + if (lhsInRHS(n[1], n[0])) + return true; + return false; +} - // There may be a single undefined element now.. remove it. - for (int i = 0; i < functions.size(); i++) - { - if (functions[i] == mgr->ASTUndefined) - { - functions.erase(functions.begin() + i); - break; - } - } - } +ASTNodeSet visited; - // The hash is generated on these values. - vector values; - values.push_back(Assignment(BEEV::ParserBM->CreateMaxConst(bits))); - values.push_back(Assignment(BEEV::ParserBM->CreateZeroConst(bits))); - while (values.size() < values_in_hash) - values.push_back(Assignment(BEEV::ParserBM->CreateBVConst(bits, rand()))); +bool +lhsInRHS(const ASTNode& n, const ASTNode& lookFor) +{ + if (lookFor == n) + return true; - findRewrites(functions, values); - writeOutRules(); + if (visited.find(n) != visited.end()) + return false; + + for (int i = 0; i < n.Degree(); i++) + if (lhsInRHS(n[i], lookFor)) + return true; + + visited.insert(n); + return false; +} + +int +getDifficulty_approximate(const ASTNode&n) +{ + if (difficulty_cache.find(n) != difficulty_cache.end()) + return difficulty_cache.find(n)->second; + + DifficultyScore ds; + int score = ds.score(n); + difficulty_cache.insert(make_pair(n, score)); + return score; +} + +// Shortcut. Don't examine the rule if it isn't a candidate. +bool +isCandidateDifficultyPreserving(const ASTNode& n) +{ + if (n.GetKind() != EQ) + return false; - return 1; + if (getDifficulty(n[0]) != getDifficulty(n[1])) + return true; + + return false; +} + +void +getSomeFunctions(ASTNode v, ASTNode w, ASTVec& result) +{ + + Kind types[] = + { BVMULT, BVDIV, SBVDIV, SBVREM, SBVMOD, BVPLUS, BVMOD }; + int number_types = sizeof(types) / sizeof(Kind); + + // all two argument functions. + for (int i = 0; i < number_types; i++) + result.push_back(create(types[i], v, w)); } +// True if "to" is a single function of "n" +bool +single_fn_of(ASTNode from, ASTNode to) +{ + for (int i = 0; i < to.Degree(); i++) + { + if (to[i].isConstant()) + continue; + + // Special case equalities are cheap so allow them through. + if (to[i].GetKind() == EQ && to[i][0].isConstant()) + { + if (!contained_in(to[i][1], from)) + return false; + } + else if (!contained_in(to[i], from)) + return false; + } + return true; +} + + +#endif