summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--py/obj.c9
-rw-r--r--py/obj.h1
-rw-r--r--py/objcomplex.c5
-rw-r--r--tests/float/cmath_fun.py6
-rw-r--r--tests/float/complex_special_mehods.py15
-rwxr-xr-xtests/run-tests1
6 files changed, 35 insertions, 2 deletions
diff --git a/py/obj.c b/py/obj.c
index 07b161255..ed047acc3 100644
--- a/py/obj.c
+++ b/py/obj.c
@@ -371,7 +371,7 @@ mp_float_t mp_obj_get_float(mp_obj_t arg) {
}
#if MICROPY_PY_BUILTINS_COMPLEX
-void mp_obj_get_complex(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) {
+bool mp_obj_get_complex_maybe(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) {
if (arg == mp_const_false) {
*real = 0;
*imag = 0;
@@ -392,6 +392,13 @@ void mp_obj_get_complex(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) {
} else if (mp_obj_is_type(arg, &mp_type_complex)) {
mp_obj_complex_get(arg, real, imag);
} else {
+ return false;
+ }
+ return true;
+}
+
+void mp_obj_get_complex(mp_obj_t arg, mp_float_t *real, mp_float_t *imag) {
+ if (!mp_obj_get_complex_maybe(arg, real, imag)) {
#if MICROPY_ERROR_REPORTING == MICROPY_ERROR_REPORTING_TERSE
mp_raise_TypeError(MP_ERROR_TEXT("can't convert to complex"));
#else
diff --git a/py/obj.h b/py/obj.h
index 590b9c4b6..1fa24eb18 100644
--- a/py/obj.h
+++ b/py/obj.h
@@ -778,6 +778,7 @@ bool mp_obj_get_int_maybe(mp_const_obj_t arg, mp_int_t *value);
mp_float_t mp_obj_get_float(mp_obj_t self_in);
bool mp_obj_get_float_maybe(mp_obj_t arg, mp_float_t *value);
void mp_obj_get_complex(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag);
+bool mp_obj_get_complex_maybe(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag);
#endif
void mp_obj_get_array(mp_obj_t o, size_t *len, mp_obj_t **items); // *items may point inside a GC block
void mp_obj_get_array_fixed_n(mp_obj_t o, size_t len, mp_obj_t **items); // *items may point inside a GC block
diff --git a/py/objcomplex.c b/py/objcomplex.c
index 91e440230..f4c4aeffc 100644
--- a/py/objcomplex.c
+++ b/py/objcomplex.c
@@ -178,7 +178,10 @@ void mp_obj_complex_get(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag) {
mp_obj_t mp_obj_complex_binary_op(mp_binary_op_t op, mp_float_t lhs_real, mp_float_t lhs_imag, mp_obj_t rhs_in) {
mp_float_t rhs_real, rhs_imag;
- mp_obj_get_complex(rhs_in, &rhs_real, &rhs_imag); // can be any type, this function will convert to float (if possible)
+ if (!mp_obj_get_complex_maybe(rhs_in, &rhs_real, &rhs_imag)) {
+ return MP_OBJ_NULL; // op not supported
+ }
+
switch (op) {
case MP_BINARY_OP_ADD:
case MP_BINARY_OP_INPLACE_ADD:
diff --git a/tests/float/cmath_fun.py b/tests/float/cmath_fun.py
index 7b5e69245..15b72e7a6 100644
--- a/tests/float/cmath_fun.py
+++ b/tests/float/cmath_fun.py
@@ -57,3 +57,9 @@ for f_name, f, test_vals in functions:
if abs(real) < 1e-6:
real = 0.0
print("complex(%.5g, %.5g)" % (real, ret.imag))
+
+# test invalid type passed to cmath function
+try:
+ log([])
+except TypeError:
+ print("TypeError")
diff --git a/tests/float/complex_special_mehods.py b/tests/float/complex_special_mehods.py
new file mode 100644
index 000000000..6789013fa
--- /dev/null
+++ b/tests/float/complex_special_mehods.py
@@ -0,0 +1,15 @@
+# test complex interacting with special methods
+
+
+class A:
+ def __add__(self, x):
+ print("__add__")
+ return 1
+
+ def __radd__(self, x):
+ print("__radd__")
+ return 2
+
+
+print(A() + 1j)
+print(1j + A())
diff --git a/tests/run-tests b/tests/run-tests
index f9e4de4b3..102b0f779 100755
--- a/tests/run-tests
+++ b/tests/run-tests
@@ -355,6 +355,7 @@ def run_tests(pyb, tests, args, base_path="."):
if not has_complex:
skip_tests.add('float/complex1.py')
skip_tests.add('float/complex1_intbig.py')
+ skip_tests.add('float/complex_special_mehods.py')
skip_tests.add('float/int_big_float.py')
skip_tests.add('float/true_value.py')
skip_tests.add('float/types.py')