diff options
| author | Damien George <damien@micropython.org> | 2024-07-24 17:17:39 +1000 |
|---|---|---|
| committer | Damien George <damien@micropython.org> | 2024-07-25 13:07:42 +1000 |
| commit | 233f5ce661d31c8e606ff727ac0c6de27d32a715 (patch) | |
| tree | 482e6467632b9f4c1cca6601ce581e682eb456aa | |
| parent | 07bf3179f6d51a0e80fc48f06490f4e4d0d8488d (diff) | |
py/runtime: Fix self arg passed to classmethod when accessed via super.
Thanks to @AJMansfield for the original test case.
Signed-off-by: Damien George <damien@micropython.org>
| -rw-r--r-- | py/runtime.c | 4 | ||||
| -rw-r--r-- | tests/basics/subclass_classmethod.py | 31 |
2 files changed, 35 insertions, 0 deletions
diff --git a/py/runtime.c b/py/runtime.c index 1836f5d92..fd0a8e690 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -1153,6 +1153,10 @@ void mp_convert_member_lookup(mp_obj_t self, const mp_obj_type_t *type, mp_obj_t // base type (which is what is passed in the `type` argument to this function). if (self != MP_OBJ_NULL) { type = mp_obj_get_type(self); + if (type == &mp_type_type) { + // `self` is already a type, so use `self` directly. + type = MP_OBJ_TO_PTR(self); + } } dest[0] = ((mp_obj_static_class_method_t *)MP_OBJ_TO_PTR(member))->fun; dest[1] = MP_OBJ_FROM_PTR(type); diff --git a/tests/basics/subclass_classmethod.py b/tests/basics/subclass_classmethod.py index 00a2ebd7c..3089726aa 100644 --- a/tests/basics/subclass_classmethod.py +++ b/tests/basics/subclass_classmethod.py @@ -35,3 +35,34 @@ class B(A): B.bar() # class calling classmethod B().bar() # instance calling classmethod B().baz() # instance calling normal method + +# super inside a classmethod +# ensure the argument of the super method that is called is the child type + + +class C: + @classmethod + def f(cls): + print("C.f", cls.__name__) # cls should be D + + @classmethod + def g(cls): + print("C.g", cls.__name__) # cls should be D + + +class D(C): + @classmethod + def f(cls): + print("D.f", cls.__name__) + super().f() + + @classmethod + def g(cls): + print("D.g", cls.__name__) + super(D, cls).g() + + +D.f() +D.g() +D().f() +D().g() |
