summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDamien George <damien@micropython.org>2024-07-24 17:17:39 +1000
committerDamien George <damien@micropython.org>2024-07-25 13:07:42 +1000
commit233f5ce661d31c8e606ff727ac0c6de27d32a715 (patch)
tree482e6467632b9f4c1cca6601ce581e682eb456aa
parent07bf3179f6d51a0e80fc48f06490f4e4d0d8488d (diff)
py/runtime: Fix self arg passed to classmethod when accessed via super.
Thanks to @AJMansfield for the original test case. Signed-off-by: Damien George <damien@micropython.org>
-rw-r--r--py/runtime.c4
-rw-r--r--tests/basics/subclass_classmethod.py31
2 files changed, 35 insertions, 0 deletions
diff --git a/py/runtime.c b/py/runtime.c
index 1836f5d92..fd0a8e690 100644
--- a/py/runtime.c
+++ b/py/runtime.c
@@ -1153,6 +1153,10 @@ void mp_convert_member_lookup(mp_obj_t self, const mp_obj_type_t *type, mp_obj_t
// base type (which is what is passed in the `type` argument to this function).
if (self != MP_OBJ_NULL) {
type = mp_obj_get_type(self);
+ if (type == &mp_type_type) {
+ // `self` is already a type, so use `self` directly.
+ type = MP_OBJ_TO_PTR(self);
+ }
}
dest[0] = ((mp_obj_static_class_method_t *)MP_OBJ_TO_PTR(member))->fun;
dest[1] = MP_OBJ_FROM_PTR(type);
diff --git a/tests/basics/subclass_classmethod.py b/tests/basics/subclass_classmethod.py
index 00a2ebd7c..3089726aa 100644
--- a/tests/basics/subclass_classmethod.py
+++ b/tests/basics/subclass_classmethod.py
@@ -35,3 +35,34 @@ class B(A):
B.bar() # class calling classmethod
B().bar() # instance calling classmethod
B().baz() # instance calling normal method
+
+# super inside a classmethod
+# ensure the argument of the super method that is called is the child type
+
+
+class C:
+ @classmethod
+ def f(cls):
+ print("C.f", cls.__name__) # cls should be D
+
+ @classmethod
+ def g(cls):
+ print("C.g", cls.__name__) # cls should be D
+
+
+class D(C):
+ @classmethod
+ def f(cls):
+ print("D.f", cls.__name__)
+ super().f()
+
+ @classmethod
+ def g(cls):
+ print("D.g", cls.__name__)
+ super(D, cls).g()
+
+
+D.f()
+D.g()
+D().f()
+D().g()