diff options
| author | Damien George <damien@micropython.org> | 2023-09-27 13:36:25 +1000 |
|---|---|---|
| committer | Damien George <damien@micropython.org> | 2023-09-29 12:03:00 +1000 |
| commit | 58f63497e5ee4d7915ea23929dad7d59712b7c01 (patch) | |
| tree | 07fa849397840bfb98472576dc48b41aea4aeaf8 | |
| parent | 03a3af417e749860b2771642a738151aef830586 (diff) | |
extmod/modssl_axtls: Only close underlying socket once if it was used.
To match the behaviour of the mbedtls implementation, and pass the
ssl_basic.py test.
Signed-off-by: Damien George <damien@micropython.org>
| -rw-r--r-- | extmod/modssl_axtls.c | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/extmod/modssl_axtls.c b/extmod/modssl_axtls.c index d169d89a2..6cb999c13 100644 --- a/extmod/modssl_axtls.c +++ b/extmod/modssl_axtls.c @@ -208,7 +208,7 @@ STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t o->base.type = &ssl_socket_type; o->buf = NULL; o->bytes_left = 0; - o->sock = sock; + o->sock = MP_OBJ_NULL; o->blocking = true; uint32_t options = SSL_SERVER_VERIFY_LATER; @@ -262,6 +262,10 @@ STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t } } + // Populate the socket entry now that the SSLSocket is fully set up. + // This prevents closing the socket if an exception is raised above. + o->sock = sock; + return o; } @@ -348,11 +352,21 @@ eagain: STATIC mp_uint_t ssl_socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, int *errcode) { mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(o_in); - if (request == MP_STREAM_CLOSE && self->ssl_sock != NULL) { + if (request == MP_STREAM_CLOSE) { + if (self->ssl_sock == NULL) { + // Already closed socket, do nothing. + return 0; + } ssl_free(self->ssl_sock); ssl_ctx_free(self->ssl_ctx); self->ssl_sock = NULL; } + + if (self->sock == MP_OBJ_NULL) { + // Underlying socket may be null if the constructor raised an exception. + return 0; + } + // Pass all requests down to the underlying socket return mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode); } |
