summaryrefslogtreecommitdiff
path: root/py
diff options
context:
space:
mode:
Diffstat (limited to 'py')
-rw-r--r--py/obj.c112
-rw-r--r--py/obj.h7
-rw-r--r--py/objarray.c1
-rw-r--r--py/objcomplex.c1
-rw-r--r--py/objfloat.c1
-rw-r--r--py/objset.c1
-rw-r--r--py/objtype.c4
-rw-r--r--py/runtime.c15
8 files changed, 82 insertions, 60 deletions
diff --git a/py/obj.c b/py/obj.c
index 55754f9be..6aa0abf0d 100644
--- a/py/obj.c
+++ b/py/obj.c
@@ -189,7 +189,7 @@ bool mp_obj_is_callable(mp_obj_t o_in) {
return mp_obj_instance_is_callable(o_in);
}
-// This function implements the '==' operator (and so the inverse of '!=').
+// This function implements the '==' and '!=' operators.
//
// From the Python language reference:
// (https://docs.python.org/3/reference/expressions.html#not-in)
@@ -202,67 +202,89 @@ bool mp_obj_is_callable(mp_obj_t o_in) {
// Furthermore, from the v3.4.2 code for object.c: "Practical amendments: If rich
// comparison returns NotImplemented, == and != are decided by comparing the object
// pointer."
-bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) {
- // Float (and complex) NaN is never equal to anything, not even itself,
- // so we must have a special check here to cover those cases.
- if (o1 == o2
- #if MICROPY_PY_BUILTINS_FLOAT
- && !mp_obj_is_float(o1)
- #endif
- #if MICROPY_PY_BUILTINS_COMPLEX
- && !mp_obj_is_type(o1, &mp_type_complex)
- #endif
- ) {
- return true;
- }
- if (o1 == mp_const_none || o2 == mp_const_none) {
- return false;
- }
+mp_obj_t mp_obj_equal_not_equal(mp_binary_op_t op, mp_obj_t o1, mp_obj_t o2) {
+ mp_obj_t local_true = (op == MP_BINARY_OP_NOT_EQUAL) ? mp_const_false : mp_const_true;
+ mp_obj_t local_false = (op == MP_BINARY_OP_NOT_EQUAL) ? mp_const_true : mp_const_false;
+ int pass_number = 0;
- // fast path for small ints
- if (mp_obj_is_small_int(o1)) {
- if (mp_obj_is_small_int(o2)) {
- // both SMALL_INT, and not equal if we get here
- return false;
- } else {
- mp_obj_t temp = o2; o2 = o1; o1 = temp;
- // o2 is now the SMALL_INT, o1 is not
- // fall through to generic op
- }
+ // Shortcut for very common cases
+ if (o1 == o2 &&
+ (mp_obj_is_small_int(o1) || !(mp_obj_get_type(o1)->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST))) {
+ return local_true;
}
// fast path for strings
if (mp_obj_is_str(o1)) {
if (mp_obj_is_str(o2)) {
// both strings, use special function
- return mp_obj_str_equal(o1, o2);
+ return mp_obj_str_equal(o1, o2) ? local_true : local_false;
+ #if MICROPY_PY_STR_BYTES_CMP_WARN
+ } else if (mp_obj_is_type(o2, &mp_type_bytes)) {
+ str_bytes_cmp:
+ mp_warning(MP_WARN_CAT(BytesWarning), "Comparison between bytes and str");
+ return local_false;
+ #endif
} else {
- // a string is never equal to anything else
- goto str_cmp_err;
+ goto skip_one_pass;
}
- } else if (mp_obj_is_str(o2)) {
+ #if MICROPY_PY_STR_BYTES_CMP_WARN
+ } else if (mp_obj_is_str(o2) && mp_obj_is_type(o1, &mp_type_bytes)) {
// o1 is not a string (else caught above), so the objects are not equal
- str_cmp_err:
- #if MICROPY_PY_STR_BYTES_CMP_WARN
- if (mp_obj_is_type(o1, &mp_type_bytes) || mp_obj_is_type(o2, &mp_type_bytes)) {
- mp_warning(MP_WARN_CAT(BytesWarning), "Comparison between bytes and str");
+ goto str_bytes_cmp;
+ #endif
+ }
+
+ // fast path for small ints
+ if (mp_obj_is_small_int(o1)) {
+ if (mp_obj_is_small_int(o2)) {
+ // both SMALL_INT, and not equal if we get here
+ return local_false;
+ } else {
+ goto skip_one_pass;
}
- #endif
- return false;
}
// generic type, call binary_op(MP_BINARY_OP_EQUAL)
- const mp_obj_type_t *type = mp_obj_get_type(o1);
- if (type->binary_op != NULL) {
- mp_obj_t r = type->binary_op(MP_BINARY_OP_EQUAL, o1, o2);
- if (r != MP_OBJ_NULL) {
- return r == mp_const_true ? true : false;
+ while (pass_number < 2) {
+ const mp_obj_type_t *type = mp_obj_get_type(o1);
+ // If a full equality test is not needed and the other object is a different
+ // type then we don't need to bother trying the comparison.
+ if (type->binary_op != NULL &&
+ ((type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST) || mp_obj_get_type(o2) == type)) {
+ // CPython is asymmetric: it will try __eq__ if there's no __ne__ but not the
+ // other way around. If the class doesn't need a full test we can skip __ne__.
+ if (op == MP_BINARY_OP_NOT_EQUAL && (type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST)) {
+ mp_obj_t r = type->binary_op(MP_BINARY_OP_NOT_EQUAL, o1, o2);
+ if (r != MP_OBJ_NULL) {
+ return r;
+ }
+ }
+
+ // Try calling __eq__.
+ mp_obj_t r = type->binary_op(MP_BINARY_OP_EQUAL, o1, o2);
+ if (r != MP_OBJ_NULL) {
+ if (op == MP_BINARY_OP_EQUAL) {
+ return r;
+ } else {
+ return mp_obj_is_true(r) ? local_true : local_false;
+ }
+ }
}
+
+ skip_one_pass:
+ // Try the other way around if none of the above worked
+ ++pass_number;
+ mp_obj_t temp = o1;
+ o1 = o2;
+ o2 = temp;
}
- // equality not implemented, and objects are not the same object, so
- // they are defined as not equal
- return false;
+ // equality not implemented, so fall back to pointer conparison
+ return (o1 == o2) ? local_true : local_false;
+}
+
+bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) {
+ return mp_obj_is_true(mp_obj_equal_not_equal(MP_BINARY_OP_EQUAL, o1, o2));
}
mp_int_t mp_obj_get_int(mp_const_obj_t arg) {
diff --git a/py/obj.h b/py/obj.h
index 2bc72b586..2e91b6b15 100644
--- a/py/obj.h
+++ b/py/obj.h
@@ -445,8 +445,14 @@ typedef mp_obj_t (*mp_fun_var_t)(size_t n, const mp_obj_t *);
typedef mp_obj_t (*mp_fun_kw_t)(size_t n, const mp_obj_t *, mp_map_t *);
// Flags for type behaviour (mp_obj_type_t.flags)
+// If MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST is clear then all the following hold:
+// (a) the type only implements the __eq__ operator and not the __ne__ operator;
+// (b) __eq__ returns a boolean result (False or True);
+// (c) __eq__ is reflexive (A==A is True);
+// (d) the type can't be equal to an instance of any different class that also clears this flag.
#define MP_TYPE_FLAG_IS_SUBCLASSED (0x0001)
#define MP_TYPE_FLAG_HAS_SPECIAL_ACCESSORS (0x0002)
+#define MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST (0x0004)
typedef enum {
PRINT_STR = 0,
@@ -729,6 +735,7 @@ void mp_obj_print_exception(const mp_print_t *print, mp_obj_t exc);
bool mp_obj_is_true(mp_obj_t arg);
bool mp_obj_is_callable(mp_obj_t o_in);
+mp_obj_t mp_obj_equal_not_equal(mp_binary_op_t op, mp_obj_t o1, mp_obj_t o2);
bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2);
static inline bool mp_obj_is_integer(mp_const_obj_t o) { return mp_obj_is_int(o) || mp_obj_is_bool(o); } // returns true if o is bool, small int or long int
diff --git a/py/objarray.c b/py/objarray.c
index c19617d4e..51f924ba4 100644
--- a/py/objarray.c
+++ b/py/objarray.c
@@ -558,6 +558,7 @@ const mp_obj_type_t mp_type_array = {
const mp_obj_type_t mp_type_bytearray = {
{ &mp_type_type },
.name = MP_QSTR_bytearray,
+ .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST,
.print = array_print,
.make_new = bytearray_make_new,
.getiter = array_iterator_new,
diff --git a/py/objcomplex.c b/py/objcomplex.c
index bf6fb51dc..0c87f544f 100644
--- a/py/objcomplex.c
+++ b/py/objcomplex.c
@@ -148,6 +148,7 @@ STATIC void complex_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) {
const mp_obj_type_t mp_type_complex = {
{ &mp_type_type },
.name = MP_QSTR_complex,
+ .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST,
.print = complex_print,
.make_new = complex_make_new,
.unary_op = complex_unary_op,
diff --git a/py/objfloat.c b/py/objfloat.c
index 3da549bb2..6181d5f57 100644
--- a/py/objfloat.c
+++ b/py/objfloat.c
@@ -186,6 +186,7 @@ STATIC mp_obj_t float_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs
const mp_obj_type_t mp_type_float = {
{ &mp_type_type },
.name = MP_QSTR_float,
+ .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST,
.print = float_print,
.make_new = float_make_new,
.unary_op = float_unary_op,
diff --git a/py/objset.c b/py/objset.c
index 6f5bcc032..a18b7e051 100644
--- a/py/objset.c
+++ b/py/objset.c
@@ -564,6 +564,7 @@ STATIC MP_DEFINE_CONST_DICT(frozenset_locals_dict, frozenset_locals_dict_table);
const mp_obj_type_t mp_type_frozenset = {
{ &mp_type_type },
.name = MP_QSTR_frozenset,
+ .flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST,
.print = set_print,
.make_new = set_make_new,
.unary_op = set_unary_op,
diff --git a/py/objtype.c b/py/objtype.c
index c574dcdfe..ae0fe6cae 100644
--- a/py/objtype.c
+++ b/py/objtype.c
@@ -467,7 +467,7 @@ const byte mp_binary_op_method_name[MP_BINARY_OP_NUM_RUNTIME] = {
[MP_BINARY_OP_EQUAL] = MP_QSTR___eq__,
[MP_BINARY_OP_LESS_EQUAL] = MP_QSTR___le__,
[MP_BINARY_OP_MORE_EQUAL] = MP_QSTR___ge__,
- // MP_BINARY_OP_NOT_EQUAL, // a != b calls a == b and inverts result
+ [MP_BINARY_OP_NOT_EQUAL] = MP_QSTR___ne__,
[MP_BINARY_OP_CONTAINS] = MP_QSTR___contains__,
// If an inplace method is not found a normal method will be used as a fallback
@@ -1100,7 +1100,7 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict)
// TODO might need to make a copy of locals_dict; at least that's how CPython does it
// Basic validation of base classes
- uint16_t base_flags = 0;
+ uint16_t base_flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST;
size_t bases_len;
mp_obj_t *bases_items;
mp_obj_tuple_get(bases_tuple, &bases_len, &bases_items);
diff --git a/py/runtime.c b/py/runtime.c
index db044cf7c..4a718c1e2 100644
--- a/py/runtime.c
+++ b/py/runtime.c
@@ -323,19 +323,8 @@ mp_obj_t mp_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
// deal with == and != for all types
if (op == MP_BINARY_OP_EQUAL || op == MP_BINARY_OP_NOT_EQUAL) {
- if (mp_obj_equal(lhs, rhs)) {
- if (op == MP_BINARY_OP_EQUAL) {
- return mp_const_true;
- } else {
- return mp_const_false;
- }
- } else {
- if (op == MP_BINARY_OP_EQUAL) {
- return mp_const_false;
- } else {
- return mp_const_true;
- }
- }
+ // mp_obj_equal_not_equal supports a bunch of shortcuts
+ return mp_obj_equal_not_equal(op, lhs, rhs);
}
// deal with exception_match for all types