diff options
author | iabdalkader <i.abdalkader@gmail.com> | 2024-10-16 14:08:43 +0200 |
---|---|---|
committer | Damien George <damien@micropython.org> | 2024-10-25 01:12:36 +1100 |
commit | 2644f577f1562a641c62d223dfb1fd80dd541ac9 (patch) | |
tree | 4b749fe728b76ebee9b93de21c90f4659744b5fc | |
parent | 09ea901317810e9e8fbd564cdbf8d50b3569d07f (diff) |
extmod/modtls_mbedtls: Add a thread-global ptr for current SSL context.
This is necessary for mbedTLS callbacks that do not carry any user state,
so those callbacks can be customised per SSL context.
Signed-off-by: iabdalkader <i.abdalkader@gmail.com>
-rw-r--r-- | extmod/modtls_mbedtls.c | 19 | ||||
-rw-r--r-- | py/mpconfig.h | 5 | ||||
-rw-r--r-- | py/mpstate.h | 4 |
3 files changed, 28 insertions, 0 deletions
diff --git a/extmod/modtls_mbedtls.c b/extmod/modtls_mbedtls.c index 30118e200..b261e7a70 100644 --- a/extmod/modtls_mbedtls.c +++ b/extmod/modtls_mbedtls.c @@ -166,6 +166,13 @@ static NORETURN void mbedtls_raise_error(int err) { #endif } +// Stores the current SSLContext for use in mbedtls callbacks where the current state is not passed. +static inline void store_active_context(mp_obj_ssl_context_t *ssl_context) { + #if MICROPY_PY_SSL_MBEDTLS_NEED_ACTIVE_CONTEXT + MP_STATE_THREAD(tls_ssl_context) = ssl_context; + #endif +} + static void ssl_check_async_handshake_failure(mp_obj_ssl_socket_t *sslsock, int *errcode) { if ( #if MBEDTLS_VERSION_NUMBER >= 0x03000000 @@ -497,6 +504,9 @@ static int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) { static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) { + // Store the current SSL context. + store_active_context(ssl_context); + // Verify the socket object has the full stream protocol mp_get_stream_raise(sock, MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL); @@ -602,6 +612,9 @@ static mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc return MP_STREAM_ERROR; } + // Store the current SSL context. + store_active_context(o->ssl_context); + int ret = mbedtls_ssl_read(&o->ssl, buf, size); if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { // end of stream @@ -643,6 +656,9 @@ static mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in return MP_STREAM_ERROR; } + // Store the current SSL context. + store_active_context(o->ssl_context); + int ret = mbedtls_ssl_write(&o->ssl, buf, size); if (ret >= 0) { return ret; @@ -680,6 +696,9 @@ static mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i mp_obj_t sock = self->sock; if (request == MP_STREAM_CLOSE) { + // Clear the SSL context. + store_active_context(NULL); + if (sock == MP_OBJ_NULL) { // Already closed socket, do nothing. return 0; diff --git a/py/mpconfig.h b/py/mpconfig.h index 34eafa9e5..7aa22d52e 100644 --- a/py/mpconfig.h +++ b/py/mpconfig.h @@ -1814,6 +1814,11 @@ typedef double mp_float_t; #define MICROPY_PY_SSL_FINALISER (MICROPY_ENABLE_FINALISER) #endif +// Whether to add a root pointer for the current ssl object +#ifndef MICROPY_PY_SSL_MBEDTLS_NEED_ACTIVE_CONTEXT +#define MICROPY_PY_SSL_MBEDTLS_NEED_ACTIVE_CONTEXT (MICROPY_PY_SSL_ECDSA_SIGN_ALT) +#endif + // Whether to provide the "vfs" module #ifndef MICROPY_PY_VFS #define MICROPY_PY_VFS (MICROPY_CONFIG_ROM_LEVEL_AT_LEAST_CORE_FEATURES && MICROPY_VFS) diff --git a/py/mpstate.h b/py/mpstate.h index af55e764f..54eca596d 100644 --- a/py/mpstate.h +++ b/py/mpstate.h @@ -293,6 +293,10 @@ typedef struct _mp_state_thread_t { bool prof_callback_is_executing; struct _mp_code_state_t *current_code_state; #endif + + #if MICROPY_PY_SSL_MBEDTLS_NEED_ACTIVE_CONTEXT + struct _mp_obj_ssl_context_t *tls_ssl_context; + #endif } mp_state_thread_t; // This structure combines the above 3 structures. |