summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--py/malloc.c33
-rw-r--r--tests/extmod/ssl_threads.py57
2 files changed, 89 insertions, 1 deletions
diff --git a/py/malloc.c b/py/malloc.c
index f557ade44..05daeb35d 100644
--- a/py/malloc.c
+++ b/py/malloc.c
@@ -209,6 +209,31 @@ void m_free(void *ptr)
#if MICROPY_TRACKED_ALLOC
+#if MICROPY_PY_THREAD && !MICROPY_PY_THREAD_GIL
+// If there's no GIL, use the GC recursive mutex to protect the tracked node linked list
+// under m_tracked_head.
+//
+// (For ports with GIL, the expectation is to only call tracked alloc functions
+// while holding the GIL.)
+
+static inline void m_tracked_node_lock(void) {
+ mp_thread_recursive_mutex_lock(&MP_STATE_MEM(gc_mutex), 1);
+}
+
+static inline void m_tracked_node_unlock(void) {
+ mp_thread_recursive_mutex_unlock(&MP_STATE_MEM(gc_mutex));
+}
+
+#else
+
+static inline void m_tracked_node_lock(void) {
+}
+
+static inline void m_tracked_node_unlock(void) {
+}
+
+#endif
+
#define MICROPY_TRACKED_ALLOC_STORE_SIZE (!MICROPY_ENABLE_GC)
typedef struct _m_tracked_node_t {
@@ -222,6 +247,7 @@ typedef struct _m_tracked_node_t {
#if MICROPY_DEBUG_VERBOSE
static size_t m_tracked_count_links(size_t *nb) {
+ m_tracked_node_lock();
m_tracked_node_t *node = MP_STATE_VM(m_tracked_head);
size_t n = 0;
*nb = 0;
@@ -234,6 +260,7 @@ static size_t m_tracked_count_links(size_t *nb) {
#endif
node = node->next;
}
+ m_tracked_node_unlock();
return n;
}
#endif
@@ -248,12 +275,14 @@ void *m_tracked_calloc(size_t nmemb, size_t size) {
size_t n = m_tracked_count_links(&nb);
DEBUG_printf("m_tracked_calloc(%u, %u) -> (%u;%u) %p\n", (int)nmemb, (int)size, (int)n, (int)nb, node);
#endif
+ m_tracked_node_lock();
if (MP_STATE_VM(m_tracked_head) != NULL) {
MP_STATE_VM(m_tracked_head)->prev = node;
}
node->prev = NULL;
node->next = MP_STATE_VM(m_tracked_head);
MP_STATE_VM(m_tracked_head) = node;
+ m_tracked_node_unlock();
#if MICROPY_TRACKED_ALLOC_STORE_SIZE
node->size = nmemb * size;
#endif
@@ -278,7 +307,8 @@ void m_tracked_free(void *ptr_in) {
size_t nb;
size_t n = m_tracked_count_links(&nb);
DEBUG_printf("m_tracked_free(%p, [%p, %p], nbytes=%u, links=%u;%u)\n", node, node->prev, node->next, (int)data_bytes, (int)n, (int)nb);
- #endif
+ #endif // MICROPY_DEBUG_VERBOSE
+ m_tracked_node_lock();
if (node->next != NULL) {
node->next->prev = node->prev;
}
@@ -287,6 +317,7 @@ void m_tracked_free(void *ptr_in) {
} else {
MP_STATE_VM(m_tracked_head) = node->next;
}
+ m_tracked_node_unlock();
m_free(node
#if MICROPY_MALLOC_USES_ALLOCATED_SIZE
#if MICROPY_TRACKED_ALLOC_STORE_SIZE
diff --git a/tests/extmod/ssl_threads.py b/tests/extmod/ssl_threads.py
new file mode 100644
index 000000000..4564abd3d
--- /dev/null
+++ b/tests/extmod/ssl_threads.py
@@ -0,0 +1,57 @@
+# Ensure that SSL sockets can be allocated from multiple
+# threads without thread safety issues
+import unittest
+
+try:
+ import _thread
+ import io
+ import tls
+ import time
+except ImportError:
+ print("SKIP")
+ raise SystemExit
+
+
+class TestSocket(io.IOBase):
+ def write(self, buf):
+ return len(buf)
+
+ def readinto(self, buf):
+ return 0
+
+ def ioctl(self, cmd, arg):
+ return 0
+
+ def setblocking(self, value):
+ pass
+
+
+ITERS = 256
+
+
+class TLSThreads(unittest.TestCase):
+ def test_sslsocket_threaded(self):
+ self.done = False
+ # only run in two threads: too much RAM demand otherwise, and rp2 only
+ # supports two anyhow
+ _thread.start_new_thread(self._alloc_many_sockets, (True,))
+ self._alloc_many_sockets(False)
+ while not self.done:
+ time.sleep(0.1)
+ print("done")
+
+ def _alloc_many_sockets(self, set_done_flag):
+ print("start", _thread.get_ident())
+ ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT)
+ ctx.verify_mode = tls.CERT_NONE
+ for n in range(ITERS):
+ s = TestSocket()
+ s = ctx.wrap_socket(s, do_handshake_on_connect=False)
+ s.close() # Free associated resources now from thread, not in a GC pass
+ print("done", _thread.get_ident())
+ if set_done_flag:
+ self.done = True
+
+
+if __name__ == "__main__":
+ unittest.main()