diff options
-rw-r--r-- | ports/rp2/mpthreadport.c | 8 | ||||
-rw-r--r-- | ports/rp2/mutex_extra.c | 13 | ||||
-rw-r--r-- | ports/rp2/mutex_extra.h | 4 | ||||
-rw-r--r-- | tests/thread/disable_irq.py | 51 | ||||
-rw-r--r-- | tests/thread/disable_irq.py.exp | 2 |
5 files changed, 68 insertions, 10 deletions
diff --git a/ports/rp2/mpthreadport.c b/ports/rp2/mpthreadport.c index 8d8f13a08..5b8e804f9 100644 --- a/ports/rp2/mpthreadport.c +++ b/ports/rp2/mpthreadport.c @@ -46,13 +46,13 @@ static uint32_t *core1_stack = NULL; static size_t core1_stack_num_words = 0; // Thread mutex. -static mutex_t atomic_mutex; +static recursive_mutex_t atomic_mutex; uint32_t mp_thread_begin_atomic_section(void) { if (core1_entry) { // When both cores are executing, we also need to provide // full mutual exclusion. - return mutex_enter_blocking_and_disable_interrupts(&atomic_mutex); + return recursive_mutex_enter_blocking_and_disable_interrupts(&atomic_mutex); } else { return save_and_disable_interrupts(); } @@ -60,7 +60,7 @@ uint32_t mp_thread_begin_atomic_section(void) { void mp_thread_end_atomic_section(uint32_t state) { if (atomic_mutex.owner != LOCK_INVALID_OWNER_ID) { - mutex_exit_and_restore_interrupts(&atomic_mutex, state); + recursive_mutex_exit_and_restore_interrupts(&atomic_mutex, state); } else { restore_interrupts(state); } @@ -70,7 +70,7 @@ void mp_thread_end_atomic_section(uint32_t state) { void mp_thread_init(void) { assert(get_core_num() == 0); - mutex_init(&atomic_mutex); + recursive_mutex_init(&atomic_mutex); // Allow MICROPY_BEGIN_ATOMIC_SECTION to be invoked from core1. multicore_lockout_victim_init(); diff --git a/ports/rp2/mutex_extra.c b/ports/rp2/mutex_extra.c index 6df57e64c..7a70a40ac 100644 --- a/ports/rp2/mutex_extra.c +++ b/ports/rp2/mutex_extra.c @@ -9,22 +9,27 @@ // These functions are taken from lib/pico-sdk/src/common/pico_sync/mutex.c and modified // so that they atomically obtain the mutex and disable interrupts. -uint32_t __time_critical_func(mutex_enter_blocking_and_disable_interrupts)(mutex_t * mtx) { +uint32_t __time_critical_func(recursive_mutex_enter_blocking_and_disable_interrupts)(recursive_mutex_t * mtx) { lock_owner_id_t caller = lock_get_caller_owner_id(); do { uint32_t save = spin_lock_blocking(mtx->core.spin_lock); - if (!lock_is_owner_id_valid(mtx->owner)) { + if (mtx->owner == caller || !lock_is_owner_id_valid(mtx->owner)) { mtx->owner = caller; + uint __unused total = ++mtx->enter_count; spin_unlock_unsafe(mtx->core.spin_lock); + assert(total); // check for overflow return save; } lock_internal_spin_unlock_with_wait(&mtx->core, save); } while (true); } -void __time_critical_func(mutex_exit_and_restore_interrupts)(mutex_t * mtx, uint32_t save) { +void __time_critical_func(recursive_mutex_exit_and_restore_interrupts)(recursive_mutex_t * mtx, uint32_t save) { spin_lock_unsafe_blocking(mtx->core.spin_lock); assert(lock_is_owner_id_valid(mtx->owner)); - mtx->owner = LOCK_INVALID_OWNER_ID; + assert(mtx->enter_count); + if (!--mtx->enter_count) { + mtx->owner = LOCK_INVALID_OWNER_ID; + } lock_internal_spin_unlock_with_notify(&mtx->core, save); } diff --git a/ports/rp2/mutex_extra.h b/ports/rp2/mutex_extra.h index be7cb96dc..61b6b4035 100644 --- a/ports/rp2/mutex_extra.h +++ b/ports/rp2/mutex_extra.h @@ -28,7 +28,7 @@ #include "pico/mutex.h" -uint32_t mutex_enter_blocking_and_disable_interrupts(mutex_t *mtx); -void mutex_exit_and_restore_interrupts(mutex_t *mtx, uint32_t save); +uint32_t recursive_mutex_enter_blocking_and_disable_interrupts(recursive_mutex_t *mtx); +void recursive_mutex_exit_and_restore_interrupts(recursive_mutex_t *mtx, uint32_t save); #endif // MICROPY_INCLUDED_RP2_MUTEX_EXTRA_H diff --git a/tests/thread/disable_irq.py b/tests/thread/disable_irq.py new file mode 100644 index 000000000..3f1ac74f3 --- /dev/null +++ b/tests/thread/disable_irq.py @@ -0,0 +1,51 @@ +# Ensure that disabling IRQs creates mutual exclusion between threads +# (also tests nesting of disable_irq across threads) +import machine +import time +import _thread + +if not hasattr(machine, "disable_irq"): + print("SKIP") + raise SystemExit + +count = 0 +thread_done = False + + +def inc_count(): + global count + a = machine.disable_irq() + try: + count += 1 + i = 0 + while i < 20: + b = machine.disable_irq() + try: + count += 1 + count -= 1 + i += 1 + finally: + machine.enable_irq(b) + finally: + machine.enable_irq(a) + + +def inc_count_multiple(times): + for _ in range(times): + inc_count() + + +def thread_entry(inc_times): + global thread_done + inc_count_multiple(inc_times) + thread_done = True + + +_thread.start_new_thread(thread_entry, (1000,)) +inc_count_multiple(1000) + +time.sleep(1) + +print("count", count, thread_done) +if count == 2000: + print("PASS") diff --git a/tests/thread/disable_irq.py.exp b/tests/thread/disable_irq.py.exp new file mode 100644 index 000000000..2174b91d0 --- /dev/null +++ b/tests/thread/disable_irq.py.exp @@ -0,0 +1,2 @@ +count 2000 True +PASS |