From d42b436a8094026f41e80f27d37b12079a5d859d Mon Sep 17 00:00:00 2001 From: trevor_hansen Date: Sat, 7 Apr 2012 12:42:23 +0000 Subject: [PATCH] Small improvements to the multiplication propagator. git-svn-id: https://stp-fast-prover.svn.sourceforge.net/svnroot/stp-fast-prover/trunk/stp@1630 e59a4935-1847-0410-ae03-e826735625c1 --- .../ConstantBitP_Multiplication.cpp | 493 +++++++++-------- .../multiplication/ColumnCounts.h | 512 +++++++++--------- 2 files changed, 514 insertions(+), 491 deletions(-) diff --git a/src/simplifier/constantBitP/ConstantBitP_Multiplication.cpp b/src/simplifier/constantBitP/ConstantBitP_Multiplication.cpp index 3dd4a21..84e1a3b 100644 --- a/src/simplifier/constantBitP/ConstantBitP_Multiplication.cpp +++ b/src/simplifier/constantBitP/ConstantBitP_Multiplication.cpp @@ -20,21 +20,23 @@ namespace simplifier const bool debug_multiply = false; std::ostream& log = std::cerr; +#if 0 // The maximum size of the carry into a column for MULTIPLICATION int maximumCarryInForMultiplication(int column) - { + { int result = 0; int currIndex = 0; while (currIndex < column) { - currIndex++; - result = (result + currIndex) / 2; + currIndex++; + result = (result + currIndex) / 2; } return result; - } + } +#endif Result fixIfCanForMultiplication(vector& children, const int index, const int aspirationalSum) @@ -50,56 +52,56 @@ namespace simplifier int columnUnfixed = cs.columnUnfixed; // both unfixed. int columnOneFixed = cs.columnOneFixed; // one of the values is fixed to one. int columnOnes = cs.columnOnes; // both are ones. - //int columnZeroes = cs.columnZeroes; // either are zero. Result result = NO_CHANGE; + // only one of the conditionals can run. + bool run = false; + // We need every value that is unfixed to be set to one. if (aspirationalSum == columnOnes + columnOneFixed + columnUnfixed && ((columnOneFixed + columnUnfixed) > 0)) { - for (unsigned i = 0; i <= (unsigned) index; i++) + for (unsigned i = 0; i <= (unsigned) index; i++) + { + // If y is unfixed, and it's not anded with zero. + if (!y.isFixed(i) && !(x.isFixed(index - i) && !x.getValue(index - i))) { - // If y is unfixed, and it's not anded with zero. - if (!y.isFixed(i) && !(x.isFixed(index - i) && !x.getValue(index - i))) - { - y.setFixed(i, true); - y.setValue(i, true); - result = CHANGED; - } - - if (!x.isFixed(index - i) && !(y.isFixed(i) && !y.getValue(i))) - { - x.setFixed(index - i, true); - x.setValue(index - i, true); - result = CHANGED; - } + y.setFixed(i, true); + y.setValue(i, true); + result = CHANGED; } + + if (!x.isFixed(index - i) && !(y.isFixed(i) && !y.getValue(i))) + { + x.setFixed(index - i, true); + x.setValue(index - i, true); + result = CHANGED; + } + }assert(result == CHANGED); + run = true; } // We have all the ones that we need already. (thanks). Set everything we can to zero. if (aspirationalSum == columnOnes && (columnUnfixed > 0 || columnOneFixed > 0)) { - for (unsigned i = 0; i <= (unsigned) index; i++) + assert(!run); + for (unsigned i = 0; i <= (unsigned) index; i++) + { + if (!y.isFixed(i) && x.isFixed(index - i) && x.getValue(index - i)) // one fixed. + + { + y.setFixed(i, true); + y.setValue(i, false); + result = CHANGED; + } + + if (!x.isFixed(index - i) && y.isFixed(i) && y.getValue(i)) // one fixed other way. { - if (!y.isFixed(i) && x.isFixed(index - i) && x.getValue(index - i)) // one fixed. - - { - y.setFixed(i, true); - y.setValue(i, false); - //columnZeroes++; - //columnOneFixed--; - result = CHANGED; - } - - if (!x.isFixed(index - i) && y.isFixed(i) && y.getValue(i)) // one fixed other way. - { - x.setFixed(index - i, true); - x.setValue(index - i, false); - //columnZeroes++; - //columnOneFixed--; - result = CHANGED; - } + x.setFixed(index - i, true); + x.setValue(index - i, false); + result = CHANGED; } + } } if (debug_multiply && result == CONFLICT) log << "CONFLICT" << endl; @@ -118,42 +120,42 @@ namespace simplifier for (unsigned i = 0; i < bitWidth; i++) { - yFixedFalse[i] = y.isFixed(i) && !y.getValue(i); - xFixedFalse[i] = x.isFixed(i) && !x.getValue(i); + yFixedFalse[i] = y.isFixed(i) && !y.getValue(i); + xFixedFalse[i] = x.isFixed(i) && !x.getValue(i); } for (unsigned i = 0; i < bitWidth; i++) { - // decrease using zeroes. - if (yFixedFalse[i]) + // decrease using zeroes. + if (yFixedFalse[i]) + { + for (unsigned j = i; j < bitWidth; j++) { - for (unsigned j = i; j < bitWidth; j++) - { - columnH[j]--; - } + columnH[j]--; } + } - if (xFixedFalse[i]) + if (xFixedFalse[i]) + { + for (unsigned j = i; j < bitWidth; j++) { - for (unsigned j = i; j < bitWidth; j++) - { - // if the row hasn't already been zeroed out. - if (!yFixedFalse[j - i]) - columnH[j]--; - } + // if the row hasn't already been zeroed out. + if (!yFixedFalse[j - i]) + columnH[j]--; } + } - // check if there are any pairs of ones. - if (x.isFixed(i) && x.getValue(i)) - for (unsigned j = 0; j < (bitWidth - i); j++) + // check if there are any pairs of ones. + if (x.isFixed(i) && x.getValue(i)) + for (unsigned j = 0; j < (bitWidth - i); j++) + { + assert(i + j < bitWidth); + if (y.isFixed(j) && y.getValue(j)) { - assert(i + j < bitWidth); - if (y.isFixed(j) && y.getValue(j)) - { - // a pair of ones. Increase the lower bound. - columnL[i + j]++; - } + // a pair of ones. Increase the lower bound. + columnL[i + j]++; } + } } return NO_CHANGE; } @@ -169,14 +171,14 @@ namespace simplifier /***NB < to ***/ for (int i = from; i < to; i++) { - if (y[i] == '*') - { - y.setFixed(i, true); - y.setValue(i, false); - r = CHANGED; - } - else if (y[i] == '1') - return CONFLICT; + if (y[i] == '*') + { + y.setFixed(i, true); + y.setValue(i, false); + r = CHANGED; + } + else if (y[i] == '1') + return CONFLICT; } return r; } @@ -197,13 +199,13 @@ namespace simplifier for (int i = output.getWidth() - 1; i > maxOutputOneFromInputs; i--) if (!output.isFixed(i)) { - output.setFixed(i, true); - output.setValue(i, false); + output.setFixed(i, true); + output.setValue(i, false); } else { - if (output.getValue(i)) - return CONFLICT; + if (output.getValue(i)) + return CONFLICT; } return NOT_IMPLEMENTED; @@ -226,11 +228,11 @@ namespace simplifier for (int i = 0; i < bitWidth; i++) { - if (x[i] == '1' || x[i] == '*') - CONSTANTBV::BitVector_Bit_On(x_c, i); + if (x[i] == '1' || x[i] == '*') + CONSTANTBV::BitVector_Bit_On(x_c, i); - if (y[i] == '1' || y[i] == '*') - CONSTANTBV::BitVector_Bit_On(y_c, i); + if (y[i] == '1' || y[i] == '*') + CONSTANTBV::BitVector_Bit_On(y_c, i); } BEEV::CBV result = CONSTANTBV::BitVector_Create(2 * bitWidth + 1, true); @@ -239,22 +241,22 @@ namespace simplifier for (int j = (2 * bitWidth) - 1; j >= 0; j--) { - if (CONSTANTBV::BitVector_bit_test(result, j)) - break; - if (j < bitWidth) + if (CONSTANTBV::BitVector_bit_test(result, j)) + break; + if (j < bitWidth) + { + if (!output.isFixed(j)) { - if (!output.isFixed(j)) - { - output.setFixed(j, true); - output.setValue(j, false); - } - else - { - if (output.getValue(j)) - return CONFLICT; - } - + output.setFixed(j, true); + output.setValue(j, false); } + else + { + if (output.getValue(j)) + return CONFLICT; + } + + } } #ifndef NDEBUG @@ -290,29 +292,29 @@ namespace simplifier bool done = false; for (int i = x_min; i <= std::min(x_max, bitwidth - 1); i++) { - if (x[i] == '1') - break; + if (x[i] == '1') + break; - if (x[i] == '0') - continue; + if (x[i] == '0') + continue; - assert(!done); - for (int j = y_min; j <= std::min(y_max, output_max); j++) - { - if (j + i >= bitwidth || (y[j] != '0' && output[i + j] != '0')) - { - done = true; - break; - } - } - if (!done) + assert(!done); + for (int j = y_min; j <= std::min(y_max, output_max); j++) + { + if (j + i >= bitwidth || (y[j] != '0' && output[i + j] != '0')) { - x.setFixed(i, true); - x.setValue(i, false); - r = CHANGED; - } - else + done = true; break; + } + } + if (!done) + { + x.setFixed(i, true); + x.setValue(i, false); + r = CHANGED; + } + else + break; } return r; } @@ -360,13 +362,13 @@ namespace simplifier if (xBottom > yBottom) { - toInvert = &x; - toSet = &y; + toInvert = &x; + toSet = &y; } else { - toInvert = &y; - toSet = &x; + toInvert = &y; + toSet = &x; } invertCount--; // position of the least fixed. @@ -381,52 +383,52 @@ namespace simplifier if (CONSTANTBV::BitVector_bit_test(toInvertCBV, 0)) { - if (debug_multiply) - cerr << "Value to Invert:" << *toInvertCBV << endl; + if (debug_multiply) + cerr << "Value to Invert:" << *toInvertCBV << endl; - BEEV::Simplifier simplifier(bm); - BEEV::CBV inverse = simplifier.MultiplicativeInverse(bm->CreateBVConst(toInvertCBV, width)).GetBVConst(); - BEEV::CBV toMultiplyBy = output.GetBVConst(invertCount, 0); + BEEV::Simplifier simplifier(bm); + BEEV::CBV inverse = simplifier.MultiplicativeInverse(bm->CreateBVConst(toInvertCBV, width)).GetBVConst(); + BEEV::CBV toMultiplyBy = output.GetBVConst(invertCount, 0); - BEEV::CBV toSetEqualTo = CONSTANTBV::BitVector_Create(2 * (width), true); + BEEV::CBV toSetEqualTo = CONSTANTBV::BitVector_Create(2 * (width), true); - CONSTANTBV::ErrCode ec = CONSTANTBV::BitVector_Multiply(toSetEqualTo, inverse, toMultiplyBy); - if (ec != CONSTANTBV::ErrCode_Ok) - { - assert(false); - throw 2314231; - } + CONSTANTBV::ErrCode ec = CONSTANTBV::BitVector_Multiply(toSetEqualTo, inverse, toMultiplyBy); + if (ec != CONSTANTBV::ErrCode_Ok) + { + assert(false); + throw 2314231; + } - if (false && debug_multiply) + if (false && debug_multiply) + { + cerr << x << "*" << y << "=" << output << endl; + cerr << "Invert bit count" << invertCount << endl; + cerr << "To set" << *toSet; + cerr << "To set equal to:" << *toSetEqualTo << endl; + } + + // Write in the value. + for (int i = 0; i <= invertCount; i++) + { + bool expected = CONSTANTBV::BitVector_bit_test(toSetEqualTo, i); + + if (toSet->isFixed(i) && (toSet->getValue(i) ^ expected)) { - cerr << x << "*" << y << "=" << output << endl; - cerr << "Invert bit count" << invertCount << endl; - cerr << "To set" << *toSet; - cerr << "To set equal to:" << *toSetEqualTo << endl; + status = CONFLICT; } - - // Write in the value. - for (int i = 0; i <= invertCount; i++) + else if (!toSet->isFixed(i)) { - bool expected = CONSTANTBV::BitVector_bit_test(toSetEqualTo, i); - - if (toSet->isFixed(i) && (toSet->getValue(i) ^ expected)) - { - status = CONFLICT; - } - else if (!toSet->isFixed(i)) - { - toSet->setFixed(i, true); - toSet->setValue(i, expected); - } + toSet->setFixed(i, true); + toSet->setValue(i, expected); } + } - // Don't delete the "inverse" because it's reference counted by the ASTNode. + // Don't delete the "inverse" because it's reference counted by the ASTNode. - CONSTANTBV::BitVector_Destroy(toSetEqualTo); - CONSTANTBV::BitVector_Destroy(toMultiplyBy); + CONSTANTBV::BitVector_Destroy(toSetEqualTo); + CONSTANTBV::BitVector_Destroy(toMultiplyBy); - //cerr << "result" << *toSet; + //cerr << "result" << *toSet; } else CONSTANTBV::BitVector_Destroy(toInvertCBV); @@ -459,22 +461,22 @@ namespace simplifier CONSTANTBV::ErrCode ec = CONSTANTBV::BitVector_Multiply(result, xCBV, yCBV); if (ec != CONSTANTBV::ErrCode_Ok) { - assert(false); - throw 2314231; + assert(false); + throw 2314231; } Result status = NOT_IMPLEMENTED; for (int i = 0; i <= minV; i++) { - bool expected = CONSTANTBV::BitVector_bit_test(result, i); + bool expected = CONSTANTBV::BitVector_bit_test(result, i); - if (output.isFixed(i) && (output.getValue(i) ^ expected)) - status = CONFLICT; - else if (!output.isFixed(i)) - { - output.setFixed(i, true); - output.setValue(i, expected); - } + if (output.isFixed(i) && (output.getValue(i) ^ expected)) + status = CONFLICT; + else if (!output.isFixed(i)) + { + output.setFixed(i, true); + output.setValue(i, expected); + } } CONSTANTBV::BitVector_Destroy(xCBV); @@ -489,12 +491,12 @@ namespace simplifier { for (int i = 0; i < bitWidth; i++) { - log << sumL[bitWidth - 1 - i] << " "; + log << sumL[bitWidth - 1 - i] << " "; } log << endl; for (int i = 0; i < bitWidth; i++) { - log << sumH[bitWidth - 1 - i] << " "; + log << sumH[bitWidth - 1 - i] << " "; } log << endl; } @@ -515,10 +517,10 @@ namespace simplifier if (debug_multiply) { - cerr << "Initial Fixing"; - cerr << x << "*"; - cerr << y << "="; - cerr << output << endl; + cerr << "Initial Fixing"; + cerr << x << "*"; + cerr << y << "="; + cerr << output << endl; } Result r = useTrailingZeroesToFix(x, y, output); @@ -528,96 +530,93 @@ namespace simplifier bool changed = true; while (changed) { - changed = false; - signed columnH[bitWidth]; // maximum number of true partial products. - signed columnL[bitWidth]; // minimum "" "" - signed sumH[bitWidth]; - signed sumL[bitWidth]; - - ColumnCounts cc(columnH, columnL, sumH, sumL, bitWidth); + changed = false; + signed columnH[bitWidth]; // maximum number of true partial products. + signed columnL[bitWidth]; // minimum "" "" + signed sumH[bitWidth]; + signed sumL[bitWidth]; - // Use the number of zeroes and ones in a column to update the possible counts. - adjustColumns(x, y, columnL, columnH); + ColumnCounts cc(columnH, columnL, sumH, sumL, bitWidth, output); - cc.rebuildSums(); - Result r = cc.fixedPoint(output); + // Use the number of zeroes and ones in a column to update the possible counts. + adjustColumns(x, y, columnL, columnH); - assert(cc.fixedPoint(output) != CHANGED); - // idempotent + cc.rebuildSums(); + Result r = cc.fixedPoint(); - if (r == CONFLICT) - return CONFLICT; + if (r == CONFLICT) + return CONFLICT; - r = NO_CHANGE; + r = NO_CHANGE; - // If any of the sums have a cardinality of 1. Set the result. - for (unsigned column = 0; column < bitWidth; column++) + // If any of the sums have a cardinality of 1. Set the result. + for (unsigned column = 0; column < bitWidth; column++) + { + if (cc.sumL[column] == cc.sumH[column]) { - if (cc.sumL[column] == cc.sumH[column]) - { - //(1) If the output has a known value. Set the output. - bool newValue = !(sumH[column] % 2 == 0); - if (!output.isFixed(column)) - { - output.setFixed(column, true); - output.setValue(column, newValue); - r = CHANGED; - } - else if (output.getValue(column) != newValue) - return CONFLICT; - } + //(1) If the output has a known value. Set the output. + bool newValue = !(sumH[column] % 2 == 0); + if (!output.isFixed(column)) + { + output.setFixed(column, true); + output.setValue(column, newValue); + r = CHANGED; + } + else if (output.getValue(column) != newValue) + return CONFLICT; } + } - if (CHANGED == r) - changed = true; + if (CHANGED == r) + changed = true; - for (unsigned column = 0; column < bitWidth; column++) + for (unsigned column = 0; column < bitWidth; column++) + { + if (cc.columnL[column] == cc.columnH[column]) { - if (cc.columnL[column] == cc.columnH[column]) - { - //(2) Knowledge of the sum may fix the operands. - Result tempResult = fixIfCanForMultiplication(children, column, cc.columnH[column]); + //(2) Knowledge of the sum may fix the operands. + Result tempResult = fixIfCanForMultiplication(children, column, cc.columnH[column]); - if (CONFLICT == tempResult) - return CONFLICT; + if (CONFLICT == tempResult) + return CONFLICT; - if (CHANGED == tempResult) - r = CHANGED; - } + if (CHANGED == tempResult) + r = CHANGED; } + } - if (debug_multiply) - { - cerr << "At end"; - cerr << "x:" << x << endl; - cerr << "y:" << y << endl; - cerr << "output:" << output << endl; - } + if (debug_multiply) + { + cerr << "At end"; + cerr << "x:" << x << endl; + cerr << "y:" << y << endl; + cerr << "output:" << output << endl; + } - assert(CONFLICT != r); + assert(CONFLICT != r); - if (CHANGED == r) - changed = true; + if (CHANGED == r) + changed = true; - if (ms != NULL) - { - *ms = MultiplicationStats(bitWidth, cc.columnL, cc.columnH, cc.sumL, cc.sumH); - ms->x = *children[0]; - ms->y = *children[1]; - ms->r = output; - } + if (ms != NULL) + { + *ms = MultiplicationStats(bitWidth, cc.columnL, cc.columnH, cc.sumL, cc.sumH); + ms->x = *children[0]; + ms->y = *children[1]; + ms->r = output; + } - if (changed) - { - useTrailingZeroesToFix(x, y, output); - // if (r == NO_CHANGE) - // changed= false; - } + if (changed) + { + useTrailingZeroesToFix(x, y, output); + // if (r == NO_CHANGE) + // changed= false; + } } if (children[0]->isTotallyFixed() && children[1]->isTotallyFixed()) { - assert(output.isTotallyFixed()); + assert(output.isTotallyFixed()); } // The below assertions are for performance only. It's not maximally precise anyway!!! @@ -625,22 +624,22 @@ namespace simplifier #ifndef NDEBUG if (r != CONFLICT) { - FixedBits x_c(x), y_c(y), o_c(output); + FixedBits x_c(x), y_c(y), o_c(output); - // These are subsumed by the consistency over the columns.. - useTrailingFixedToFix(x_c, y_c, o_c); - useLeadingZeroesToFix(x_c, y_c, o_c); - useInversesToSolve(x_c, y_c, o_c, bm); + // These are subsumed by the consistency over the columns.. + useTrailingFixedToFix(x_c, y_c, o_c); + useLeadingZeroesToFix(x_c, y_c, o_c); + useInversesToSolve(x_c, y_c, o_c, bm); - // This one should have been called to fixed point! - useTrailingZeroesToFix(x_c, y_c, o_c); + // This one should have been called to fixed point! + useTrailingZeroesToFix(x_c, y_c, o_c); - if (!FixedBits::equals(x_c, x) || !FixedBits::equals(y_c, y) || !FixedBits::equals(o_c, output)) - { - cerr << x << y << output << endl; - cerr << x_c << y_c << o_c << endl; - assert(false); - } + if (!FixedBits::equals(x_c, x) || !FixedBits::equals(y_c, y) || !FixedBits::equals(o_c, output)) + { + cerr << x << y << output << endl; + cerr << x_c << y_c << o_c << endl; + assert(false); + } } #endif diff --git a/src/simplifier/constantBitP/multiplication/ColumnCounts.h b/src/simplifier/constantBitP/multiplication/ColumnCounts.h index 890a152..39e9da0 100644 --- a/src/simplifier/constantBitP/multiplication/ColumnCounts.h +++ b/src/simplifier/constantBitP/multiplication/ColumnCounts.h @@ -10,250 +10,274 @@ namespace simplifier { -namespace constantBitP -{ - -extern std::ostream& log; - - -struct Interval -{ - int& low; - int& high; - Interval(int& _low, int& _high) : - low(_low), high(_high) - { - } -}; - -struct ColumnCounts -{ - signed *columnH; // maximum number of true partial products. - signed *columnL; // minimum "" "" - signed *sumH; - signed *sumL; - unsigned int bitWidth; - - ColumnCounts(signed _columnH[], signed _columnL[], signed _sumH[], - signed _sumL[], unsigned _bitWidth) : - columnH(_columnH), columnL(_columnL), sumH(_sumH), sumL(_sumL) - { - // setup the low and highs. - bitWidth = _bitWidth; - // initialise 'em. - for (unsigned i = 0; i < bitWidth; i++) - { - columnL[i] = 0; - columnH[i] = i + 1; - } - } - - void rebuildSums() - { - // Initialise sums. - sumL[0] = columnL[0]; - sumH[0] = columnH[0]; - for (unsigned i = /**/1 /**/; i < bitWidth; i++) - { - assert((columnH[i] >= columnL[i]) && (columnL[i] >= 0)); - sumH[i] = columnH[i] + (sumH[i - 1] / 2); - sumL[i] = columnL[i] + (sumL[i - 1] / 2); - } - } - - void print(string message) - { - log << message << endl; - log << " columnL:"; - for (unsigned i = 0; i < bitWidth; i++) - { - log << columnL[bitWidth - 1 - i] << " "; - } - log << endl; - log << " columnH:"; - for (unsigned i = 0; i < bitWidth; i++) - { - log << columnH[bitWidth - 1 - i] << " "; - } - log << endl; - log << " sumL: "; - - for (unsigned i = 0; i < bitWidth; i++) - { - log << sumL[bitWidth - 1 - i] << " "; - } - log << endl; - log << " sumH: "; - for (unsigned i = 0; i < bitWidth; i++) - { - log << sumH[bitWidth - 1 - i] << " "; - } - log << endl; - } - - // update the sum of a column to the parity of the output for that column. e.g. [0,2] if the answer is 1, goes to [1,1]. - Result snapTo(const FixedBits& output) - { - Result r = NO_CHANGE; - - // Make sure each column's sum is consistent with the output. - for (unsigned i = 0; i < bitWidth; i++) - { - if (output.isFixed(i)) - { - //bool changed = false; - int expected = output.getValue(i) ? 1 : 0; - - // output is true. So the maximum and minimum can only be even. - if ((sumH[i] & 1) != expected) - { - sumH[i]--; - r = CHANGED; - } - if ((sumL[i] & 1) != expected) - { - sumL[i]++; - r = CHANGED; - } - - if (((sumH[i] < sumL[i]) || (sumL[i] < 0))) - return CONFLICT; - } - } - return r; - } - - bool inConflict() - { - for (unsigned i = 0; i < bitWidth; i++) - if ((sumL[i] > sumH[i]) || (columnL[i] > columnH[i])) - return true; - return false; - - } - - Result fixedPoint(FixedBits & output) - { - if (inConflict()) - return CONFLICT; - - bool changed = true; - bool totalChanged = false; - - while (changed) - { - changed = false; - - Result r = snapTo(output); - if (r == CHANGED) - changed = true; - if (r == CONFLICT) - return CONFLICT; - - r = propagate(output); - if (r == CHANGED) - changed = true; - if (r == CONFLICT) - return CONFLICT; - - if (changed) - totalChanged = true; - } - - if (totalChanged) - return CHANGED; - else - return NO_CHANGE; - } - - //Assert that all the counts are consistent. - Result propagate(const FixedBits& output) - { - bool changed = false; - - int i = 0; - - // - if (sumL[i] > columnL[i]) - { - columnL[i] = sumL[i]; - changed = true; - } - if (sumL[i] < columnL[i]) - { - sumL[i] = columnL[i]; - changed = true; - } - if (sumH[i] < columnH[i]) - { - columnH[i] = sumH[i]; - changed = true; - } - if (sumH[i] > columnH[i]) - { - sumH[i] = columnH[i]; - changed = true; - } - - for (unsigned i = 1; i < bitWidth; i++) - { - Interval a(sumL[i], sumH[i]); - Interval b(columnL[i], columnH[i]); - - int low = sumL[i - 1] / 2; // interval takes references. - int high = sumH[i - 1] / 2; - Interval c(low, high); - - if (a.low < b.low + c.low) - { - a.low = b.low + c.low; - changed = true; - } - - if (a.high > b.high + c.high) - { - changed = true; - a.high = b.high + c.high; - } - - if (a.low - b.high > c.low) - { - int toAssign = ((a.low - b.high) * 2); - assert(toAssign > sumL[i-1]); - sumL[i - 1] = toAssign; - changed = true; - } - - if (a.high - b.low < c.high) - { - int toAssign = ((a.high - b.low) * 2) + 1; - assert(toAssign < sumH[i-1]); - sumH[i - 1] = toAssign; - changed = true; - } - - if (a.low - c.high > b.low) - { - b.low = a.low - c.high; - changed = true; - } - - if (a.high - c.low < b.high) - { - b.high = a.high - c.low; - changed = true; - } - - } - if (changed) - return CHANGED; - else - return NO_CHANGE; - } - -}; - -} + namespace constantBitP + { + + extern std::ostream& log; + + struct Interval + { + int& low; + int& high; + Interval(int& _low, int& _high) : + low(_low), high(_high) + { + } + }; + + struct ColumnCounts + { + signed *columnH; // maximum number of true partial products. + signed *columnL; // minimum "" "" + signed *sumH; + signed *sumL; + unsigned int bitWidth; + const FixedBits & output; + + ColumnCounts(signed _columnH[], signed _columnL[], signed _sumH[], signed _sumL[], unsigned _bitWidth, + FixedBits& output_) : + columnH(_columnH), columnL(_columnL), sumH(_sumH), sumL(_sumL), output(output_) + { + // setup the low and highs. + bitWidth = _bitWidth; + // initialise 'em. + for (unsigned i = 0; i < bitWidth; i++) + { + columnL[i] = 0; + columnH[i] = i + 1; + } + } + + void + rebuildSums() + { + // Initialise sums. + sumL[0] = columnL[0]; + sumH[0] = columnH[0]; + snapTo(0); + + for (unsigned i = /**/1 /**/; i < bitWidth; i++) + { + assert((columnH[i] >= columnL[i]) && (columnL[i] >= 0)); + sumH[i] = columnH[i] + (sumH[i - 1] / 2); + sumL[i] = columnL[i] + (sumL[i - 1] / 2); + if (output.isFixed(i)) + snapTo(i); + } + } + + void + print(string message) + { + log << message << endl; + log << " columnL:"; + for (unsigned i = 0; i < bitWidth; i++) + { + log << columnL[bitWidth - 1 - i] << " "; + } + log << endl; + log << " columnH:"; + for (unsigned i = 0; i < bitWidth; i++) + { + log << columnH[bitWidth - 1 - i] << " "; + } + log << endl; + log << " sumL: "; + + for (unsigned i = 0; i < bitWidth; i++) + { + log << sumL[bitWidth - 1 - i] << " "; + } + log << endl; + log << " sumH: "; + for (unsigned i = 0; i < bitWidth; i++) + { + log << sumH[bitWidth - 1 - i] << " "; + } + log << endl; + } + + Result + snapTo(int i) + { + Result r = NO_CHANGE; + if (output.isFixed(i)) + { + //bool changed = false; + int expected = output.getValue(i) ? 1 : 0; + + // output is true. So the maximum and minimum can only be even. + if ((sumH[i] & 1) != expected) + { + sumH[i]--; + r = CHANGED; + } + if ((sumL[i] & 1) != expected) + { + sumL[i]++; + r = CHANGED; + } + + if (((sumH[i] < sumL[i]) || (sumL[i] < 0))) + return CONFLICT; + } + return r; + } + + // update the sum of a column to the parity of the output for that column. e.g. [0,2] if the answer is 1, goes to [1,1]. + Result + snapTo() + { + Result r = NO_CHANGE; + + // Make sure each column's sum is consistent with the output. + for (unsigned i = 0; i < bitWidth; i++) + { + r = merge(r, snapTo(i)); + } + return r; + } + + bool + inConflict() + { + for (unsigned i = 0; i < bitWidth; i++) + if ((sumL[i] > sumH[i]) || (columnL[i] > columnH[i])) + return true; + return false; + + } + + Result + fixedPoint() + { + if (inConflict()) + return CONFLICT; + + bool changed = true; + bool totalChanged = false; + + while (changed) + { + changed = false; + + Result r = snapTo(); + if (r == CHANGED) + changed = true; + if (r == CONFLICT) + return CONFLICT; + + r = propagate(); + if (r == CHANGED) + changed = true; + if (r == CONFLICT) + return CONFLICT; + + if (changed) + totalChanged = true; + } + + if (inConflict()) + return CONFLICT; + + assert(propagate() == NO_CHANGE); + assert(snapTo() == NO_CHANGE); + + if (totalChanged) + return CHANGED; + else + return NO_CHANGE; + } + + //Assert that all the counts are consistent. + Result + propagate() + { + bool changed = false; + + int i = 0; + + // + if (sumL[i] > columnL[i]) + { + columnL[i] = sumL[i]; + changed = true; + } + if (sumL[i] < columnL[i]) + { + sumL[i] = columnL[i]; + changed = true; + } + if (sumH[i] < columnH[i]) + { + columnH[i] = sumH[i]; + changed = true; + } + if (sumH[i] > columnH[i]) + { + sumH[i] = columnH[i]; + changed = true; + } + + for (unsigned i = 1; i < bitWidth; i++) + { + Interval a(sumL[i], sumH[i]); + Interval b(columnL[i], columnH[i]); + + int low = sumL[i - 1] / 2; // interval takes references. + int high = sumH[i - 1] / 2; + Interval c(low, high); + + if (a.low < b.low + c.low) + { + a.low = b.low + c.low; + changed = true; + } + + if (a.high > b.high + c.high) + { + changed = true; + a.high = b.high + c.high; + } + + if (a.low - b.high > c.low) + { + int toAssign = ((a.low - b.high) * 2); + assert(toAssign > sumL[i - 1]); + sumL[i - 1] = toAssign; + changed = true; + } + + if (a.high - b.low < c.high) + { + int toAssign = ((a.high - b.low) * 2) + 1; + assert(toAssign < sumH[i - 1]); + sumH[i - 1] = toAssign; + changed = true; + } + + if (a.low - c.high > b.low) + { + b.low = a.low - c.high; + changed = true; + } + + if (a.high - c.low < b.high) + { + b.high = a.high - c.low; + changed = true; + } + + } + if (changed) + return CHANGED; + else + return NO_CHANGE; + } + + }; + + } } -- 2.47.3