diff options
| author | Damien Tournoud <damien@platform.sh> | 2022-12-15 10:15:18 -0800 |
|---|---|---|
| committer | Damien Tournoud <damien@platform.sh> | 2022-12-15 12:06:22 -0800 |
| commit | ed58d6e4ce5878d5f53457e72686962e7f57f5df (patch) | |
| tree | 8aab3206923a5d9cd56e0294409d802aa185fc59 /extmod/modussl_mbedtls.c | |
| parent | 988b6e2dae4e7f6b6f85c3b149fabeb50bb0519f (diff) | |
extmod/modussl_mbedtls: Fix support for ioctl(MP_STREAM_POLL).
During the initial handshake or subsequent renegotiation, the protocol
might need to read in order to write (or conversely to write in order
to read). It might be blocked from doing so by the state of the
underlying socket (i.e. there is no data to read, or there is no space
to write).
The library indicates this condition by returning one of the errors
`MBEDTLS_ERR_SSL_WANT_READ` or `MBEDTLS_ERR_SSL_WANT_WRITE`. When that
happens, we need to enforce that the next poll operation only considers
the direction that the library indicated.
In addition, mbedtls does its own read buffering that we need to take
into account while polling, and we need to save the last error between
read()/write() and ioctl().
Diffstat (limited to 'extmod/modussl_mbedtls.c')
| -rw-r--r-- | extmod/modussl_mbedtls.c | 69 |
1 files changed, 68 insertions, 1 deletions
diff --git a/extmod/modussl_mbedtls.c b/extmod/modussl_mbedtls.c index 95cf12c97..eea2d7953 100644 --- a/extmod/modussl_mbedtls.c +++ b/extmod/modussl_mbedtls.c @@ -46,6 +46,8 @@ #include "mbedtls/debug.h" #include "mbedtls/error.h" +#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR) + typedef struct _mp_obj_ssl_socket_t { mp_obj_base_t base; mp_obj_t sock; @@ -56,6 +58,9 @@ typedef struct _mp_obj_ssl_socket_t { mbedtls_x509_crt cacert; mbedtls_x509_crt cert; mbedtls_pk_context pkey; + + uintptr_t poll_mask; // Indicates which read or write operations the protocol needs next + int last_error; // The last error code, if any } mp_obj_ssl_socket_t; struct ssl_args { @@ -165,6 +170,8 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { #endif o->base.type = &ussl_socket_type; o->sock = sock; + o->poll_mask = 0; + o->last_error = 0; int ret; mbedtls_ssl_init(&o->ssl); @@ -306,6 +313,12 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) { mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); + o->poll_mask = 0; + + if (o->last_error) { + *errcode = o->last_error; + return MP_STREAM_ERROR; + } int ret = mbedtls_ssl_read(&o->ssl, buf, size); if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { @@ -322,6 +335,9 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc // wanting to write next handshake message. The same may happen with // renegotation. ret = MP_EWOULDBLOCK; + o->poll_mask = MP_STREAM_POLL_WR; + } else { + o->last_error = ret; } *errcode = ret; return MP_STREAM_ERROR; @@ -329,6 +345,12 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) { mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); + o->poll_mask = 0; + + if (o->last_error) { + *errcode = o->last_error; + return MP_STREAM_ERROR; + } int ret = mbedtls_ssl_write(&o->ssl, buf, size); if (ret >= 0) { @@ -341,6 +363,9 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in // wanting to read next handshake message. The same may happen with // renegotation. ret = MP_EWOULDBLOCK; + o->poll_mask = MP_STREAM_POLL_RD; + } else { + o->last_error = ret; } *errcode = ret; return MP_STREAM_ERROR; @@ -358,7 +383,16 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking); STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, int *errcode) { mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(o_in); + mp_uint_t ret = 0; + uintptr_t saved_arg = 0; + mp_obj_t sock = self->sock; + if (sock == MP_OBJ_NULL || (request != MP_STREAM_CLOSE && self->last_error != 0)) { + // Closed or error socket: + return MP_STREAM_POLL_NVAL; + } + if (request == MP_STREAM_CLOSE) { + self->sock = MP_OBJ_NULL; mbedtls_pk_free(&self->pkey); mbedtls_x509_crt_free(&self->cert); mbedtls_x509_crt_free(&self->cacert); @@ -366,9 +400,39 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i mbedtls_ssl_config_free(&self->conf); mbedtls_ctr_drbg_free(&self->ctr_drbg); mbedtls_entropy_free(&self->entropy); + } else if (request == MP_STREAM_POLL) { + // If the library signaled us that it needs reading or writing, only check that direction, + // but save what the caller asked because we need to restore it later + if (self->poll_mask && (arg & MP_STREAM_POLL_RDWR)) { + saved_arg = arg & MP_STREAM_POLL_RDWR; + arg = (arg & ~saved_arg) | self->poll_mask; + } + + // Take into account that the library might have buffered data already + int has_pending = 0; + if (arg & MP_STREAM_POLL_RD) { + has_pending = mbedtls_ssl_check_pending(&self->ssl); + if (has_pending) { + ret |= MP_STREAM_POLL_RD; + if (arg == MP_STREAM_POLL_RD) { + // Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket + return MP_STREAM_POLL_RD; + } + } + } } + // Pass all requests down to the underlying socket - return mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode); + ret |= mp_get_stream(sock)->ioctl(sock, request, arg, errcode); + + if (request == MP_STREAM_POLL) { + // The direction the library needed is available, return a fake result to the caller so that + // it reenters a read or a write to allow the handshake to progress + if (ret & self->poll_mask) { + ret |= saved_arg; + } + } + return ret; } STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = { @@ -381,6 +445,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = { #if MICROPY_PY_USSL_FINALISER { MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) }, #endif + #if MICROPY_UNIX_COVERAGE + { MP_ROM_QSTR(MP_QSTR_ioctl), MP_ROM_PTR(&mp_stream_ioctl_obj) }, + #endif { MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) }, }; |
