summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/library/ssl.rst21
-rw-r--r--extmod/mbedtls/mbedtls_config_common.h1
-rw-r--r--extmod/modtls_mbedtls.c76
-rw-r--r--ports/esp32/boards/sdkconfig.base3
-rw-r--r--tests/extmod/tls_dtls.py51
-rw-r--r--tests/extmod/tls_dtls.py.exp3
6 files changed, 151 insertions, 4 deletions
diff --git a/docs/library/ssl.rst b/docs/library/ssl.rst
index dff90b8da..4327c74ba 100644
--- a/docs/library/ssl.rst
+++ b/docs/library/ssl.rst
@@ -117,11 +117,32 @@ Exceptions
This exception does NOT exist. Instead its base class, OSError, is used.
+DTLS support
+------------
+
+.. admonition:: Difference to CPython
+ :class: attention
+
+ This is a MicroPython extension.
+
+This module supports DTLS in client and server mode via the `PROTOCOL_DTLS_CLIENT`
+and `PROTOCOL_DTLS_SERVER` constants that can be used as the ``protocol`` argument
+of `SSLContext`.
+
+In this case the underlying socket is expected to behave as a datagram socket (i.e.
+like the socket opened with ``socket.socket`` with ``socket.AF_INET`` as ``af`` and
+``socket.SOCK_DGRAM`` as ``type``).
+
+DTLS is only supported on ports that use mbed TLS, and it is not enabled by default:
+it requires enabling ``MBEDTLS_SSL_PROTO_DTLS`` in the specific port configuration.
+
Constants
---------
.. data:: ssl.PROTOCOL_TLS_CLIENT
ssl.PROTOCOL_TLS_SERVER
+ ssl.PROTOCOL_DTLS_CLIENT (when DTLS support is enabled)
+ ssl.PROTOCOL_DTLS_SERVER (when DTLS support is enabled)
Supported values for the *protocol* parameter.
diff --git a/extmod/mbedtls/mbedtls_config_common.h b/extmod/mbedtls/mbedtls_config_common.h
index 6ea8540af..6cd14befc 100644
--- a/extmod/mbedtls/mbedtls_config_common.h
+++ b/extmod/mbedtls/mbedtls_config_common.h
@@ -89,6 +89,7 @@
#define MBEDTLS_SHA384_C
#define MBEDTLS_SHA512_C
#define MBEDTLS_SSL_CLI_C
+#define MBEDTLS_SSL_PROTO_DTLS
#define MBEDTLS_SSL_SRV_C
#define MBEDTLS_SSL_TLS_C
#define MBEDTLS_X509_CRT_PARSE_C
diff --git a/extmod/modtls_mbedtls.c b/extmod/modtls_mbedtls.c
index 3fd416d72..6c34805da 100644
--- a/extmod/modtls_mbedtls.c
+++ b/extmod/modtls_mbedtls.c
@@ -37,6 +37,7 @@
#include "py/stream.h"
#include "py/objstr.h"
#include "py/reader.h"
+#include "py/mphal.h"
#include "py/gc.h"
#include "extmod/vfs.h"
@@ -47,6 +48,9 @@
#include "mbedtls/pk.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
+#ifdef MBEDTLS_SSL_PROTO_DTLS
+#include "mbedtls/timing.h"
+#endif
#include "mbedtls/debug.h"
#include "mbedtls/error.h"
#if MBEDTLS_VERSION_NUMBER >= 0x03000000
@@ -65,6 +69,14 @@
#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
+#define MP_ENDPOINT_IS_SERVER (1 << 0)
+#define MP_TRANSPORT_IS_DTLS (1 << 1)
+
+#define MP_PROTOCOL_TLS_CLIENT 0
+#define MP_PROTOCOL_TLS_SERVER MP_ENDPOINT_IS_SERVER
+#define MP_PROTOCOL_DTLS_CLIENT MP_TRANSPORT_IS_DTLS
+#define MP_PROTOCOL_DTLS_SERVER MP_ENDPOINT_IS_SERVER | MP_TRANSPORT_IS_DTLS
+
// This corresponds to an SSLContext object.
typedef struct _mp_obj_ssl_context_t {
mp_obj_base_t base;
@@ -91,6 +103,12 @@ typedef struct _mp_obj_ssl_socket_t {
uintptr_t poll_mask; // Indicates which read or write operations the protocol needs next
int last_error; // The last error code, if any
+
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
+ mp_uint_t timer_start_ms;
+ mp_uint_t timer_fin_ms;
+ mp_uint_t timer_int_ms;
+ #endif
} mp_obj_ssl_socket_t;
static const mp_obj_type_t ssl_context_type;
@@ -242,7 +260,10 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args
mp_arg_check_num(n_args, n_kw, 1, 1, false);
// This is the "protocol" argument.
- mp_int_t endpoint = mp_obj_get_int(args[0]);
+ mp_int_t protocol = mp_obj_get_int(args[0]);
+
+ int endpoint = (protocol & MP_ENDPOINT_IS_SERVER) ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT;
+ int transport = (protocol & MP_TRANSPORT_IS_DTLS) ? MBEDTLS_SSL_TRANSPORT_DATAGRAM : MBEDTLS_SSL_TRANSPORT_STREAM;
// Create SSLContext object.
#if MICROPY_PY_SSL_FINALISER
@@ -282,7 +303,7 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args
}
ret = mbedtls_ssl_config_defaults(&self->conf, endpoint,
- MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
+ transport, MBEDTLS_SSL_PRESET_DEFAULT);
if (ret != 0) {
mbedtls_raise_error(ret);
}
@@ -525,6 +546,39 @@ static int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
}
}
+#ifdef MBEDTLS_SSL_PROTO_DTLS
+static void _mbedtls_timing_set_delay(void *ctx, uint32_t int_ms, uint32_t fin_ms) {
+ mp_obj_ssl_socket_t *o = (mp_obj_ssl_socket_t *)ctx;
+
+ o->timer_int_ms = int_ms;
+ o->timer_fin_ms = fin_ms;
+
+ if (fin_ms != 0) {
+ o->timer_start_ms = mp_hal_ticks_ms();
+ }
+}
+
+static int _mbedtls_timing_get_delay(void *ctx) {
+ mp_obj_ssl_socket_t *o = (mp_obj_ssl_socket_t *)ctx;
+
+ if (o->timer_fin_ms == 0) {
+ return -1;
+ }
+
+ mp_uint_t elapsed_ms = mp_hal_ticks_ms() - o->timer_start_ms;
+
+ if (elapsed_ms >= o->timer_fin_ms) {
+ return 2;
+ }
+
+ if (elapsed_ms >= o->timer_int_ms) {
+ return 1;
+ }
+
+ return 0;
+}
+#endif
+
static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) {
@@ -577,6 +631,10 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t
mp_raise_ValueError(MP_ERROR_TEXT("CERT_REQUIRED requires server_hostname"));
}
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
+ mbedtls_ssl_set_timer_cb(&o->ssl, o, _mbedtls_timing_set_delay, _mbedtls_timing_get_delay);
+ #endif
+
mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
if (do_handshake_on_connect) {
@@ -788,6 +846,12 @@ static const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) },
{ MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) },
{ MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) },
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
+ { MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&mp_stream_read1_obj) },
+ { MP_ROM_QSTR(MP_QSTR_recv_into), MP_ROM_PTR(&mp_stream_readinto_obj) },
+ { MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&mp_stream_write1_obj) },
+ { MP_ROM_QSTR(MP_QSTR_sendall), MP_ROM_PTR(&mp_stream_write_obj) },
+ #endif
{ MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socket_setblocking_obj) },
{ MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) },
#if MICROPY_PY_SSL_FINALISER
@@ -879,8 +943,12 @@ static const mp_rom_map_elem_t mp_module_tls_globals_table[] = {
// Constants.
{ MP_ROM_QSTR(MP_QSTR_MBEDTLS_VERSION), MP_ROM_PTR(&mbedtls_version_obj)},
- { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(MBEDTLS_SSL_IS_CLIENT) },
- { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(MBEDTLS_SSL_IS_SERVER) },
+ { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(MP_PROTOCOL_TLS_CLIENT) },
+ { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(MP_PROTOCOL_TLS_SERVER) },
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
+ { MP_ROM_QSTR(MP_QSTR_PROTOCOL_DTLS_CLIENT), MP_ROM_INT(MP_PROTOCOL_DTLS_CLIENT) },
+ { MP_ROM_QSTR(MP_QSTR_PROTOCOL_DTLS_SERVER), MP_ROM_INT(MP_PROTOCOL_DTLS_SERVER) },
+ #endif
{ MP_ROM_QSTR(MP_QSTR_CERT_NONE), MP_ROM_INT(MBEDTLS_SSL_VERIFY_NONE) },
{ MP_ROM_QSTR(MP_QSTR_CERT_OPTIONAL), MP_ROM_INT(MBEDTLS_SSL_VERIFY_OPTIONAL) },
{ MP_ROM_QSTR(MP_QSTR_CERT_REQUIRED), MP_ROM_INT(MBEDTLS_SSL_VERIFY_REQUIRED) },
diff --git a/ports/esp32/boards/sdkconfig.base b/ports/esp32/boards/sdkconfig.base
index e20835c70..530db4271 100644
--- a/ports/esp32/boards/sdkconfig.base
+++ b/ports/esp32/boards/sdkconfig.base
@@ -64,6 +64,9 @@ CONFIG_MBEDTLS_HAVE_TIME_DATE=y
CONFIG_MBEDTLS_PLATFORM_TIME_ALT=y
CONFIG_MBEDTLS_HAVE_TIME=y
+# Enable DTLS
+CONFIG_MBEDTLS_SSL_PROTO_DTLS=y
+
# Disable ALPN support as it's not implemented in MicroPython
CONFIG_MBEDTLS_SSL_ALPN=n
diff --git a/tests/extmod/tls_dtls.py b/tests/extmod/tls_dtls.py
new file mode 100644
index 000000000..b2d716769
--- /dev/null
+++ b/tests/extmod/tls_dtls.py
@@ -0,0 +1,51 @@
+# Test DTLS functionality including timeout handling
+
+try:
+ from tls import PROTOCOL_DTLS_CLIENT, PROTOCOL_DTLS_SERVER, SSLContext, CERT_NONE
+ import io
+except ImportError:
+ print("SKIP")
+ raise SystemExit
+
+
+class DummySocket(io.IOBase):
+ def __init__(self):
+ self.write_buffer = bytearray()
+ self.read_buffer = bytearray()
+
+ def write(self, data):
+ return len(data)
+
+ def readinto(self, buf):
+ # This is a placeholder socket that doesn't actually read anything
+ # so the read buffer is always empty.
+ return None
+
+ def ioctl(self, req, arg):
+ if req == 4: # MP_STREAM_CLOSE
+ return 0
+ return -1
+
+
+# Create dummy sockets for testing
+server_socket = DummySocket()
+client_socket = DummySocket()
+
+# Wrap the DTLS Server
+dtls_server_ctx = SSLContext(PROTOCOL_DTLS_SERVER)
+dtls_server_ctx.verify_mode = CERT_NONE
+dtls_server = dtls_server_ctx.wrap_socket(server_socket, do_handshake_on_connect=False)
+print("Wrapped DTLS Server")
+
+# Wrap the DTLS Client
+dtls_client_ctx = SSLContext(PROTOCOL_DTLS_CLIENT)
+dtls_client_ctx.verify_mode = CERT_NONE
+dtls_client = dtls_client_ctx.wrap_socket(client_socket, do_handshake_on_connect=False)
+print("Wrapped DTLS Client")
+
+# Trigger the timing check multiple times with different elapsed times
+for i in range(10): # Try multiple iterations to hit the timing window
+ dtls_client.write(b"test")
+ data = dtls_server.read(1024) # This should eventually hit the timing condition
+
+print("OK")
diff --git a/tests/extmod/tls_dtls.py.exp b/tests/extmod/tls_dtls.py.exp
new file mode 100644
index 000000000..78d72bff1
--- /dev/null
+++ b/tests/extmod/tls_dtls.py.exp
@@ -0,0 +1,3 @@
+Wrapped DTLS Server
+Wrapped DTLS Client
+OK