diff options
| -rw-r--r-- | ports/windows/mpconfigport.h | 2 | ||||
| -rw-r--r-- | py/modmath.c | 12 | ||||
| -rw-r--r-- | py/mpconfig.h | 5 | ||||
| -rw-r--r-- | py/objfloat.c | 6 | ||||
| -rw-r--r-- | tests/float/math_domain.py | 2 | 
5 files changed, 26 insertions, 1 deletions
| diff --git a/ports/windows/mpconfigport.h b/ports/windows/mpconfigport.h index 30389c700..faf10752e 100644 --- a/ports/windows/mpconfigport.h +++ b/ports/windows/mpconfigport.h @@ -236,6 +236,8 @@ extern const struct _mp_obj_module_t mp_module_time;  #define MICROPY_PY_MATH_FMOD_FIX_INFNAN (1)  #ifdef _WIN64  #define MICROPY_PY_MATH_MODF_FIX_NEGZERO (1) +#else +#define MICROPY_PY_MATH_POW_FIX_NAN (1)  #endif  #endif diff --git a/py/modmath.c b/py/modmath.c index b312eeb3d..b7948f39e 100644 --- a/py/modmath.c +++ b/py/modmath.c @@ -98,7 +98,19 @@ mp_float_t MICROPY_FLOAT_C_FUN(log2)(mp_float_t x) {  // sqrt(x): returns the square root of x  MATH_FUN_1(sqrt, sqrt)  // pow(x, y): returns x to the power of y +#if MICROPY_PY_MATH_POW_FIX_NAN +mp_float_t pow_func(mp_float_t x, mp_float_t y) { +    // pow(base, 0) returns 1 for any base, even when base is NaN +    // pow(+1, exponent) returns 1 for any exponent, even when exponent is NaN +    if (x == MICROPY_FLOAT_CONST(1.0) || y == MICROPY_FLOAT_CONST(0.0)) { +        return MICROPY_FLOAT_CONST(1.0); +    } +    return MICROPY_FLOAT_C_FUN(pow)(x, y); +} +MATH_FUN_2(pow, pow_func) +#else  MATH_FUN_2(pow, pow) +#endif  // exp(x)  MATH_FUN_1(exp, exp)  #if MICROPY_PY_MATH_SPECIAL_FUNCTIONS diff --git a/py/mpconfig.h b/py/mpconfig.h index ae4cabfdc..12517316c 100644 --- a/py/mpconfig.h +++ b/py/mpconfig.h @@ -1160,6 +1160,11 @@ typedef double mp_float_t;  #define MICROPY_PY_MATH_MODF_FIX_NEGZERO (0)  #endif +// Whether to provide fix for pow(1, NaN) and pow(NaN, 0), which both should be 1 not NaN. +#ifndef MICROPY_PY_MATH_POW_FIX_NAN +#define MICROPY_PY_MATH_POW_FIX_NAN (0) +#endif +  // Whether to provide "cmath" module  #ifndef MICROPY_PY_CMATH  #define MICROPY_PY_CMATH (0) diff --git a/py/objfloat.c b/py/objfloat.c index f1e401ecc..451609492 100644 --- a/py/objfloat.c +++ b/py/objfloat.c @@ -300,6 +300,12 @@ mp_obj_t mp_obj_float_binary_op(mp_binary_op_t op, mp_float_t lhs_val, mp_obj_t                  mp_raise_ValueError(MP_ERROR_TEXT("complex values not supported"));                  #endif              } +            #if MICROPY_PY_MATH_POW_FIX_NAN // Also see modmath.c. +            if (lhs_val == MICROPY_FLOAT_CONST(1.0) || rhs_val == MICROPY_FLOAT_CONST(0.0)) { +                lhs_val = MICROPY_FLOAT_CONST(1.0); +                break; +            } +            #endif              lhs_val = MICROPY_FLOAT_C_FUN(pow)(lhs_val, rhs_val);              break;          case MP_BINARY_OP_DIVMOD: { diff --git a/tests/float/math_domain.py b/tests/float/math_domain.py index e63628cf5..0c25dc08b 100644 --- a/tests/float/math_domain.py +++ b/tests/float/math_domain.py @@ -38,7 +38,7 @@ for name, f, args in (  # double argument functions  for name, f, args in ( -    ("pow", math.pow, ((0, 2), (-1, 2), (0, -1), (-1, 2.3))), +    ("pow", math.pow, ((0, 2), (-1, 2), (0, -1), (-1, 2.3), (nan, 0), (1, nan))),      ("fmod", math.fmod, ((1.2, inf), (1.2, -inf), (1.2, 0), (inf, 1.2))),      ("atan2", math.atan2, ((0, 0), (-inf, inf), (-inf, -inf), (inf, -inf))),      ("copysign", math.copysign, ()), | 
