summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--py/mpz.c114
-rw-r--r--py/mpz.h4
2 files changed, 48 insertions, 70 deletions
diff --git a/py/mpz.c b/py/mpz.c
index 018a5454f..78dee3132 100644
--- a/py/mpz.c
+++ b/py/mpz.c
@@ -537,83 +537,57 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
// not to overflow the borrow variable. And the shifting of
// borrow needs some special logic (it's a shift right with
// round up).
-
- if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {
- const mpz_dig_t *d = den_dig;
- mpz_dbl_dig_t d_norm = 0;
- mpz_dbl_dig_signed_t borrow = 0;
-
- for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
- d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
- borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
- *n = borrow & DIG_MASK;
- borrow >>= DIG_SIZE;
- }
- borrow += *num_dig; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2
- *num_dig = borrow & DIG_MASK;
- borrow >>= DIG_SIZE;
-
- // adjust quotient if it is too big
- for (; borrow != 0; --quo) {
- d = den_dig;
- d_norm = 0;
- mpz_dbl_dig_t carry = 0;
- for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
- d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
- carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
- *n = carry & DIG_MASK;
- carry >>= DIG_SIZE;
- }
- carry += *num_dig;
- *num_dig = carry & DIG_MASK;
- carry >>= DIG_SIZE;
-
- borrow += carry;
- }
- } else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2
- const mpz_dig_t *d = den_dig;
- mpz_dbl_dig_t d_norm = 0;
- mpz_dbl_dig_t borrow = 0;
-
- for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
- d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
- mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
- if (x >= *n || *n - x <= borrow) {
- borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n;
- *n = (-borrow) & DIG_MASK;
- borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
- } else {
- *n = ((mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)borrow) & DIG_MASK;
- borrow = 0;
- }
- }
- if (borrow >= *num_dig) {
- borrow -= (mpz_dbl_dig_t)*num_dig;
- *num_dig = (-borrow) & DIG_MASK;
+ //
+ const mpz_dig_t *d = den_dig;
+ mpz_dbl_dig_t d_norm = 0;
+ mpz_dbl_dig_t borrow = 0;
+ for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
+ d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
+ mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
+ #if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
+ borrow += (mpz_dbl_dig_t)*n - x; // will overflow if DIG_SIZE >= MPZ_DBL_DIG_SIZE/2
+ *n = borrow & DIG_MASK;
+ borrow = (mpz_dbl_dig_signed_t)borrow >> DIG_SIZE;
+ #else // DIG_SIZE == MPZ_DBL_DIG_SIZE / 2
+ if (x >= *n || *n - x <= borrow) {
+ borrow += x - (mpz_dbl_dig_t)*n;
+ *n = (-borrow) & DIG_MASK;
borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
} else {
- *num_dig = (*num_dig - borrow) & DIG_MASK;
+ *n = ((mpz_dbl_dig_t)*n - x - borrow) & DIG_MASK;
borrow = 0;
}
+ #endif
+ }
- // adjust quotient if it is too big
- for (; borrow != 0; --quo) {
- d = den_dig;
- d_norm = 0;
- mpz_dbl_dig_t carry = 0;
- for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
- d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
- carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
- *n = carry & DIG_MASK;
- carry >>= DIG_SIZE;
- }
- carry += (mpz_dbl_dig_t)*num_dig;
- *num_dig = carry & DIG_MASK;
- carry >>= DIG_SIZE;
+ #if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
+ // Borrow was negative in the above for-loop, make it positive for next if-block.
+ borrow = -borrow;
+ #endif
- //assert(borrow >= carry); // enable this to check the logic
- borrow -= carry;
+ // At this point we have either:
+ //
+ // 1. quo was the correct value and the most-sig-digit of num is exactly
+ // cancelled by borrow (borrow == *num_dig). In this case there is
+ // nothing more to do.
+ //
+ // 2. quo was too large, we subtracted too many den from num, and the
+ // most-sig-digit of num is 1 less than borrow (borrow == *num_dig + 1).
+ // In this case we must reduce quo and add back den to num until the
+ // carry from this operation cancels out the borrow.
+ //
+ borrow -= *num_dig;
+ for (; borrow != 0; --quo) {
+ d = den_dig;
+ d_norm = 0;
+ mpz_dbl_dig_t carry = 0;
+ for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
+ d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
+ carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
+ *n = carry & DIG_MASK;
+ carry >>= DIG_SIZE;
}
+ borrow -= carry;
}
// store this digit of the quotient
diff --git a/py/mpz.h b/py/mpz.h
index e2d0c30aa..3c36cac66 100644
--- a/py/mpz.h
+++ b/py/mpz.h
@@ -55,18 +55,22 @@
#endif
#if MPZ_DIG_SIZE > 16
+#define MPZ_DBL_DIG_SIZE (64)
typedef uint32_t mpz_dig_t;
typedef uint64_t mpz_dbl_dig_t;
typedef int64_t mpz_dbl_dig_signed_t;
#elif MPZ_DIG_SIZE > 8
+#define MPZ_DBL_DIG_SIZE (32)
typedef uint16_t mpz_dig_t;
typedef uint32_t mpz_dbl_dig_t;
typedef int32_t mpz_dbl_dig_signed_t;
#elif MPZ_DIG_SIZE > 4
+#define MPZ_DBL_DIG_SIZE (16)
typedef uint8_t mpz_dig_t;
typedef uint16_t mpz_dbl_dig_t;
typedef int16_t mpz_dbl_dig_signed_t;
#else
+#define MPZ_DBL_DIG_SIZE (8)
typedef uint8_t mpz_dig_t;
typedef uint8_t mpz_dbl_dig_t;
typedef int8_t mpz_dbl_dig_signed_t;