diff options
author | Nicko van Someren <nicko@nicko.org> | 2019-12-31 15:19:12 -0700 |
---|---|---|
committer | Damien George <damien.p.george@gmail.com> | 2020-01-30 14:53:07 +1100 |
commit | 3aab54bf434e7f025a91ea05052f1bac439fad8c (patch) | |
tree | 838593685c39f9ee39dd0556715e4a1fc7038e93 /py/obj.c | |
parent | c3450effd4c3a402eeccf44a84a05ef4b36d69a0 (diff) |
py: Support non-boolean results for equality and inequality tests.
This commit implements a more complete replication of CPython's behaviour
for equality and inequality testing of objects. This addresses the issues
discussed in #5382 and a few other inconsistencies. Improvements over the
old code include:
- Support for returning non-boolean results from comparisons (as used by
numpy and others).
- Support for non-reflexive equality tests.
- Preferential use of __ne__ methods and MP_BINARY_OP_NOT_EQUAL binary
operators for inequality tests, when available.
- Fallback to op2 == op1 or op2 != op1 when op1 does not implement the
(in)equality operators.
The scheme here makes use of a new flag, MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST,
in the flags word of mp_obj_type_t to indicate if various shortcuts can or
cannot be used when performing equality and inequality tests. Currently
four built-in classes have the flag set: float and complex are
non-reflexive (since nan != nan) while bytearray and frozenszet instances
can equal other builtin class instances (bytes and set respectively). The
flag is also set for any new class defined by the user.
This commit also includes a more comprehensive set of tests for the
behaviour of (in)equality operators implemented in special methods.
Diffstat (limited to 'py/obj.c')
-rw-r--r-- | py/obj.c | 112 |
1 files changed, 67 insertions, 45 deletions
@@ -189,7 +189,7 @@ bool mp_obj_is_callable(mp_obj_t o_in) { return mp_obj_instance_is_callable(o_in); } -// This function implements the '==' operator (and so the inverse of '!='). +// This function implements the '==' and '!=' operators. // // From the Python language reference: // (https://docs.python.org/3/reference/expressions.html#not-in) @@ -202,67 +202,89 @@ bool mp_obj_is_callable(mp_obj_t o_in) { // Furthermore, from the v3.4.2 code for object.c: "Practical amendments: If rich // comparison returns NotImplemented, == and != are decided by comparing the object // pointer." -bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) { - // Float (and complex) NaN is never equal to anything, not even itself, - // so we must have a special check here to cover those cases. - if (o1 == o2 - #if MICROPY_PY_BUILTINS_FLOAT - && !mp_obj_is_float(o1) - #endif - #if MICROPY_PY_BUILTINS_COMPLEX - && !mp_obj_is_type(o1, &mp_type_complex) - #endif - ) { - return true; - } - if (o1 == mp_const_none || o2 == mp_const_none) { - return false; - } +mp_obj_t mp_obj_equal_not_equal(mp_binary_op_t op, mp_obj_t o1, mp_obj_t o2) { + mp_obj_t local_true = (op == MP_BINARY_OP_NOT_EQUAL) ? mp_const_false : mp_const_true; + mp_obj_t local_false = (op == MP_BINARY_OP_NOT_EQUAL) ? mp_const_true : mp_const_false; + int pass_number = 0; - // fast path for small ints - if (mp_obj_is_small_int(o1)) { - if (mp_obj_is_small_int(o2)) { - // both SMALL_INT, and not equal if we get here - return false; - } else { - mp_obj_t temp = o2; o2 = o1; o1 = temp; - // o2 is now the SMALL_INT, o1 is not - // fall through to generic op - } + // Shortcut for very common cases + if (o1 == o2 && + (mp_obj_is_small_int(o1) || !(mp_obj_get_type(o1)->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST))) { + return local_true; } // fast path for strings if (mp_obj_is_str(o1)) { if (mp_obj_is_str(o2)) { // both strings, use special function - return mp_obj_str_equal(o1, o2); + return mp_obj_str_equal(o1, o2) ? local_true : local_false; + #if MICROPY_PY_STR_BYTES_CMP_WARN + } else if (mp_obj_is_type(o2, &mp_type_bytes)) { + str_bytes_cmp: + mp_warning(MP_WARN_CAT(BytesWarning), "Comparison between bytes and str"); + return local_false; + #endif } else { - // a string is never equal to anything else - goto str_cmp_err; + goto skip_one_pass; } - } else if (mp_obj_is_str(o2)) { + #if MICROPY_PY_STR_BYTES_CMP_WARN + } else if (mp_obj_is_str(o2) && mp_obj_is_type(o1, &mp_type_bytes)) { // o1 is not a string (else caught above), so the objects are not equal - str_cmp_err: - #if MICROPY_PY_STR_BYTES_CMP_WARN - if (mp_obj_is_type(o1, &mp_type_bytes) || mp_obj_is_type(o2, &mp_type_bytes)) { - mp_warning(MP_WARN_CAT(BytesWarning), "Comparison between bytes and str"); + goto str_bytes_cmp; + #endif + } + + // fast path for small ints + if (mp_obj_is_small_int(o1)) { + if (mp_obj_is_small_int(o2)) { + // both SMALL_INT, and not equal if we get here + return local_false; + } else { + goto skip_one_pass; } - #endif - return false; } // generic type, call binary_op(MP_BINARY_OP_EQUAL) - const mp_obj_type_t *type = mp_obj_get_type(o1); - if (type->binary_op != NULL) { - mp_obj_t r = type->binary_op(MP_BINARY_OP_EQUAL, o1, o2); - if (r != MP_OBJ_NULL) { - return r == mp_const_true ? true : false; + while (pass_number < 2) { + const mp_obj_type_t *type = mp_obj_get_type(o1); + // If a full equality test is not needed and the other object is a different + // type then we don't need to bother trying the comparison. + if (type->binary_op != NULL && + ((type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST) || mp_obj_get_type(o2) == type)) { + // CPython is asymmetric: it will try __eq__ if there's no __ne__ but not the + // other way around. If the class doesn't need a full test we can skip __ne__. + if (op == MP_BINARY_OP_NOT_EQUAL && (type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST)) { + mp_obj_t r = type->binary_op(MP_BINARY_OP_NOT_EQUAL, o1, o2); + if (r != MP_OBJ_NULL) { + return r; + } + } + + // Try calling __eq__. + mp_obj_t r = type->binary_op(MP_BINARY_OP_EQUAL, o1, o2); + if (r != MP_OBJ_NULL) { + if (op == MP_BINARY_OP_EQUAL) { + return r; + } else { + return mp_obj_is_true(r) ? local_true : local_false; + } + } } + + skip_one_pass: + // Try the other way around if none of the above worked + ++pass_number; + mp_obj_t temp = o1; + o1 = o2; + o2 = temp; } - // equality not implemented, and objects are not the same object, so - // they are defined as not equal - return false; + // equality not implemented, so fall back to pointer conparison + return (o1 == o2) ? local_true : local_false; +} + +bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) { + return mp_obj_is_true(mp_obj_equal_not_equal(MP_BINARY_OP_EQUAL, o1, o2)); } mp_int_t mp_obj_get_int(mp_const_obj_t arg) { |