summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--py/mpz.c55
-rw-r--r--tests/basics/int_big_div.py4
2 files changed, 20 insertions, 39 deletions
diff --git a/py/mpz.c b/py/mpz.c
index 51d0c9689..e0d249c21 100644
--- a/py/mpz.c
+++ b/py/mpz.c
@@ -531,60 +531,37 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
quo /= lead_den_digit;
// Multiply quo by den and subtract from num to get remainder.
- // We have different code here to handle different compile-time
- // configurations of mpz:
- //
- // 1. DIG_SIZE is stricly less than half the number of bits
- // available in mpz_dbl_dig_t. In this case we can use a
- // slightly more optimal (in time and space) routine that
- // uses the extra bits in mpz_dbl_dig_signed_t to store a
- // sign bit.
- //
- // 2. DIG_SIZE is exactly half the number of bits available in
- // mpz_dbl_dig_t. In this (common) case we need to be careful
- // not to overflow the borrow variable. And the shifting of
- // borrow needs some special logic (it's a shift right with
- // round up).
- //
+ // Must be careful with overflow of the borrow variable. Both
+ // borrow and low_digs are signed values and need signed right-shift,
+ // but x is unsigned and may take a full-range value.
const mpz_dig_t *d = den_dig;
mpz_dbl_dig_t d_norm = 0;
- mpz_dbl_dig_t borrow = 0;
+ mpz_dbl_dig_signed_t borrow = 0;
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
+ // Get the next digit in (den).
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
+ // Multiply the next digit in (quo * den).
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
- #if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
- borrow += (mpz_dbl_dig_t)*n - x; // will overflow if DIG_SIZE >= MPZ_DBL_DIG_SIZE/2
- *n = borrow & DIG_MASK;
- borrow = (mpz_dbl_dig_signed_t)borrow >> DIG_SIZE;
- #else // DIG_SIZE == MPZ_DBL_DIG_SIZE / 2
- if (x >= *n || *n - x <= borrow) {
- borrow += x - (mpz_dbl_dig_t)*n;
- *n = (-borrow) & DIG_MASK;
- borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
- } else {
- *n = ((mpz_dbl_dig_t)*n - x - borrow) & DIG_MASK;
- borrow = 0;
- }
- #endif
+ // Compute the low DIG_MASK bits of the next digit in (num - quo * den)
+ mpz_dbl_dig_signed_t low_digs = (borrow & DIG_MASK) + *n - (x & DIG_MASK);
+ // Store the digit result for (num).
+ *n = low_digs & DIG_MASK;
+ // Compute the borrow, shifted right before summing to avoid overflow.
+ borrow = (borrow >> DIG_SIZE) - (x >> DIG_SIZE) + (low_digs >> DIG_SIZE);
}
- #if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
- // Borrow was negative in the above for-loop, make it positive for next if-block.
- borrow = -borrow;
- #endif
-
// At this point we have either:
//
// 1. quo was the correct value and the most-sig-digit of num is exactly
- // cancelled by borrow (borrow == *num_dig). In this case there is
+ // cancelled by borrow (borrow + *num_dig == 0). In this case there is
// nothing more to do.
//
// 2. quo was too large, we subtracted too many den from num, and the
- // most-sig-digit of num is 1 less than borrow (borrow == *num_dig + 1).
+ // most-sig-digit of num is less than needed (borrow + *num_dig < 0).
// In this case we must reduce quo and add back den to num until the
// carry from this operation cancels out the borrow.
//
- borrow -= *num_dig;
+ borrow += *num_dig;
for (; borrow != 0; --quo) {
d = den_dig;
d_norm = 0;
@@ -595,7 +572,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
*n = carry & DIG_MASK;
carry >>= DIG_SIZE;
}
- borrow -= carry;
+ borrow += carry;
}
// store this digit of the quotient
diff --git a/tests/basics/int_big_div.py b/tests/basics/int_big_div.py
index 642f051d4..29fd40597 100644
--- a/tests/basics/int_big_div.py
+++ b/tests/basics/int_big_div.py
@@ -8,3 +8,7 @@ x = 0x8000000000000000
print((x + 1) // x)
x = 0x86c60128feff5330
print((x + 1) // x)
+
+# these check edge cases where borrow overflows
+print((2 ** 48 - 1) ** 2 // (2 ** 48 - 1))
+print((2 ** 256 - 2 ** 32) ** 2 // (2 ** 256 - 2 ** 32))