diff options
Diffstat (limited to 'py/mpz.c')
| -rw-r--r-- | py/mpz.c | 87 | 
1 files changed, 71 insertions, 16 deletions
| @@ -37,7 +37,9 @@  #if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ  #define DIG_SIZE (MPZ_DIG_SIZE) -#define DIG_MASK ((1 << DIG_SIZE) - 1) +#define DIG_MASK ((1L << DIG_SIZE) - 1) +#define DIG_MSB  (1L << (DIG_SIZE - 1)) +#define DIG_BASE (1L << DIG_SIZE)  /*   mpz is an arbitrary precision integer type with a public API. @@ -61,7 +63,7 @@ STATIC mp_int_t mpn_cmp(const mpz_dig_t *idig, mp_uint_t ilen, const mpz_dig_t *      if (ilen > jlen) { return 1; }      for (idig += ilen, jdig += ilen; ilen > 0; --ilen) { -        mp_int_t cmp = *(--idig) - *(--jdig); +        mpz_dbl_dig_signed_t cmp = (mpz_dbl_dig_t)*(--idig) - (mpz_dbl_dig_t)*(--jdig);          if (cmp < 0) { return -1; }          if (cmp > 0) { return 1; }      } @@ -127,7 +129,7 @@ STATIC mp_uint_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mp_ui      for (mp_uint_t i = jlen; i > 0; i--, idig++, jdig++) {          mpz_dbl_dig_t d = *jdig;          if (i > 1) { -            d |= jdig[1] << DIG_SIZE; +            d |= (mpz_dbl_dig_t)jdig[1] << DIG_SIZE;          }          d >>= n_part;          *idig = d & DIG_MASK; @@ -152,7 +154,7 @@ STATIC mp_uint_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,      jlen -= klen;      for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) { -        carry += *jdig + *kdig; +        carry += (mpz_dbl_dig_t)*jdig + (mpz_dbl_dig_t)*kdig;          *idig = carry & DIG_MASK;          carry >>= DIG_SIZE;      } @@ -182,7 +184,7 @@ STATIC mp_uint_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, mp_uint_t jlen,      jlen -= klen;      for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) { -        borrow += *jdig - *kdig; +        borrow += (mpz_dbl_dig_t)*jdig - (mpz_dbl_dig_t)*kdig;          *idig = borrow & DIG_MASK;          borrow >>= DIG_SIZE;      } @@ -301,7 +303,7 @@ STATIC mp_uint_t mpn_mul_dig_add_dig(mpz_dig_t *idig, mp_uint_t ilen, mpz_dig_t      mpz_dbl_dig_t carry = dadd;      for (; ilen > 0; --ilen, ++idig) { -        carry += *idig * dmul; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2 +        carry += (mpz_dbl_dig_t)*idig * (mpz_dbl_dig_t)dmul; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2          *idig = carry & DIG_MASK;          carry >>= DIG_SIZE;      } @@ -328,7 +330,7 @@ STATIC mp_uint_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, mp_uint_t jlen, mpz_d          mp_uint_t jl = jlen;          for (mpz_dig_t *jd = jdig; jl > 0; --jl, ++jd, ++id) { -            carry += *id + *jd * *kdig; // will never overflow so long as DIG_SIZE <= WORD_SIZE / 2 +            carry += (mpz_dbl_dig_t)*id + (mpz_dbl_dig_t)*jd * (mpz_dbl_dig_t)*kdig; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2              *id = carry & DIG_MASK;              carry >>= DIG_SIZE;          } @@ -375,7 +377,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,      // count number of leading zeros in leading digit of denominator      {          mpz_dig_t d = den_dig[den_len - 1]; -        while ((d & (1 << (DIG_SIZE - 1))) == 0) { +        while ((d & DIG_MSB) == 0) {              d <<= 1;              ++norm_shift;          } @@ -412,21 +414,36 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,      // keep going while we have enough digits to divide      while (*num_len > den_len) { -        mpz_dbl_dig_t quo = (*num_dig << DIG_SIZE) | num_dig[-1]; +        mpz_dbl_dig_t quo = ((mpz_dbl_dig_t)*num_dig << DIG_SIZE) | num_dig[-1];          // get approximate quotient          quo /= lead_den_digit; -        // multiply quo by den and subtract from num get remainder -        { +        // Multiply quo by den and subtract from num to get remainder. +        // We have different code here to handle different compile-time +        // configurations of mpz: +        // +        //   1. DIG_SIZE is stricly less than half the number of bits +        //      available in mpz_dbl_dig_t.  In this case we can use a +        //      slightly more optimal (in time and space) routine that +        //      uses the extra bits in mpz_dbl_dig_signed_t to store a +        //      sign bit. +        // +        //   2. DIG_SIZE is exactly half the number of bits available in +        //      mpz_dbl_dig_t.  In this (common) case we need to be careful +        //      not to overflow the borrow variable.  And the shifting of +        //      borrow needs some special logic (it's a shift right with +        //      round up). + +        if (DIG_SIZE < 8 * sizeof(mpz_dbl_dig_t) / 2) {              mpz_dbl_dig_signed_t borrow = 0;              for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { -                borrow += *n - quo * *d; // will overflow if DIG_SIZE >= 16 +                borrow += (mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)*d; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2                  *n = borrow & DIG_MASK;                  borrow >>= DIG_SIZE;              } -            borrow += *num_dig; // will overflow if DIG_SIZE >= 16 +            borrow += *num_dig; // will overflow if DIG_SIZE >= 8*sizeof(mpz_dbl_dig_t)/2              *num_dig = borrow & DIG_MASK;              borrow >>= DIG_SIZE; @@ -434,7 +451,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,              for (; borrow != 0; --quo) {                  mpz_dbl_dig_t carry = 0;                  for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { -                    carry += *n + *d; +                    carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d;                      *n = carry & DIG_MASK;                      carry >>= DIG_SIZE;                  } @@ -444,6 +461,44 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig,                  borrow += carry;              } +        } else { // DIG_SIZE == 8 * sizeof(mpz_dbl_dig_t) / 2 +            mpz_dbl_dig_t borrow = 0; + +            for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { +                mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (mpz_dbl_dig_t)(*d); +                if (x >= *n || *n - x <= borrow) { +                    borrow += (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)*n; +                    *n = (-borrow) & DIG_MASK; +                    borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up +                } else { +                    *n = ((mpz_dbl_dig_t)*n - (mpz_dbl_dig_t)x - (mpz_dbl_dig_t)borrow) & DIG_MASK; +                    borrow = 0; +                } +            } +            if (borrow >= *num_dig) { +                borrow -= (mpz_dbl_dig_t)*num_dig; +                *num_dig = (-borrow) & DIG_MASK; +                borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up +            } else { +                *num_dig = (*num_dig - borrow) & DIG_MASK; +                borrow = 0; +            } + +            // adjust quotient if it is too big +            for (; borrow != 0; --quo) { +                mpz_dbl_dig_t carry = 0; +                for (mpz_dig_t *n = num_dig - den_len, *d = den_dig; n < num_dig; ++n, ++d) { +                    carry += (mpz_dbl_dig_t)*n + (mpz_dbl_dig_t)*d; +                    *n = carry & DIG_MASK; +                    carry >>= DIG_SIZE; +                } +                carry += (mpz_dbl_dig_t)*num_dig; +                *num_dig = carry & DIG_MASK; +                carry >>= DIG_SIZE; + +                //assert(borrow >= carry); // enable this to check the logic +                borrow -= carry; +            }          }          // store this digit of the quotient @@ -1256,7 +1311,7 @@ bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) {      mpz_dig_t *d = i->dig + i->len;      while (--d >= i->dig) { -        if (val > ((~0) >> DIG_SIZE)) { +        if (val > (~(WORD_MSBIT_HIGH) >> (DIG_SIZE - 1))) {              // will overflow              return false;          } @@ -1273,7 +1328,7 @@ mp_float_t mpz_as_float(const mpz_t *i) {      mpz_dig_t *d = i->dig + i->len;      while (--d >= i->dig) { -        val = val * (1 << DIG_SIZE) + *d; +        val = val * DIG_BASE + *d;      }      if (i->neg != 0) { | 
