#include "rewrite_rule.h"
#include "rewrite_system.h"
#include "Functionlist.h"
+#include "misc.h"
extern int
smt2parse();
volatile bool force_writeout = false;
// Saves a little bit of time. The vectors are saved between invocations.
-vector<ASTVec> saved_array;
+vector<ASTVec*> saved_array;
// Stores the difficulties that have already been generated.
map<ASTNode, int> difficulty_cache;
getVariables(const ASTNode& n);
bool
-unifyNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width);
+matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width);
typedef HASHMAP<ASTNode, string, ASTNode::ASTNodeHasher, ASTNode::ASTNodeEqual> ASTNodeString;
BEEV::STPMgr* mgr;
NodeFactory* nf;
+NodeFactory* simpNf;
+
SATSolver * ss;
ASTNodeSet stored; // Store nodes so they aren't garbage collected.
Simplifier *simp;
bool
checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& ass, bool& bad);
+ASTNode
+withNF(const ASTNode &n)
+{
+ if (n.isAtom())
+ return n;
+
+ ASTVec c;
+ for (int i=0; i< n.Degree();i++)
+ c.push_back(withNF(n[i]));
+
+ if (n.GetType() == BOOLEAN_TYPE)
+ return nf->CreateNode(n.GetKind(), c);
+ else
+ return nf->CreateArrayTerm(n.GetKind(), n.GetIndexWidth(), n.GetValueWidth(), c);
+}
+
+
ASTNode
renameVars(const ASTNode &n)
{
// We have an array of arrays already created to store the children.
// This reduces the number of objects created/destroyed.
if (count >= saved_array.size())
- saved_array.push_back(ASTVec());
+ saved_array.push_back(new ASTVec());
- ASTVec& new_children = saved_array[count];
+ ASTVec& new_children = *saved_array[count];
new_children.clear();
for (int i = 0; i < n.Degree(); i++)
new_children.push_back(eval(n[i], map, count + 1));
ASTNode r = NonMemberBVConstEvaluator(mgr, n.GetKind(), new_children, n.GetValueWidth());
+ new_children.clear();
map.insert(make_pair(n, r));
return r;
}
// True if it's always true. Otherwise fills the assignment.
bool
-isConstant(const ASTNode& n, VariableAssignment& different)
+isConstant(const ASTNode& n, VariableAssignment& different, const int bits)
{
if (isConstantToSat(n))
return true;
mgr->ValidFlag = false;
vector<ASTNode> symbols = getVariables(n);
- assert(symbols.size() > 0);
// Both of them might not be contained in the assignment.
- different.setV(mgr->CreateZeroConst(symbols[0].GetValueWidth()));
- different.setW(mgr->CreateZeroConst(symbols[0].GetValueWidth()));
+ different.setV(mgr->CreateZeroConst(bits));
+ different.setW(mgr->CreateZeroConst(bits));
// It might have been widened.
for (int i = 0; i < symbols.size(); i++)
bool
orderEquivalence(ASTNode& from, ASTNode& to)
{
+ if(from.IsNull())
+ return false;
+ if(from.GetKind() == UNDEFINED)
+ return false;
+ if(to.IsNull())
+ return false;
+ if(to.GetKind() == UNDEFINED)
+ return false;
+
+ {
+ ASTVec c;
+ c.push_back(from);
+ c.push_back(to);
+ ASTNode w = widen(mgr->hashingNodeFactory->CreateNode(EQ,c), widen_to);
+
+ if (w.IsNull() || w.GetKind() == UNDEFINED)
+ return false;
+ }
+
vector<ASTNode> s_from; // The variables in the from node.
ASTNodeSet visited;
getVariables(from, s_from, visited);
std::sort(s_from.begin(), s_from.end());
+ const int from_c = visited.size();
vector<ASTNode> s_to; // The variables in the to node.
visited.clear();
getVariables(to, s_to, visited);
sort(s_to.begin(), s_to.end());
+ const int to_c = visited.size();
+
+ if (from_c > 50 || to_c > 50)
+ return false; // not interested in giant rules.
vector<ASTNode> result(s_to.size() + s_from.size());
// We must map from most variables to fewer variables.
if (s_from.size() > s_to.size())
return true;
- if (getDifficulty(from) < getDifficulty(to))
+ if ((getDifficulty(from)+5) < getDifficulty(to))
{
swap(from, to);
return true;
}
- if (getDifficulty(from) > getDifficulty(to))
+ if (getDifficulty(from) > (getDifficulty(to)+5))
{
return true;
}
- // Difficulty is equal. Order based on the number of nodes.
- vector<ASTNode> symbols;
- visited.clear();
- getVariables(from, symbols, visited);
- int from_c = visited.size();
-
- symbols.clear();
- visited.clear();
- getVariables(to, symbols, visited);
- int to_c = visited.size();
-
if (to_c < from_c)
{
return true;
ToSATBase::ASTNodeToSATVar nodeToSATVar;
toCNF.toCNF(BBFormula, cnfData, nodeToSATVar, false, nm);
+ // Send the clauses to Minisat, do unit propagation.
+ ///////////////
+
+ // Create a new sat variable for each of the variables in the CNF.
+ assert(ss->nVars() ==0);
+ for (int i = 0; i < cnfData->nVars ; i++)
+ ss->newVar();
+
+ SATSolver::vec_literals satSolverClause;
+ for (int i = 0; i < cnfData->nClauses; i++)
+ {
+ satSolverClause.clear();
+ for (int * pLit = cnfData->pClauses[i], *pStop = cnfData->pClauses[i
+ + 1]; pLit < pStop; pLit++)
+ {
+ SATSolver::Var var = (*pLit) >> 1;
+ assert ((var < ss->nVars()));
+ Minisat::Lit l = SATSolver::mkLit(var, (*pLit) & 1);
+ satSolverClause.push(l);
+ }
+
+ ss->addClause(satSolverClause);
+ }
+
+ ss->simplify();
+ assert (ss->okay()); // should be satisfiable.
+
// Why we go to all this trouble. The number of clauses.
- int score = cnfData->nClauses;
+ const int score = ss->nClauses();
+ assert(score <= cnfData->nClauses);
+ //////////////
//Cnf_ClearMemory();
Cnf_DataFree(cnfData);
force_writeout = true;
}
+volatile bool debug_usr2 = false;
+
+//toggle.
+void do_usr2(int ignore)
+{
+ debug_usr2=!debug_usr2;
+}
+
int
startup()
GlobalSTP = new STP(mgr, simp, at, tosat, abs);
-#ifndef NOTSIMPLIFYING_NF
- nf = new SimplifyingNodeFactory(*(mgr->hashingNodeFactory), *mgr);
- mgr->defaultNodeFactory = nf;
-#else
- nf = mgr->hashingNodeFactory;
- mgr->defaultNodeFactory = mgr->hashingNodeFactory;
-#endif
+ simpNf = new SimplifyingNodeFactory(*(mgr->hashingNodeFactory), *mgr);
+ nf = new TypeChecker(*simpNf, *mgr);
+ mgr->defaultNodeFactory = simpNf;
mgr->UserFlags.stats_flag = false;
mgr->UserFlags.optimize_flag = true;
// Prime the cache with 100..
for (int i = 0; i < 100; i++)
{
- saved_array.push_back(ASTVec());
+ saved_array.push_back(new ASTVec());
}
zero = mgr->CreateZeroConst(bits);
// Write out the work so far..
signal(SIGUSR1,do_write_out);
+ signal(SIGUSR2,do_usr2);
+
}
void
HASHMAP<uint64_t, ASTVec>::iterator it2;
cout << "Split into " << map.size() << " pieces\n";
+ if (depth > 0)
+ {
+ if(map.size() ==1)
+ {
+ cerr << values[0].getV();
+ cerr << values[0].getW();
+ assert(false);
+ }
+ }
int id = 1;
for (it2 = map.begin(); it2 != map.end(); it2++)
}
ASTVec& equiv = expressions;
-
// Sort so that constants, and smaller expressions will be checked first.
std::sort(equiv.begin(), equiv.end(), lessThan);
if (equiv[i].GetKind() == UNDEFINED || equiv[j].GetKind() == UNDEFINED)
continue;
- equiv[j] = rewrite_system.rewriteNode(equiv[j]);
-
- ASTNode from = equiv[i];
- ASTNode to = equiv[j];
- bool r = orderEquivalence(from, to);
-
- if (!r)
- {
- if (from != to)
- cout << "can't be ordered" << from << to;
- continue;
- }
+ ASTNode& from = equiv[i];
+ ASTNode& to = equiv[j];
VariableAssignment different;
bool bad = false;
const int st = getCurrentTime();
+ if (from.isConstant() && to.isConstant())
+ continue;
+
if (checkRule(from, to, different, bad))
{
- cout << "Discovered a new rule.";
- cout << from;
- cout << to;
- cout << getDifficulty(from) << " to " << getDifficulty(to) << endl;
- cout << "After rewriting";
- cout << rewrite_system.rewriteNode(from);
- cout << rewrite_system.rewriteNode(to);
- cout << "------";
+ equiv[i] = rewrite_system.rewriteNode(equiv[i]);
+ equiv[j] = rewrite_system.rewriteNode(equiv[j]);
- Rewrite_rule rr(mgr, from, to, getCurrentTime() - st);
+ equiv[i] = rewriteThroughWithAIGS(equiv[i]);
+ equiv[j] = rewriteThroughWithAIGS(equiv[j]);
- if (!rr.timedCheck(10000))
+ ASTNode& f = equiv[i];
+ ASTNode& t = equiv[j];
+
+ bool r = orderEquivalence(f, t);
+
+ if (!r)
continue;
+ cout << "Discovered a new rule.";
+ cout << f << t;
+ cout << getDifficulty(f) << " to " << getDifficulty(t) << endl;
+
+ Rewrite_rule rr(mgr, f, t, getCurrentTime() - st);
+
+ if (!rr.timedCheck(1000))
+ {
+ cout << "Rule failed extended verification.";
+ continue;
+ }
+ cout << "Verified Rule to: " << rr.getVerifiedToBits() << " bits" << endl;
+ cout << "------";
+
rewrite_system.push_back(rr);
// Remove the more difficult expression.
}
// Write out the rules intermitently.
- if (force_writeout || lastOutput + 5000 < rewrite_system.size())
+ if (force_writeout || lastOutput + 500 < rewrite_system.size())
{
rewrite_system.rewriteAll();
writeOutRules();
ASTVec children;
children.push_back(from);
children.push_back(to);
+
+ // The simplifying node factory sometimes meant it couldn't be widended.
const ASTNode n = mgr->hashingNodeFactory->CreateNode(EQ, children);
assert(widen_to > bits);
}
// Send it to the SAT solver to verify that the widening has the same answer.
- bool result = isConstant(widened, assignment);
+ bool result = isConstant(widened, assignment, bits);
if (!result)
{
}
// Detected it's not a constant, or is constant FALSE.
- cout << "*" << i - bits << "*";
+ cout << "*" << i - bits << "*";
return false;
}
}
void
writeOutRules()
{
+ cerr << "Writing out: " << rewrite_system.size() << " rules" << endl;
force_writeout = false;
std::vector<string> output;
for (Rewrite_system::RewriteRuleContainer::iterator it = rewrite_system.toWrite.begin() ; it != rewrite_system.toWrite.end(); it++)
{
- if (!it->isOK())
- {
- rewrite_system.erase(it--);
- continue;
- }
-
ASTNode to = it->getTo();
ASTNode from = it->getFrom();
if (dup.find(sofar) != dup.end())
{
- cout << "-----";
+ cout << "-----Writing out has found a duplicate rule-----";
cout << sofar;
ASTNode f = it->getFrom();
- cout << f << std::endl;
- cout << dup.find(sofar)->second.getFrom();
+ cout << "This:" << f << std::endl;
+ cout << "Has the same text as this: " << dup.find(sofar)->second.getFrom();
+ cout << "Rule " << it->getId() << " has the same text as " << dup.find(sofar)->second.getId() << endl;
ASTNodeMap fromTo;
-
- cerr << f;
f = renameVars(f);
- //cerr << "renamed" << f;
- bool result = unifyNode(f,dup.find(sofar)->second.getFrom(),fromTo,2) ;
- cout << "unified" << result << endl;
+ bool result = commutative_matchNode(f,dup.find(sofar)->second.getFrom(),fromTo,2) ;
+ cout << "Has it unified:" << result << endl;
ASTNodeMap seen;
- cout << rewrite(f,*it,seen );
// The text of this rule is the same as another rule.
rewrite_system.erase(it--);
bucket("n.GetKind() ==", output, buckets);
ofstream outputFile;
+
+ // Because we output the difficulty (i.e. the number of CNF clauses),
+ // this is very slow.
+ #ifdef OUTPUT_CPP_RULES
outputFile.open("rewrite_data_new.cpp", ios::trunc);
+
// output the C++ code.
hash_map<string, vector<string>, hashF<std::string> >::const_iterator it;
for (it = buckets.begin(); it != buckets.end(); it++)
outputFile << "}" << endl;
}
outputFile.close();
+ #endif
///////////////
outputFile.open("rules_new.smt2", ios::trunc);
{
n = renameVars(n);
ASTNodeMap seen;
- n = rewrite(n,original_rule,seen);
+ n = rewrite(n,original_rule,seen,0);
return renameVarsBack(n);
}
// assumes the variables in n are two characters wide.
ASTNode
-rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen)
+rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen, int depth)
{
- if (n.isAtom())
+ if (depth > 50)
return n;
- // nb. won't rewrite through EQ etc.
- if (n.GetType() != BITVECTOR_TYPE)
+ if (n.isAtom())
return n;
ASTVec v;
+ v.reserve(n.Degree());
for (int i = 0; i < n.Degree(); i++)
- v.push_back(rewrite(n[i],original_rule,seen));
+ v.push_back(rewrite(n[i],original_rule,seen,depth+1));
assert(v.size() > 0);
ASTNode n2;
if (v!=n.GetChildren())
- n2 = mgr->CreateTerm(n.GetKind(), n.GetValueWidth(), v);
+ {
+ if (n.GetType() != BOOLEAN_TYPE)
+ n2 = mgr->CreateArrayTerm(n.GetKind(),n.GetIndexWidth(), n.GetValueWidth(), v);
+ else
+ n2 = mgr->CreateNode(n.GetKind(), v);
+ }
else
n2 = n;
ASTNodeMap fromTo;
- vector<Rewrite_rule>& rr =
- n[0].Degree() > 0 ?
- (rewrite_system.kind_kind_to_rr[n.GetKind()][n[0].GetKind()]) :
- (rewrite_system.kind_to_rr[n.GetKind()]) ;
+ if (rewrite_system.lookups_invalid)
+ rewrite_system.buildLookupTable();
+ vector<Rewrite_rule>& rr = rewrite_system.kind_to_rr[n.GetKind()];
for (int i = 0; i < rr.size(); i++)
{
const ASTNode& f = rr[i].getFrom();
- if (unifyNode(f, n2, fromTo,1))
+ if (commutative_matchNode(f, n2, fromTo, 1))
{
- /*
- cerr << "Unifying" << f;
- cerr << "with:" << n2;
+ if (debug_usr2)
+ {
+ cerr << "Original Rule(" << original_rule.getId() << ")";
+ cerr << original_rule.getFrom();
+ cerr << "->" << original_rule.getTo();
- cerr << "Now" << rr[i].getTo();
- cerr << "reducing rule" << rr[i].getN();
+ cerr << "Matching Rule(" << rr[i].getId() << ")";
+ cerr << rr[i].getFrom();
+ cerr << "->" << rr[i].getTo();
- for (ASTNodeMap::iterator it = fromTo.begin(); it != fromTo.end(); it++)
- {
- cerr << it->first << "to" << it->second << endl;
- }
+ cerr << "--------------";
+ cerr << "Unifying" << f;
+ cerr << "with:" << n2;
+ cerr << "--------------";
- cerr << "--------------";
- */
+ for (ASTNodeMap::iterator it = fromTo.begin(); it != fromTo.end(); it++)
+ {
+ cerr << it->first << "to" << it->second << endl;
+ }
+
+ cerr << "--------------";
+ }
if (seen.find(n) != seen.end())
return seen.find(n)->second;
- seen.insert(make_pair(n,rr[i].getTo()));
+ seen.insert(make_pair(n, rr[i].getTo()));
ASTNodeMap cache;
- ASTNode r= SubstitutionMap::replace(rr[i].getTo(), fromTo, cache, nf, false, true);
+ ASTNode r = SubstitutionMap::replace(rr[i].getTo(), fromTo, cache, nf, false, true);
seen.erase(n);
- seen.insert(make_pair(n,r));
- r= rewrite(r,original_rule,seen);
+ seen.insert(make_pair(n, r));
+ r = rewrite(r, original_rule, seen, depth + 1);
seen.erase(n);
- seen.insert(make_pair(n,r));
+ seen.insert(make_pair(n, r));
return r;
+
}
}
return n2;
}
+
int smt2_scan_string(const char *yy_str);
// http://stackoverflow.com/questions/3418231/c-replace-part-of-a-string-with-another-string
void
-loadNewRules()
+load_new_rules(const string fileName = "rules_new.smt2")
{
FILE * in;
bool opended= false;
- string fileName = "rules_new.smt2"
if(!ifstream(fileName.c_str())) /// use stdin if the default file is not found.
in = stdin;
opended = true; // so we know to fclose it.
}
- // We store references to "v" and "w", so we need to removed the
+ // We store references to "v" and "w", so we need to remove the
// definitions from the input we parse.
v = mgr->LookupOrCreateSymbol("v");
int rv = sscanf(line, ";id:%d\tverified_to:%d\ttime:%d\tfrom_difficulty:%d\tto_difficulty:%d\n", &id, &verified_to_bits, &time_used, &from_v, &to_v);
if (rv !=5)
{
+ cerr << line;
done = true;
break;
}
assert(values.size() ==1);
- ASTNode from = values[0][0];
- ASTNode to = values[0][1];
+ // The nodes have been built with the hashing node factory.
+ // In practice we want to match nodes that are created with the simplifying NF.
+ // If we enabled the simplifying NF, the EQUALS checks would remove rules we want to keep.
+ ASTNode from = withNF(values[0][0]);
+ ASTNode to = withNF(values[0][1]);
// Rule should be orderable.
bool ok = orderEquivalence(from, to);
cout << "discarding rule that can't be ordered";
cout << from << to;
cout << "----";
+ mgr->PopQuery();
+ parserInterface->popToFirstLevel();
continue;
}
Rewrite_rule r(mgr, from, to, 0, id);
r.setVerified(verified_to_bits,time_used);
- assert(r.isOK());
rewrite_system.push_back(r);
mgr->PopQuery();
parserInterface->popToFirstLevel();
}
+ extern int smt2lex_destroy(void);
+ smt2lex_destroy();
+
parserInterface->cleanUp();
if (opended)
- fclose(in);
+ {
+ cout << "New Style Rules Loaded:" << rewrite_system.size() << endl;
+ fclose(in);
+ }
+
+ // So we don't output as soon as one is discovered...
+ lastOutput = rewrite_system.size();
- cout << "New Style Rules Loaded:" << rewrite_system.size() << endl;
}
-//read from stdin, then tests it until the timeout.
+//Reads in new format rules. And tests each of them for the allocated time.
void
-expandRules(int timeout_ms)
+expandRules(int timeout_ms, const char* fileName = "")
{
- loadNewRules();
+ load_new_rules(fileName);
createVariables();
for (Rewrite_system::RewriteRuleContainer::iterator it = rewrite_system.begin(); it!= rewrite_system.end();it++)
{
if ((*it).timedCheck(timeout_ms))
- it->writeOut(cout); // omit failed.
+ {
+ it->writeOut(cout); // omit failed.
+ cerr << getDifficulty(it->getFrom()) <<" " << getDifficulty(it->getTo());
+ }
}
}
+
+void t2()
+{
+ extern FILE *smt2in;
+
+ smt2in = fopen("big_array.smt2", "r");
+ TypeChecker nfTypeCheckDefault(*mgr->hashingNodeFactory, *mgr);
+ Cpp_interface piTypeCheckDefault(*mgr, &nfTypeCheckDefault);
+ parserInterface = &piTypeCheckDefault;
+
+ mgr->GetRunTimes()->start(RunTimes::Parsing);
+ smt2parse();
+
+ ASTVec v = FlattenKind(AND, piTypeCheckDefault.GetAsserts());
+ ASTNode n = nf->CreateNode(XOR, v);
+
+ //rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen)
+ ASTNodeMap seen;
+ createVariables();
+ ASTNode r = rename_then_rewrite(n,Rewrite_rule::getNullRule());
+ cerr << r;
+
+
+}
+
// loads the already existing rules.
-void loadExistingRules(string fileName)
+void load_old_rules(string fileName)
{
if(!ifstream(fileName.c_str()))
return; // file doesn't exist.
Rewrite_rule r(mgr, from, to, 0);
- if (r.isOK());
- rewrite_system.push_back(r);
+ rewrite_system.push_back(r);
}
mgr->PopQuery();
int test()
{
// Test code.
- loadExistingRules("test.smt2");
+ load_old_rules("test.smt2");
v = mgr->LookupOrCreateSymbol("v");
v.SetValueWidth(bits);
w0.SetValueWidth(bits);
}
+void unit_test()
+{
+
+ // Create the negation and not terms in different orders. This tests the commutative matching.
+ ASTVec c;
+ c.push_back(v);
+ ASTNode not_v = create(BEEV::BVNEG, c);
+ ASTNode neg_v = create(BEEV::BVUMINUS, c);
+
+ ASTNode plus_v = create(BVPLUS,not_v,neg_v);
+
+ c.clear();
+ c.push_back(w);
+ ASTNode neg_w = create(BEEV::BVUMINUS, c);
+ ASTNode not_w = create(BEEV::BVNEG, c);
+ ASTNode plus_w = create(BVPLUS,not_w,neg_w);
+
+ ASTNodeMap sub;
+ plus_w = renameVars(plus_w);
+ assert(commutative_matchNode(plus_w,plus_v,sub,2));
+ sub.clear();
+
+ assert(commutative_matchNode(plus_v,plus_w,sub,1));
+
+
+}
+
+
int
main(int argc, const char* argv[])
{
startup();
+
if (argc == 1) // Read the current rule set, find new rules.
{
- loadExistingRules("array.smt2");
+ load_new_rules();
createVariables();
////////////
rewrite_system.buildLookupTable();
rewrite_system.rewriteAll();
writeOutRules();
}
- else if (argc == 3 && !strcmp("expand",argv[1])) // expand the bit-widths rules are tested at.
+ else if (argc ==2 && !strcmp("unit-test",argv[1]))
{
+ load_new_rules();
+ createVariables();
+ unit_test();
+ }
+ else if (argc ==2 && !strcmp("verify",argv[1]))
+ {
+ load_new_rules();
+ rewrite_system.verifyAllwithSAT();
+ }
+ else if ((argc == 4 || argc ==3) && !strcmp("expand",argv[1]))
+ {
+ // expand the bit-widths rules are tested at.
int timeout_ms = atoi(argv[2]);
assert(timeout_ms > 0);
- expandRules(timeout_ms);
+ expandRules(timeout_ms, (argc == 4 ? argv[3]:""));
}
- else if (argc == 2 && !strcmp("verify-all",argv[1]))
+ else if (argc == 2 && !strcmp("rewrite",argv[1]))
{
- loadNewRules();
+ // load the rules and apply the rewrite system to itself.
+ load_new_rules();
createVariables();
- rewrite_system.verifyAllwithSAT();
- writeOutRules(); // have the times now..
+ rewrite_system.rewriteAll();
+ writeOutRules();
}
else if (argc == 2 && !strcmp("write-out",argv[1]))
{
- loadNewRules();
+ load_new_rules();
createVariables();
rewrite_system.rewriteAll();
writeOutRules(); // have the times now..
}
+ else if (argc == 2 && !strcmp("renumber",argv[1]))
+ {
+ // Intended to merge two sets of rules, then renumber them.
+ load_new_rules();
+ createVariables();
+ rewrite_system.renumber();
+ writeOutRules();
+ }
else if (argc == 2 && !strcmp("test",argv[1]))
{
testProps();
}
else if (argc == 2 && !strcmp("delete-failed",argv[1]))
{
- loadNewRules();
+ load_new_rules();
ifstream fin;
fin.open("failed.txt",ios::in);
char line[256];
createVariables();
writeOutRules();
}
-}
-
+ else if (argc == 2 && !strcmp("test2",argv[1]))
+ {
+ load_new_rules();
+ t2();
+ }
+ for (int i=0; i< saved_array.size();i++)
+ delete saved_array[i];
+}
+#if 0
// Term variables have a specified width!!!
bool
-unifyNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width)
+matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width)
{
// Pointers to the same value. OK.
if (n0 == n1)
if (n0.GetKind() == SYMBOL && strlen(n0.GetName()) == term_variable_width)
{
if (fromTo.find(n0) != fromTo.end())
- return unifyNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width);
+ return matchNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width);
fromTo.insert(make_pair(n0, n1));
- return unifyNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width);
+ return matchNode(fromTo.find(n0)->second, n1, fromTo, term_variable_width);
}
// Here:
for (int i = 0; i < n0.Degree(); i++)
{
- if (!unifyNode(n0[i], n1[i], fromTo, term_variable_width))
- return false;
+ if (!matchNode(n0[i], n1[i], fromTo, term_variable_width))
+ return false;
}
return true;
}
+#endif
+
+
+bool debug_matching = false;
+
+/////////
+// Term variables have a specified width!!!
+// "false" if it definately can't be matched with any possible commutative ordering.
+// "true" can be matched, later you need to check if all the "commutative" can be matched.
+bool
+commutative_matchNode(const ASTNode& n0, const ASTNode& n1, const int term_variable_width,
+ deque<pair< ASTNode, ASTNode> >& commutative, ASTNode& vNode, ASTNode& wNode)
+{
+ // Pointers to the same value. OK.
+ if (n0 == n1)
+ return true;
+
+ // If we try and match sub-terms of concatenations,e,g. 000::x = 000111, we want it to fail.
+ if(n0.GetValueWidth() != n1.GetValueWidth())
+ return false;
+
+ if (n0.GetKind() == SYMBOL && strlen(n0.GetName()) == term_variable_width)
+ {
+ if (n0.GetName()[0] == 'v')
+ {
+ if (vNode != mgr->ASTUndefined)
+ return commutative_matchNode(vNode, n1, term_variable_width, commutative, vNode, wNode);
+ else
+ {
+ vNode = n1;
+ return true;
+ }
+ }
+ else if (n0.GetName()[0] == 'w')
+ {
+ if (wNode != mgr->ASTUndefined)
+ return commutative_matchNode(wNode, n1, term_variable_width, commutative, vNode, wNode);
+ else
+ {
+ wNode = n1;
+ return true;
+ }
+ }
+ else
+ FatalError("nefeafs");
+ }
+
+ // Here:
+ // They could be different BVConsts, different symbols, or
+ // different functions.
+
+ if (n0.Degree() != n1.Degree() || (n0.Degree() == 0))
+ return false;
+
+ if (n0.GetKind() != n1.GetKind())
+ return false;
+
+ // If it's commutative, check it specially / seprately later.
+ if (isCommutative(n0.GetKind()) && n0.Degree() > 1)
+ {
+ commutative.push_back(make_pair(n0,n1));
+ return true;
+ }
+ else
+ {
+ for (int i = 0; i < n0.Degree(); i++)
+ {
+ if (!commutative_matchNode(n0[i], n1[i], term_variable_width,commutative,vNode,wNode))
+ return false;
+ }
+ }
+ return true;
+}
+
+//
+// Term variables have a specified width!!!
+bool
+c_matchNode(const ASTNode& n0, const ASTNode& n1, const int term_variable_width,
+ deque<pair<ASTNode, ASTNode> >& commutative_to_check, ASTNode& vNode, ASTNode& wNode)
+{
+ ASTNode vNode_copy = vNode;
+ ASTNode wNode_copy = wNode;
+
+ const int init_comm_size = commutative_to_check.size();
+
+ bool r = commutative_matchNode(n0, n1, term_variable_width, commutative_to_check, vNode,wNode);
+ assert(commutative_to_check.size() >= init_comm_size); // if anything, only pushed onto the back.
+
+ if (debug_matching)
+ {
+ cerr << "======Commut-match=======" << r << endl;
+ cerr << "given" << n0 << n1;
+ cerr << "Commutative still to match:" << endl;
+ for (int j=0;j < commutative_to_check.size(); j++)
+ {
+ cerr << "++++++++++" << endl;
+ cerr << "first" << commutative_to_check[j].first;
+ cerr << "second" << commutative_to_check[j].second;
+ }
+ cerr << "From To Map is:" << endl;
+ cerr << "vNode" << vNode;
+ cerr << "wNode" << wNode;
+ cerr << "=============";
+ }
+
+ if (!r)
+ {
+ // If it's bad we restore it all back.
+ commutative_to_check.erase(commutative_to_check.begin() + init_comm_size, commutative_to_check.end());
+ vNode = vNode_copy;
+ wNode = wNode_copy;
+ return false;
+ }
+
+ // base case.
+ if (commutative_to_check.size() == 0)
+ return r;
+
+ pair<ASTNode, ASTNode> p = commutative_to_check.back();
+ commutative_to_check.pop_back();
+ assert(p.first.GetKind() == p.second.GetKind());
+ const ASTVec& f = p.first.GetChildren();
+ ASTVec s = p.second.GetChildren(); // non-const, needs to be sorted later.
+
+ if (f.size() != s.size())
+ {
+ cerr << "different sized!!!";
+ // If it's bad we restore it all back.
+ if (commutative_to_check.size() < init_comm_size)
+ commutative_to_check.push_back(p);
+ else
+ commutative_to_check.erase(commutative_to_check.begin() + init_comm_size, commutative_to_check.end());
+
+ vNode = vNode_copy;
+ wNode = wNode_copy;
+
+ return false;
+ }
+
+ // The next_permutation function requires this.
+ sort(s.begin(), s.end());
+
+ ASTNode vNode_copy2 = vNode;
+ ASTNode wNode_copy2 = wNode;
+
+ //deque<pair<ASTNode, ASTNode> > todo_copy2 = commutative_to_check;
+ const int new_comm_size = commutative_to_check.size();
+
+
+ // Check each permutation of the commutative operation's operands.
+ do
+ {
+ // Check each of the operands matches. Store Extra away in "commutative_to_check".
+ bool good= true;
+ for (int i=0;i < f.size(); i++)
+ {
+ if (!commutative_matchNode(f[i], s[i], term_variable_width, commutative_to_check, vNode, wNode))
+ {
+ good = false;
+ break;
+ }
+ }
+
+ // Empty out the "commutative_to_check".
+ if (good)
+ if (!c_matchNode(mgr->ASTTrue, mgr->ASTTrue, term_variable_width, commutative_to_check,vNode,wNode))
+ good =false;
+
+ if (good)
+ {
+ assert(commutative_to_check.size() ==0);
+ return true; // all match.
+ }
+ else
+ {
+ vNode = vNode_copy2;
+ wNode = wNode_copy2;
+ commutative_to_check.erase(commutative_to_check.begin() + new_comm_size, commutative_to_check.end());
+ //assert(commutative_to_check == todo_copy2);
+ //commutative_to_check = todo_copy2;
+ }
+ }
+ while (next_permutation(s.begin(), s.end()));
+
+ // None of the permutations match. We return the data unchanged.
+
+ vNode = vNode_copy;
+ wNode = wNode_copy;
+
+
+ if (commutative_to_check.size() < init_comm_size)
+ commutative_to_check.push_back(p);
+ else
+ commutative_to_check.erase(commutative_to_check.begin() + init_comm_size, commutative_to_check.end());
+
+
+ return false;
+}
+
+/* This does commutative matching of nodes. A substitution to the term variables (which are the
+ * those with a name of the width specified), of n0 is found. That is if the variables of n0 are
+ * substituted with the "substitution", then it will equal n1.
+ *
+ * Initially I thought commutative matching was easy to get right. NO!
+ *
+ * NB: This uses a "statc" container so this can't be called recursively.
+ */
+bool
+commutative_matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& substitution, const int term_variable_width)
+{
+ assert(substitution.size() ==0);
+
+#ifdef PEDANTIC_MATCHING_ASSERTS
+ {
+ // There shouldn't be any term variables on the RHS.
+ vector<ASTNode> vars = getVariables(n1);
+ vector<ASTNode>::iterator it = vars.begin();
+ while (it != vars.end())
+ {
+ assert(strlen(it->GetName()) != term_variable_width);
+ assert(it->GetName()[0] == 'v' || it->GetName()[0] == 'w');
+ it++;
+ }
+ assert(vars.size() <=2);
+
+ // All the LHS variables should be term variables.
+ vars = getVariables(n0);
+ it = vars.begin();
+ while (it != vars.end())
+ {
+ assert(strlen(it->GetName()) == term_variable_width);
+ it++;
+ }
+ assert(vars.size() <=2);
+ }
+
+
+#endif
+
+ deque<pair<ASTNode, ASTNode> > commutative;
+ commutative.clear();
+
+ ASTNode vNode = mgr->ASTUndefined;
+ ASTNode wNode = mgr->ASTUndefined;
+ bool r = c_matchNode(n0,n1,term_variable_width,commutative,vNode,wNode);
+
+ if (r)
+ {
+ vector<ASTNode> s = getVariables(n0);
+ for (vector<ASTNode>::iterator it = s.begin(); it != s.end();it++)
+ {
+ assert(it->GetKind() ==SYMBOL);
+ assert(strlen(it->GetName()) == term_variable_width);
+ if (it->GetName()[0] == 'v')
+ {
+ assert(vNode != mgr->ASTUndefined);
+ assert(vNode.GetValueWidth() == it->GetValueWidth());
+ substitution.insert(make_pair(*it,vNode));
+ }
+ if (it->GetName()[0] == 'w')
+ {
+ assert(wNode != mgr->ASTUndefined);
+ assert(wNode.GetValueWidth() == it->GetValueWidth());
+ substitution.insert(make_pair(*it,wNode));
+ }
+ }
+ }
+
+ if (debug_matching)
+ {
+ cerr << "=======" << endl << "The result is: " << r << "for the inputs" << n0 << n1 << "=-===";
+ }
+
+ if (!r)
+ {
+ assert(substitution.size() == 0);
+ assert(commutative.size() ==0); // should be none left to process.
+ }
+
+ return r;
+}
+
+ASTNode
+rewriteThroughWithAIGS(const ASTNode &n_)
+{
+ assert(n_.GetType() == BITVECTOR_TYPE);
+ ASTNode f = mgr->LookupOrCreateSymbol("rewriteThroughWithAIGS");
+ f.SetValueWidth(n_.GetValueWidth());
+ ASTNode n = create(EQ, n_, f);
+
+ BBNodeManagerAIG nm;
+ BitBlaster<BBNodeAIG, BBNodeManagerAIG> bb(&nm, simp, mgr->defaultNodeFactory, &mgr->UserFlags);
+ ASTNodeMap fromTo;
+ ASTNodeMap equivs;
+ bb.getConsts(n, fromTo, equivs);
+
+ ASTNode result = n_;
+ if (equivs.size() > 0)
+ {
+ ASTNodeMap cache;
+ result = SubstitutionMap::replace(result, equivs, cache, nf, false, true);
+ }
+
+ if (fromTo.size() > 0)
+ {
+ ASTNodeMap cache;
+ result = SubstitutionMap::replace(result, fromTo, cache, nf);
+ }
+ return result;
+}
widen(const ASTNode& w, int width);
bool
-unifyNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo);
+matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width);
+
+bool
+commutative_matchNode(const ASTNode& n0, const ASTNode& n1, ASTNodeMap& fromTo, const int term_variable_width);
+
ASTNode
renameVars(const ASTNode &n);
ASTNode
-rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen);
+rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen, int depth);
bool
checkRule(const ASTNode & from, const ASTNode & to, VariableAssignment& assignment, bool&bad);
bool
isConstant(const ASTNode& n, VariableAssignment& different);
+ASTNode
+rewriteThroughWithAIGS(const ASTNode &n_);
+
+
class Rewrite_system
{
public:
void
writeOutRules();
-
- friend ASTNode rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen);
+ friend ASTNode rewrite(const ASTNode&n, const Rewrite_rule& original_rule, ASTNodeMap& seen, int depth);
RewriteRuleContainer toWrite;
+
std::map< Kind, vector<Rewrite_rule> > kind_to_rr;
- std::map< Kind, std::map< Kind, vector<Rewrite_rule> > > kind_kind_to_rr;
+ bool lookups_invalid; // whether the above table is bad.
+
+ void
+ addRuleToLookup(Rewrite_rule& r)
+ {
+ const ASTNode& from = r.getFrom();
+ kind_to_rr[from.GetKind()].push_back(r);
+ assert(from.Degree() > 0); // Shouldn't map from a constant, nor from a variable.
+ }
+
public:
Rewrite_system()
{
+ lookups_invalid = false;
}
RewriteRuleContainer::iterator
return toWrite.end();
}
+ bool areLookupsGood()
+ {
+ return lookups_invalid;
+ }
void
- addRuleToLookup(Rewrite_rule& r)
+ renumber()
{
- const ASTNode& from = r.getFrom();
- kind_to_rr[from.GetKind()].push_back(r);
-
- assert(from.Degree() > 0); // Shouldn't map from a constant, nor from a variable.
-
- if (from[0].Degree() > 0)
- kind_kind_to_rr[from.GetKind()][from[0].GetKind()].push_back(r);
+ int id=0;
+ for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++)
+ it->id = id++;
+ lookups_invalid=true;
}
void
{
toWrite.erase(it--);
cerr << "matched" << id;
+ lookups_invalid = true;
}
}
}
buildLookupTable()
{
kind_to_rr.clear();
- kind_kind_to_rr.clear();
for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++)
{
addRuleToLookup(*it);
}
- }
-
- void
- removeBad()
- {
- for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++)
- {
- if (!it->isOK())
- {
- cout << "Removing Rule that is bad";
- cout << it->getFrom();
- cout << it->getTo();
- cout << "----\n";
-
- toWrite.erase(it--);
- }
- }
+ lookups_invalid =false;
}
void
eraseDuplicates()
{
- removeBad();
toWrite.sort();
toWrite.unique();
+ lookups_invalid = true;
}
+ // Rewrite the "from" and "To", add the rule if it's still good.
+ // NB: Doesn't rebuild the lookup table.
void
- push_back(Rewrite_rule rr)
+ push_back(Rewrite_rule& rr)
{
toWrite.push_back(rr);
addRuleToLookup(rr);
erase(RewriteRuleContainer::iterator it)
{
toWrite.erase(it);
+ lookups_invalid = true;
}
int
return toWrite.size();
}
- ASTNode rewriteNode(ASTNode n)
+ static ASTNode rewriteNode(ASTNode n)
{
- Rewrite_rule null_rule( Rewrite_rule(mgr,mgr->CreateZeroConst(1), mgr->CreateZeroConst(1),0));
- return rename_then_rewrite(n,null_rule);
+ return rename_then_rewrite(n,Rewrite_rule::getNullRule());
}
void
if (i % 1000 == 0)
cout << "rewrite all:" << i << " of " << toWrite.size() << endl;
- // if not OK, should have been removed during duplicates.
- // shouldn't add extra rules that aren't ok.
- assert (it->isOK());
+ ASTNode from_wide = renameVars(it->getFrom());
+ ASTNode to_wide = renameVars(it->getTo());
+
+ // The renamed should match the original, and vice versa.
+ {
+ ASTNodeMap fromTo;
+ assert(commutative_matchNode(from_wide, it->getFrom(),fromTo,2));
+ fromTo.clear();
+ assert(commutative_matchNode(it->getFrom(), from_wide, fromTo, 1));
+ fromTo.clear();
+ assert(commutative_matchNode(to_wide, it->getTo(),fromTo,2));
+ fromTo.clear();
+ assert(commutative_matchNode(it->getTo(),to_wide, fromTo,1));
+ }
- ASTNode n = renameVars(it->getFrom());
ASTNodeMap seen;
- ASTNode rewritten_from = rewrite(n, *it,seen);
+ ASTNode from_wide_rewritten = rewrite(from_wide, *it,seen,0);
+ seen = ASTNodeMap();
+ ASTNode to_wide_rewritten = rewrite(to_wide, *it,seen,0);
+ seen = ASTNodeMap();
+
+ // Also apply the AIG rules.
+ to_wide_rewritten = rewriteThroughWithAIGS(to_wide_rewritten);
- if (n != rewritten_from)
+ if ((from_wide != from_wide_rewritten) || (to_wide != to_wide_rewritten))
{
- assert (isConstantToSat(create(EQ, rewritten_from,n)));
+ ASTNode from_rewritten = renameVarsBack(from_wide_rewritten);
+ ASTNode to_rewritten = renameVarsBack(to_wide_rewritten);
+
+ assert(BVTypeCheckRecursive(from_rewritten));
+ assert(BVTypeCheckRecursive(to_rewritten));
- rewritten_from = renameVarsBack(rewritten_from);
- ASTNode to = it->getTo();
- bool ok = orderEquivalence(rewritten_from, to);
+ assert (isConstantToSat(create(EQ, from_wide_rewritten,from_wide)));
+ assert (isConstantToSat(create(EQ, to_wide_rewritten,to_wide)));
+ assert (isConstantToSat(create(EQ, it->getFrom(),from_rewritten)));
+ assert (isConstantToSat(create(EQ, it->getTo(),to_rewritten)));
+
+ bool ok = orderEquivalence(from_rewritten, to_rewritten);
if (ok)
{
- Rewrite_rule rr(mgr, rewritten_from, to, 0);
- if (rr.isOK())
- {
- cout << "Modifying Rule\n";
- cout << "Initially From";
- cout << it->getFrom();
- cout << "new From";
- cout << rewritten_from;
- cout << "---";
-
- *it= rr;
- buildLookupTable(); // Otherwise two rules will remove each other?
- }
- else
- {
- cout << "Erasing rule";
- cout << "Initially From";
- cout << it->getFrom();
- cout << "new From";
- cout << rewritten_from;
- cout << "---";
-
- erase(it--);
- i--;
- buildLookupTable(); // Otherwise two rules will remove each other?
- }
+ Rewrite_rule rr(mgr, from_rewritten, to_rewritten, 0);
+ cout << "Modifying Rule\n";
+ cout << "Initially From";
+ cout << it->getFrom();
+ cout << "Initially To";
+ cout << it->getTo();
+ cout << "New From";
+ cout << from_rewritten;
+ cout << "New To";
+ cout << to_rewritten;
+ cout << "---";
+ cout << getDifficulty(rr.getFrom()) << " --> " << getDifficulty(rr.getTo()) << endl;
+ cout << "replacing" << it->getId() << " with " << rr.getId() << endl;
+
+ *it = rr;
+ lookups_invalid = true;
+
}
- else
- {
- if (rewritten_from != to)
- {
- cout << "Mapped but couldn't order";
- cout << rewritten_from << to;
- }
+ if (!ok)
+ {
+ cout << "Erasing bad rule.\n";
erase(it--);
i--;
- buildLookupTable(); // Otherwise two rules will remove each other?
+ lookups_invalid = true;
}
}
}
eraseDuplicates();
cout << "Size after rewriteAll:" << toWrite.size() << endl;
+ buildLookupTable();
}
void clear()
{
toWrite.clear();
+ buildLookupTable();
}
void
verifyAllwithSAT()
{
+ cerr << "Started verifying all" << endl;
for (RewriteRuleContainer::iterator it = toWrite.begin() ; it != toWrite.end(); it++)
{
VariableAssignment assignment;
assert(r);
assert(!bad);
}
- if (bits > it->getVerifiedToBits())
+ if (bits >= it->getVerifiedToBits())
it->setVerified(bits,getCurrentTime() - st);
}
}
-
void
writeOut(ostream &o)
{