diff options
Diffstat (limited to 'py/obj.c')
-rw-r--r-- | py/obj.c | 112 |
1 files changed, 67 insertions, 45 deletions
@@ -189,7 +189,7 @@ bool mp_obj_is_callable(mp_obj_t o_in) { return mp_obj_instance_is_callable(o_in); } -// This function implements the '==' operator (and so the inverse of '!='). +// This function implements the '==' and '!=' operators. // // From the Python language reference: // (https://docs.python.org/3/reference/expressions.html#not-in) @@ -202,67 +202,89 @@ bool mp_obj_is_callable(mp_obj_t o_in) { // Furthermore, from the v3.4.2 code for object.c: "Practical amendments: If rich // comparison returns NotImplemented, == and != are decided by comparing the object // pointer." -bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) { - // Float (and complex) NaN is never equal to anything, not even itself, - // so we must have a special check here to cover those cases. - if (o1 == o2 - #if MICROPY_PY_BUILTINS_FLOAT - && !mp_obj_is_float(o1) - #endif - #if MICROPY_PY_BUILTINS_COMPLEX - && !mp_obj_is_type(o1, &mp_type_complex) - #endif - ) { - return true; - } - if (o1 == mp_const_none || o2 == mp_const_none) { - return false; - } +mp_obj_t mp_obj_equal_not_equal(mp_binary_op_t op, mp_obj_t o1, mp_obj_t o2) { + mp_obj_t local_true = (op == MP_BINARY_OP_NOT_EQUAL) ? mp_const_false : mp_const_true; + mp_obj_t local_false = (op == MP_BINARY_OP_NOT_EQUAL) ? mp_const_true : mp_const_false; + int pass_number = 0; - // fast path for small ints - if (mp_obj_is_small_int(o1)) { - if (mp_obj_is_small_int(o2)) { - // both SMALL_INT, and not equal if we get here - return false; - } else { - mp_obj_t temp = o2; o2 = o1; o1 = temp; - // o2 is now the SMALL_INT, o1 is not - // fall through to generic op - } + // Shortcut for very common cases + if (o1 == o2 && + (mp_obj_is_small_int(o1) || !(mp_obj_get_type(o1)->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST))) { + return local_true; } // fast path for strings if (mp_obj_is_str(o1)) { if (mp_obj_is_str(o2)) { // both strings, use special function - return mp_obj_str_equal(o1, o2); + return mp_obj_str_equal(o1, o2) ? local_true : local_false; + #if MICROPY_PY_STR_BYTES_CMP_WARN + } else if (mp_obj_is_type(o2, &mp_type_bytes)) { + str_bytes_cmp: + mp_warning(MP_WARN_CAT(BytesWarning), "Comparison between bytes and str"); + return local_false; + #endif } else { - // a string is never equal to anything else - goto str_cmp_err; + goto skip_one_pass; } - } else if (mp_obj_is_str(o2)) { + #if MICROPY_PY_STR_BYTES_CMP_WARN + } else if (mp_obj_is_str(o2) && mp_obj_is_type(o1, &mp_type_bytes)) { // o1 is not a string (else caught above), so the objects are not equal - str_cmp_err: - #if MICROPY_PY_STR_BYTES_CMP_WARN - if (mp_obj_is_type(o1, &mp_type_bytes) || mp_obj_is_type(o2, &mp_type_bytes)) { - mp_warning(MP_WARN_CAT(BytesWarning), "Comparison between bytes and str"); + goto str_bytes_cmp; + #endif + } + + // fast path for small ints + if (mp_obj_is_small_int(o1)) { + if (mp_obj_is_small_int(o2)) { + // both SMALL_INT, and not equal if we get here + return local_false; + } else { + goto skip_one_pass; } - #endif - return false; } // generic type, call binary_op(MP_BINARY_OP_EQUAL) - const mp_obj_type_t *type = mp_obj_get_type(o1); - if (type->binary_op != NULL) { - mp_obj_t r = type->binary_op(MP_BINARY_OP_EQUAL, o1, o2); - if (r != MP_OBJ_NULL) { - return r == mp_const_true ? true : false; + while (pass_number < 2) { + const mp_obj_type_t *type = mp_obj_get_type(o1); + // If a full equality test is not needed and the other object is a different + // type then we don't need to bother trying the comparison. + if (type->binary_op != NULL && + ((type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST) || mp_obj_get_type(o2) == type)) { + // CPython is asymmetric: it will try __eq__ if there's no __ne__ but not the + // other way around. If the class doesn't need a full test we can skip __ne__. + if (op == MP_BINARY_OP_NOT_EQUAL && (type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST)) { + mp_obj_t r = type->binary_op(MP_BINARY_OP_NOT_EQUAL, o1, o2); + if (r != MP_OBJ_NULL) { + return r; + } + } + + // Try calling __eq__. + mp_obj_t r = type->binary_op(MP_BINARY_OP_EQUAL, o1, o2); + if (r != MP_OBJ_NULL) { + if (op == MP_BINARY_OP_EQUAL) { + return r; + } else { + return mp_obj_is_true(r) ? local_true : local_false; + } + } } + + skip_one_pass: + // Try the other way around if none of the above worked + ++pass_number; + mp_obj_t temp = o1; + o1 = o2; + o2 = temp; } - // equality not implemented, and objects are not the same object, so - // they are defined as not equal - return false; + // equality not implemented, so fall back to pointer conparison + return (o1 == o2) ? local_true : local_false; +} + +bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) { + return mp_obj_is_true(mp_obj_equal_not_equal(MP_BINARY_OP_EQUAL, o1, o2)); } mp_int_t mp_obj_get_int(mp_const_obj_t arg) { |