summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ports/rp2/mpthreadport.c8
-rw-r--r--ports/rp2/mutex_extra.c13
-rw-r--r--ports/rp2/mutex_extra.h4
-rw-r--r--tests/thread/disable_irq.py51
-rw-r--r--tests/thread/disable_irq.py.exp2
5 files changed, 68 insertions, 10 deletions
diff --git a/ports/rp2/mpthreadport.c b/ports/rp2/mpthreadport.c
index 8d8f13a08..5b8e804f9 100644
--- a/ports/rp2/mpthreadport.c
+++ b/ports/rp2/mpthreadport.c
@@ -46,13 +46,13 @@ static uint32_t *core1_stack = NULL;
static size_t core1_stack_num_words = 0;
// Thread mutex.
-static mutex_t atomic_mutex;
+static recursive_mutex_t atomic_mutex;
uint32_t mp_thread_begin_atomic_section(void) {
if (core1_entry) {
// When both cores are executing, we also need to provide
// full mutual exclusion.
- return mutex_enter_blocking_and_disable_interrupts(&atomic_mutex);
+ return recursive_mutex_enter_blocking_and_disable_interrupts(&atomic_mutex);
} else {
return save_and_disable_interrupts();
}
@@ -60,7 +60,7 @@ uint32_t mp_thread_begin_atomic_section(void) {
void mp_thread_end_atomic_section(uint32_t state) {
if (atomic_mutex.owner != LOCK_INVALID_OWNER_ID) {
- mutex_exit_and_restore_interrupts(&atomic_mutex, state);
+ recursive_mutex_exit_and_restore_interrupts(&atomic_mutex, state);
} else {
restore_interrupts(state);
}
@@ -70,7 +70,7 @@ void mp_thread_end_atomic_section(uint32_t state) {
void mp_thread_init(void) {
assert(get_core_num() == 0);
- mutex_init(&atomic_mutex);
+ recursive_mutex_init(&atomic_mutex);
// Allow MICROPY_BEGIN_ATOMIC_SECTION to be invoked from core1.
multicore_lockout_victim_init();
diff --git a/ports/rp2/mutex_extra.c b/ports/rp2/mutex_extra.c
index 6df57e64c..7a70a40ac 100644
--- a/ports/rp2/mutex_extra.c
+++ b/ports/rp2/mutex_extra.c
@@ -9,22 +9,27 @@
// These functions are taken from lib/pico-sdk/src/common/pico_sync/mutex.c and modified
// so that they atomically obtain the mutex and disable interrupts.
-uint32_t __time_critical_func(mutex_enter_blocking_and_disable_interrupts)(mutex_t * mtx) {
+uint32_t __time_critical_func(recursive_mutex_enter_blocking_and_disable_interrupts)(recursive_mutex_t * mtx) {
lock_owner_id_t caller = lock_get_caller_owner_id();
do {
uint32_t save = spin_lock_blocking(mtx->core.spin_lock);
- if (!lock_is_owner_id_valid(mtx->owner)) {
+ if (mtx->owner == caller || !lock_is_owner_id_valid(mtx->owner)) {
mtx->owner = caller;
+ uint __unused total = ++mtx->enter_count;
spin_unlock_unsafe(mtx->core.spin_lock);
+ assert(total); // check for overflow
return save;
}
lock_internal_spin_unlock_with_wait(&mtx->core, save);
} while (true);
}
-void __time_critical_func(mutex_exit_and_restore_interrupts)(mutex_t * mtx, uint32_t save) {
+void __time_critical_func(recursive_mutex_exit_and_restore_interrupts)(recursive_mutex_t * mtx, uint32_t save) {
spin_lock_unsafe_blocking(mtx->core.spin_lock);
assert(lock_is_owner_id_valid(mtx->owner));
- mtx->owner = LOCK_INVALID_OWNER_ID;
+ assert(mtx->enter_count);
+ if (!--mtx->enter_count) {
+ mtx->owner = LOCK_INVALID_OWNER_ID;
+ }
lock_internal_spin_unlock_with_notify(&mtx->core, save);
}
diff --git a/ports/rp2/mutex_extra.h b/ports/rp2/mutex_extra.h
index be7cb96dc..61b6b4035 100644
--- a/ports/rp2/mutex_extra.h
+++ b/ports/rp2/mutex_extra.h
@@ -28,7 +28,7 @@
#include "pico/mutex.h"
-uint32_t mutex_enter_blocking_and_disable_interrupts(mutex_t *mtx);
-void mutex_exit_and_restore_interrupts(mutex_t *mtx, uint32_t save);
+uint32_t recursive_mutex_enter_blocking_and_disable_interrupts(recursive_mutex_t *mtx);
+void recursive_mutex_exit_and_restore_interrupts(recursive_mutex_t *mtx, uint32_t save);
#endif // MICROPY_INCLUDED_RP2_MUTEX_EXTRA_H
diff --git a/tests/thread/disable_irq.py b/tests/thread/disable_irq.py
new file mode 100644
index 000000000..3f1ac74f3
--- /dev/null
+++ b/tests/thread/disable_irq.py
@@ -0,0 +1,51 @@
+# Ensure that disabling IRQs creates mutual exclusion between threads
+# (also tests nesting of disable_irq across threads)
+import machine
+import time
+import _thread
+
+if not hasattr(machine, "disable_irq"):
+ print("SKIP")
+ raise SystemExit
+
+count = 0
+thread_done = False
+
+
+def inc_count():
+ global count
+ a = machine.disable_irq()
+ try:
+ count += 1
+ i = 0
+ while i < 20:
+ b = machine.disable_irq()
+ try:
+ count += 1
+ count -= 1
+ i += 1
+ finally:
+ machine.enable_irq(b)
+ finally:
+ machine.enable_irq(a)
+
+
+def inc_count_multiple(times):
+ for _ in range(times):
+ inc_count()
+
+
+def thread_entry(inc_times):
+ global thread_done
+ inc_count_multiple(inc_times)
+ thread_done = True
+
+
+_thread.start_new_thread(thread_entry, (1000,))
+inc_count_multiple(1000)
+
+time.sleep(1)
+
+print("count", count, thread_done)
+if count == 2000:
+ print("PASS")
diff --git a/tests/thread/disable_irq.py.exp b/tests/thread/disable_irq.py.exp
new file mode 100644
index 000000000..2174b91d0
--- /dev/null
+++ b/tests/thread/disable_irq.py.exp
@@ -0,0 +1,2 @@
+count 2000 True
+PASS