summaryrefslogtreecommitdiff
path: root/extmod/modussl_mbedtls.c
diff options
context:
space:
mode:
Diffstat (limited to 'extmod/modussl_mbedtls.c')
-rw-r--r--extmod/modussl_mbedtls.c69
1 files changed, 68 insertions, 1 deletions
diff --git a/extmod/modussl_mbedtls.c b/extmod/modussl_mbedtls.c
index 95cf12c97..eea2d7953 100644
--- a/extmod/modussl_mbedtls.c
+++ b/extmod/modussl_mbedtls.c
@@ -46,6 +46,8 @@
#include "mbedtls/debug.h"
#include "mbedtls/error.h"
+#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
+
typedef struct _mp_obj_ssl_socket_t {
mp_obj_base_t base;
mp_obj_t sock;
@@ -56,6 +58,9 @@ typedef struct _mp_obj_ssl_socket_t {
mbedtls_x509_crt cacert;
mbedtls_x509_crt cert;
mbedtls_pk_context pkey;
+
+ uintptr_t poll_mask; // Indicates which read or write operations the protocol needs next
+ int last_error; // The last error code, if any
} mp_obj_ssl_socket_t;
struct ssl_args {
@@ -165,6 +170,8 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
#endif
o->base.type = &ussl_socket_type;
o->sock = sock;
+ o->poll_mask = 0;
+ o->last_error = 0;
int ret;
mbedtls_ssl_init(&o->ssl);
@@ -306,6 +313,12 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin
STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
+ o->poll_mask = 0;
+
+ if (o->last_error) {
+ *errcode = o->last_error;
+ return MP_STREAM_ERROR;
+ }
int ret = mbedtls_ssl_read(&o->ssl, buf, size);
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
@@ -322,6 +335,9 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
// wanting to write next handshake message. The same may happen with
// renegotation.
ret = MP_EWOULDBLOCK;
+ o->poll_mask = MP_STREAM_POLL_WR;
+ } else {
+ o->last_error = ret;
}
*errcode = ret;
return MP_STREAM_ERROR;
@@ -329,6 +345,12 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
+ o->poll_mask = 0;
+
+ if (o->last_error) {
+ *errcode = o->last_error;
+ return MP_STREAM_ERROR;
+ }
int ret = mbedtls_ssl_write(&o->ssl, buf, size);
if (ret >= 0) {
@@ -341,6 +363,9 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
// wanting to read next handshake message. The same may happen with
// renegotation.
ret = MP_EWOULDBLOCK;
+ o->poll_mask = MP_STREAM_POLL_RD;
+ } else {
+ o->last_error = ret;
}
*errcode = ret;
return MP_STREAM_ERROR;
@@ -358,7 +383,16 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking);
STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, int *errcode) {
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(o_in);
+ mp_uint_t ret = 0;
+ uintptr_t saved_arg = 0;
+ mp_obj_t sock = self->sock;
+ if (sock == MP_OBJ_NULL || (request != MP_STREAM_CLOSE && self->last_error != 0)) {
+ // Closed or error socket:
+ return MP_STREAM_POLL_NVAL;
+ }
+
if (request == MP_STREAM_CLOSE) {
+ self->sock = MP_OBJ_NULL;
mbedtls_pk_free(&self->pkey);
mbedtls_x509_crt_free(&self->cert);
mbedtls_x509_crt_free(&self->cacert);
@@ -366,9 +400,39 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
mbedtls_ssl_config_free(&self->conf);
mbedtls_ctr_drbg_free(&self->ctr_drbg);
mbedtls_entropy_free(&self->entropy);
+ } else if (request == MP_STREAM_POLL) {
+ // If the library signaled us that it needs reading or writing, only check that direction,
+ // but save what the caller asked because we need to restore it later
+ if (self->poll_mask && (arg & MP_STREAM_POLL_RDWR)) {
+ saved_arg = arg & MP_STREAM_POLL_RDWR;
+ arg = (arg & ~saved_arg) | self->poll_mask;
+ }
+
+ // Take into account that the library might have buffered data already
+ int has_pending = 0;
+ if (arg & MP_STREAM_POLL_RD) {
+ has_pending = mbedtls_ssl_check_pending(&self->ssl);
+ if (has_pending) {
+ ret |= MP_STREAM_POLL_RD;
+ if (arg == MP_STREAM_POLL_RD) {
+ // Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket
+ return MP_STREAM_POLL_RD;
+ }
+ }
+ }
}
+
// Pass all requests down to the underlying socket
- return mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
+ ret |= mp_get_stream(sock)->ioctl(sock, request, arg, errcode);
+
+ if (request == MP_STREAM_POLL) {
+ // The direction the library needed is available, return a fake result to the caller so that
+ // it reenters a read or a write to allow the handshake to progress
+ if (ret & self->poll_mask) {
+ ret |= saved_arg;
+ }
+ }
+ return ret;
}
STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
@@ -381,6 +445,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
#if MICROPY_PY_USSL_FINALISER
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) },
#endif
+ #if MICROPY_UNIX_COVERAGE
+ { MP_ROM_QSTR(MP_QSTR_ioctl), MP_ROM_PTR(&mp_stream_ioctl_obj) },
+ #endif
{ MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) },
};