diff options
| -rw-r--r-- | py/builtin.c | 2 | ||||
| -rw-r--r-- | py/mpz.c | 42 | ||||
| -rw-r--r-- | py/mpz.h | 2 | ||||
| -rw-r--r-- | py/objint_mpz.c | 6 | 
4 files changed, 22 insertions, 30 deletions
| diff --git a/py/builtin.c b/py/builtin.c index dabc99016..88a724fcd 100644 --- a/py/builtin.c +++ b/py/builtin.c @@ -175,7 +175,7 @@ MP_DEFINE_CONST_FUN_OBJ_1(mp_builtin_callable_obj, mp_builtin_callable);  STATIC mp_obj_t mp_builtin_chr(mp_obj_t o_in) {      #if MICROPY_PY_BUILTINS_STR_UNICODE -    mp_int_t c = mp_obj_get_int(o_in); +    mp_uint_t c = mp_obj_get_int(o_in);      char str[4];      int len = 0;      if (c < 0x80) { @@ -1229,49 +1229,41 @@ mp_int_t mpz_hash(const mpz_t *z) {      return val;  } -// TODO check that this correctly handles overflow in all cases -mp_int_t mpz_as_int(const mpz_t *i) { +bool mpz_as_int_checked(const mpz_t *i, mp_int_t *value) {      mp_int_t val = 0;      mpz_dig_t *d = i->dig + i->len;      while (--d >= i->dig) { -        mp_int_t oldval = val; -        val = (val << DIG_SIZE) | *d; -        if (val < oldval) { -            // overflow, return +/- "infinity" -            if (i->neg == 0) { -                // +infinity -                return ~WORD_MSBIT_HIGH; -            } else { -                // -infinity -                return WORD_MSBIT_HIGH; -            } +        if (val > (~(WORD_MSBIT_HIGH) >> DIG_SIZE)) { +            // will overflow +            return false;          } +        val = (val << DIG_SIZE) | *d;      }      if (i->neg != 0) {          val = -val;      } -    return val; +    *value = val; +    return true;  } -// TODO check that this correctly handles overflow in all cases -bool mpz_as_int_checked(const mpz_t *i, mp_int_t *value) { -    mp_int_t val = 0; +bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) { +    if (i->neg != 0) { +        // can't represent signed values +        return false; +    } + +    mp_uint_t val = 0;      mpz_dig_t *d = i->dig + i->len;      while (--d >= i->dig) { -        mp_int_t oldval = val; -        val = (val << DIG_SIZE) | *d; -        if (val < oldval) { -            // overflow +        if (val > ((~0) >> DIG_SIZE)) { +            // will overflow              return false;          } -    } - -    if (i->neg != 0) { -        val = -val; +        val = (val << DIG_SIZE) | *d;      }      *value = val; @@ -97,8 +97,8 @@ mpz_t *mpz_div(const mpz_t *lhs, const mpz_t *rhs);  mpz_t *mpz_mod(const mpz_t *lhs, const mpz_t *rhs);  mp_int_t mpz_hash(const mpz_t *z); -mp_int_t mpz_as_int(const mpz_t *z);  bool mpz_as_int_checked(const mpz_t *z, mp_int_t *value); +bool mpz_as_uint_checked(const mpz_t *z, mp_uint_t *value);  #if MICROPY_PY_BUILTINS_FLOAT  mp_float_t mpz_as_float(const mpz_t *z);  #endif diff --git a/py/objint_mpz.c b/py/objint_mpz.c index c60e5c2b8..8233773b8 100644 --- a/py/objint_mpz.c +++ b/py/objint_mpz.c @@ -225,8 +225,7 @@ mp_obj_t mp_obj_int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {              case MP_BINARY_OP_INPLACE_LSHIFT:              case MP_BINARY_OP_RSHIFT:              case MP_BINARY_OP_INPLACE_RSHIFT: { -                // TODO check conversion overflow -                mp_int_t irhs = mpz_as_int(zrhs); +                mp_int_t irhs = mp_obj_int_get_checked(rhs_in);                  if (irhs < 0) {                      nlr_raise(mp_obj_new_exception_msg(&mp_type_ValueError, "negative shift count"));                  } @@ -303,7 +302,8 @@ mp_int_t mp_obj_int_get(mp_const_obj_t self_in) {          return MP_OBJ_SMALL_INT_VALUE(self_in);      } else {          const mp_obj_int_t *self = self_in; -        return mpz_as_int(&self->mpz); +        // TODO this is a hack until we remove mp_obj_int_get function entirely +        return mpz_hash(&self->mpz);      }  } | 
