diff options
author | Jeff Epler <jepler@gmail.com> | 2025-07-19 11:24:34 -0500 |
---|---|---|
committer | Damien George <damien@micropython.org> | 2025-07-28 23:58:46 +1000 |
commit | 062e82a7cd60165db7edd8ee32105ea19a31ae2f (patch) | |
tree | 840bece2d256b11b78c3b5f9cd52465532dddde2 | |
parent | ebc9525c953dfbb6c44ab99320b5278e4314dafe (diff) |
py/objint_mpz: Fix pow3 where third argument is zero.
This finding is based on fuzzing MicroPython. I manually minimized the
test case it provided.
Signed-off-by: Jeff Epler <jepler@gmail.com>
-rw-r--r-- | py/objint_mpz.c | 7 | ||||
-rw-r--r-- | tests/basics/builtin_pow3_intbig.py | 5 |
2 files changed, 9 insertions, 3 deletions
diff --git a/py/objint_mpz.c b/py/objint_mpz.c index 6f2ea616c..ea4e409a2 100644 --- a/py/objint_mpz.c +++ b/py/objint_mpz.c @@ -356,9 +356,10 @@ static mpz_t *mp_mpz_for_int(mp_obj_t arg, mpz_t *temp) { mp_obj_t mp_obj_int_pow3(mp_obj_t base, mp_obj_t exponent, mp_obj_t modulus) { if (!mp_obj_is_int(base) || !mp_obj_is_int(exponent) || !mp_obj_is_int(modulus)) { mp_raise_TypeError(MP_ERROR_TEXT("pow() with 3 arguments requires integers")); + } else if (modulus == MP_OBJ_NEW_SMALL_INT(0)) { + mp_raise_ValueError(MP_ERROR_TEXT("divide by zero")); } else { - mp_obj_t result = mp_obj_new_int_from_ull(0); // Use the _from_ull version as this forces an mpz int - mp_obj_int_t *res_p = (mp_obj_int_t *)MP_OBJ_TO_PTR(result); + mp_obj_int_t *res_p = mp_obj_int_new_mpz(); mpz_t l_temp, r_temp, m_temp; mpz_t *lhs = mp_mpz_for_int(base, &l_temp); @@ -376,7 +377,7 @@ mp_obj_t mp_obj_int_pow3(mp_obj_t base, mp_obj_t exponent, mp_obj_t modulus) { if (mod == &m_temp) { mpz_deinit(mod); } - return result; + return MP_OBJ_FROM_PTR(res_p); } } #endif diff --git a/tests/basics/builtin_pow3_intbig.py b/tests/basics/builtin_pow3_intbig.py index bedc8b36b..41d2acbc0 100644 --- a/tests/basics/builtin_pow3_intbig.py +++ b/tests/basics/builtin_pow3_intbig.py @@ -20,3 +20,8 @@ print(hex(pow(2, x-1, x))) # Should be 1, since x is prime print(hex(pow(y, x-1, x))) # Should be 1, since x is prime print(hex(pow(y, y-1, x))) # Should be a 'big value' print(hex(pow(y, y-1, y))) # Should be a 'big value' + +try: + print(pow(1, 2, 0)) +except ValueError: + print("ValueError") |