summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDamien George <damien@micropython.org>2023-05-12 23:17:20 +1000
committerDamien George <damien@micropython.org>2023-05-19 13:42:35 +1000
commit4b57330465b98df30ef8a10e19a0e197b5797550 (patch)
treeb3a674e8cb6673ab451d6ef6d97ec29aed910a6b
parentca9068e0efc5e53ff7042ef68ad4a4ec9ff91915 (diff)
py/objstr: Return unsupported binop instead of raising TypeError.
So that user types can implement reverse operators and have them work with str on the left-hand-side, eg `"a" + UserType()`. Signed-off-by: Damien George <damien@micropython.org>
-rw-r--r--py/objstr.c11
-rw-r--r--tests/basics/class_reverse_op.py32
2 files changed, 41 insertions, 2 deletions
diff --git a/py/objstr.c b/py/objstr.c
index 4d9dca04a..e6c5ee71c 100644
--- a/py/objstr.c
+++ b/py/objstr.c
@@ -403,7 +403,16 @@ mp_obj_t mp_obj_str_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_i
} else {
// LHS is str and RHS has an incompatible type
// (except if operation is EQUAL, but that's handled by mp_obj_equal)
- bad_implicit_conversion(rhs_in);
+
+ // CONTAINS must fail with a bad-implicit-conversion exception, because
+ // otherwise mp_binary_op() will fallback to `list(lhs).__contains__(rhs)`.
+ if (op == MP_BINARY_OP_CONTAINS) {
+ bad_implicit_conversion(rhs_in);
+ }
+
+ // All other operations are not supported, and may be handled by another
+ // type, eg for reverse operations.
+ return MP_OBJ_NULL;
}
switch (op) {
diff --git a/tests/basics/class_reverse_op.py b/tests/basics/class_reverse_op.py
index b0dae5f8a..11aba6aad 100644
--- a/tests/basics/class_reverse_op.py
+++ b/tests/basics/class_reverse_op.py
@@ -1,5 +1,7 @@
-class A:
+# Test reverse operators.
+# Test user type with integers.
+class A:
def __init__(self, v):
self.v = v
@@ -14,5 +16,33 @@ class A:
def __repr__(self):
return "A({})".format(self.v)
+
print(A(3) + 1)
print(2 + A(5))
+
+
+# Test user type with strings.
+class B:
+ def __init__(self, v):
+ self.v = v
+
+ def __repr__(self):
+ return "B({})".format(self.v)
+
+ def __ror__(self, o):
+ return B(o + "|" + self.v)
+
+ def __radd__(self, o):
+ return B(o + "+" + self.v)
+
+ def __rmul__(self, o):
+ return B(o + "*" + self.v)
+
+ def __rtruediv__(self, o):
+ return B(o + "/" + self.v)
+
+
+print("a" | B("b"))
+print("a" + B("b"))
+print("a" * B("b"))
+print("a" / B("b"))