summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKeenan Johnson <keenan.johnson@gmail.com>2025-02-13 13:11:38 -0800
committerDamien George <damien@micropython.org>2025-02-14 12:55:25 +1100
commit321b30ca564bb33c625292247d00f7dd29dc9559 (patch)
tree34f04febf16e5825f069b002494dd512b0c30331
parentaef6705a321fbefb06288b5be1f5931bf8c42fe3 (diff)
extmod/modtls_mbedtls: Wire in support for DTLS.
This commit enables support for DTLS, i.e. TLS over datagram transport protocols like UDP. While support for DTLS is absent in CPython, it is worth supporting it in MicroPython because it is the basis of the ubiquitous CoAP protocol, used in many IoT projects. To select DTLS, a new set of "protocols" are added to SSLContext: - ssl.PROTOCOL_DTLS_CLIENT - ssl.PROTOCOL_DTLS_SERVER If one of these is set, the library assumes that the underlying socket is a datagram-like socket (i.e. UDP or similar). Our own timer callbacks are implemented because the out of the box implementation relies on `gettimeofday()`. This new DTLS feature is enabled on all ports that use mbedTLS. This commit is an update to a previous PR #10062. Addresses issue #5270 which requested DTLS support. Signed-off-by: Keenan Johnson <keenan.johnson@gmail.com>
-rw-r--r--docs/library/ssl.rst21
-rw-r--r--extmod/mbedtls/mbedtls_config_common.h1
-rw-r--r--extmod/modtls_mbedtls.c76
-rw-r--r--ports/esp32/boards/sdkconfig.base3
-rw-r--r--tests/extmod/tls_dtls.py51
-rw-r--r--tests/extmod/tls_dtls.py.exp3
6 files changed, 151 insertions, 4 deletions
diff --git a/docs/library/ssl.rst b/docs/library/ssl.rst
index dff90b8da..4327c74ba 100644
--- a/docs/library/ssl.rst
+++ b/docs/library/ssl.rst
@@ -117,11 +117,32 @@ Exceptions
This exception does NOT exist. Instead its base class, OSError, is used.
+DTLS support
+------------
+
+.. admonition:: Difference to CPython
+ :class: attention
+
+ This is a MicroPython extension.
+
+This module supports DTLS in client and server mode via the `PROTOCOL_DTLS_CLIENT`
+and `PROTOCOL_DTLS_SERVER` constants that can be used as the ``protocol`` argument
+of `SSLContext`.
+
+In this case the underlying socket is expected to behave as a datagram socket (i.e.
+like the socket opened with ``socket.socket`` with ``socket.AF_INET`` as ``af`` and
+``socket.SOCK_DGRAM`` as ``type``).
+
+DTLS is only supported on ports that use mbed TLS, and it is not enabled by default:
+it requires enabling ``MBEDTLS_SSL_PROTO_DTLS`` in the specific port configuration.
+
Constants
---------
.. data:: ssl.PROTOCOL_TLS_CLIENT
ssl.PROTOCOL_TLS_SERVER
+ ssl.PROTOCOL_DTLS_CLIENT (when DTLS support is enabled)
+ ssl.PROTOCOL_DTLS_SERVER (when DTLS support is enabled)
Supported values for the *protocol* parameter.
diff --git a/extmod/mbedtls/mbedtls_config_common.h b/extmod/mbedtls/mbedtls_config_common.h
index 6ea8540af..6cd14befc 100644
--- a/extmod/mbedtls/mbedtls_config_common.h
+++ b/extmod/mbedtls/mbedtls_config_common.h
@@ -89,6 +89,7 @@
#define MBEDTLS_SHA384_C
#define MBEDTLS_SHA512_C
#define MBEDTLS_SSL_CLI_C
+#define MBEDTLS_SSL_PROTO_DTLS
#define MBEDTLS_SSL_SRV_C
#define MBEDTLS_SSL_TLS_C
#define MBEDTLS_X509_CRT_PARSE_C
diff --git a/extmod/modtls_mbedtls.c b/extmod/modtls_mbedtls.c
index 3fd416d72..6c34805da 100644
--- a/extmod/modtls_mbedtls.c
+++ b/extmod/modtls_mbedtls.c
@@ -37,6 +37,7 @@
#include "py/stream.h"
#include "py/objstr.h"
#include "py/reader.h"
+#include "py/mphal.h"
#include "py/gc.h"
#include "extmod/vfs.h"
@@ -47,6 +48,9 @@
#include "mbedtls/pk.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
+#ifdef MBEDTLS_SSL_PROTO_DTLS
+#include "mbedtls/timing.h"
+#endif
#include "mbedtls/debug.h"
#include "mbedtls/error.h"
#if MBEDTLS_VERSION_NUMBER >= 0x03000000
@@ -65,6 +69,14 @@
#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
+#define MP_ENDPOINT_IS_SERVER (1 << 0)
+#define MP_TRANSPORT_IS_DTLS (1 << 1)
+
+#define MP_PROTOCOL_TLS_CLIENT 0
+#define MP_PROTOCOL_TLS_SERVER MP_ENDPOINT_IS_SERVER
+#define MP_PROTOCOL_DTLS_CLIENT MP_TRANSPORT_IS_DTLS
+#define MP_PROTOCOL_DTLS_SERVER MP_ENDPOINT_IS_SERVER | MP_TRANSPORT_IS_DTLS
+
// This corresponds to an SSLContext object.
typedef struct _mp_obj_ssl_context_t {
mp_obj_base_t base;
@@ -91,6 +103,12 @@ typedef struct _mp_obj_ssl_socket_t {
uintptr_t poll_mask; // Indicates which read or write operations the protocol needs next
int last_error; // The last error code, if any
+
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
+ mp_uint_t timer_start_ms;
+ mp_uint_t timer_fin_ms;
+ mp_uint_t timer_int_ms;
+ #endif
} mp_obj_ssl_socket_t;
static const mp_obj_type_t ssl_context_type;
@@ -242,7 +260,10 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args
mp_arg_check_num(n_args, n_kw, 1, 1, false);
// This is the "protocol" argument.
- mp_int_t endpoint = mp_obj_get_int(args[0]);
+ mp_int_t protocol = mp_obj_get_int(args[0]);
+
+ int endpoint = (protocol & MP_ENDPOINT_IS_SERVER) ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT;
+ int transport = (protocol & MP_TRANSPORT_IS_DTLS) ? MBEDTLS_SSL_TRANSPORT_DATAGRAM : MBEDTLS_SSL_TRANSPORT_STREAM;
// Create SSLContext object.
#if MICROPY_PY_SSL_FINALISER
@@ -282,7 +303,7 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args
}
ret = mbedtls_ssl_config_defaults(&self->conf, endpoint,
- MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
+ transport, MBEDTLS_SSL_PRESET_DEFAULT);
if (ret != 0) {
mbedtls_raise_error(ret);
}
@@ -525,6 +546,39 @@ static int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
}
}
+#ifdef MBEDTLS_SSL_PROTO_DTLS
+static void _mbedtls_timing_set_delay(void *ctx, uint32_t int_ms, uint32_t fin_ms) {
+ mp_obj_ssl_socket_t *o = (mp_obj_ssl_socket_t *)ctx;
+
+ o->timer_int_ms = int_ms;
+ o->timer_fin_ms = fin_ms;
+
+ if (fin_ms != 0) {
+ o->timer_start_ms = mp_hal_ticks_ms();
+ }
+}
+
+static int _mbedtls_timing_get_delay(void *ctx) {
+ mp_obj_ssl_socket_t *o = (mp_obj_ssl_socket_t *)ctx;
+
+ if (o->timer_fin_ms == 0) {
+ return -1;
+ }
+
+ mp_uint_t elapsed_ms = mp_hal_ticks_ms() - o->timer_start_ms;
+
+ if (elapsed_ms >= o->timer_fin_ms) {
+ return 2;
+ }
+
+ if (elapsed_ms >= o->timer_int_ms) {
+ return 1;
+ }
+
+ return 0;
+}
+#endif
+
static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) {
@@ -577,6 +631,10 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t
mp_raise_ValueError(MP_ERROR_TEXT("CERT_REQUIRED requires server_hostname"));
}
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
+ mbedtls_ssl_set_timer_cb(&o->ssl, o, _mbedtls_timing_set_delay, _mbedtls_timing_get_delay);
+ #endif
+
mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
if (do_handshake_on_connect) {
@@ -788,6 +846,12 @@ static const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) },
{ MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) },
{ MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) },
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
+ { MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&mp_stream_read1_obj) },
+ { MP_ROM_QSTR(MP_QSTR_recv_into), MP_ROM_PTR(&mp_stream_readinto_obj) },
+ { MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&mp_stream_write1_obj) },
+ { MP_ROM_QSTR(MP_QSTR_sendall), MP_ROM_PTR(&mp_stream_write_obj) },
+ #endif
{ MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socket_setblocking_obj) },
{ MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) },
#if MICROPY_PY_SSL_FINALISER
@@ -879,8 +943,12 @@ static const mp_rom_map_elem_t mp_module_tls_globals_table[] = {
// Constants.
{ MP_ROM_QSTR(MP_QSTR_MBEDTLS_VERSION), MP_ROM_PTR(&mbedtls_version_obj)},
- { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(MBEDTLS_SSL_IS_CLIENT) },
- { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(MBEDTLS_SSL_IS_SERVER) },
+ { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(MP_PROTOCOL_TLS_CLIENT) },
+ { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(MP_PROTOCOL_TLS_SERVER) },
+ #ifdef MBEDTLS_SSL_PROTO_DTLS
+ { MP_ROM_QSTR(MP_QSTR_PROTOCOL_DTLS_CLIENT), MP_ROM_INT(MP_PROTOCOL_DTLS_CLIENT) },
+ { MP_ROM_QSTR(MP_QSTR_PROTOCOL_DTLS_SERVER), MP_ROM_INT(MP_PROTOCOL_DTLS_SERVER) },
+ #endif
{ MP_ROM_QSTR(MP_QSTR_CERT_NONE), MP_ROM_INT(MBEDTLS_SSL_VERIFY_NONE) },
{ MP_ROM_QSTR(MP_QSTR_CERT_OPTIONAL), MP_ROM_INT(MBEDTLS_SSL_VERIFY_OPTIONAL) },
{ MP_ROM_QSTR(MP_QSTR_CERT_REQUIRED), MP_ROM_INT(MBEDTLS_SSL_VERIFY_REQUIRED) },
diff --git a/ports/esp32/boards/sdkconfig.base b/ports/esp32/boards/sdkconfig.base
index e20835c70..530db4271 100644
--- a/ports/esp32/boards/sdkconfig.base
+++ b/ports/esp32/boards/sdkconfig.base
@@ -64,6 +64,9 @@ CONFIG_MBEDTLS_HAVE_TIME_DATE=y
CONFIG_MBEDTLS_PLATFORM_TIME_ALT=y
CONFIG_MBEDTLS_HAVE_TIME=y
+# Enable DTLS
+CONFIG_MBEDTLS_SSL_PROTO_DTLS=y
+
# Disable ALPN support as it's not implemented in MicroPython
CONFIG_MBEDTLS_SSL_ALPN=n
diff --git a/tests/extmod/tls_dtls.py b/tests/extmod/tls_dtls.py
new file mode 100644
index 000000000..b2d716769
--- /dev/null
+++ b/tests/extmod/tls_dtls.py
@@ -0,0 +1,51 @@
+# Test DTLS functionality including timeout handling
+
+try:
+ from tls import PROTOCOL_DTLS_CLIENT, PROTOCOL_DTLS_SERVER, SSLContext, CERT_NONE
+ import io
+except ImportError:
+ print("SKIP")
+ raise SystemExit
+
+
+class DummySocket(io.IOBase):
+ def __init__(self):
+ self.write_buffer = bytearray()
+ self.read_buffer = bytearray()
+
+ def write(self, data):
+ return len(data)
+
+ def readinto(self, buf):
+ # This is a placeholder socket that doesn't actually read anything
+ # so the read buffer is always empty.
+ return None
+
+ def ioctl(self, req, arg):
+ if req == 4: # MP_STREAM_CLOSE
+ return 0
+ return -1
+
+
+# Create dummy sockets for testing
+server_socket = DummySocket()
+client_socket = DummySocket()
+
+# Wrap the DTLS Server
+dtls_server_ctx = SSLContext(PROTOCOL_DTLS_SERVER)
+dtls_server_ctx.verify_mode = CERT_NONE
+dtls_server = dtls_server_ctx.wrap_socket(server_socket, do_handshake_on_connect=False)
+print("Wrapped DTLS Server")
+
+# Wrap the DTLS Client
+dtls_client_ctx = SSLContext(PROTOCOL_DTLS_CLIENT)
+dtls_client_ctx.verify_mode = CERT_NONE
+dtls_client = dtls_client_ctx.wrap_socket(client_socket, do_handshake_on_connect=False)
+print("Wrapped DTLS Client")
+
+# Trigger the timing check multiple times with different elapsed times
+for i in range(10): # Try multiple iterations to hit the timing window
+ dtls_client.write(b"test")
+ data = dtls_server.read(1024) # This should eventually hit the timing condition
+
+print("OK")
diff --git a/tests/extmod/tls_dtls.py.exp b/tests/extmod/tls_dtls.py.exp
new file mode 100644
index 000000000..78d72bff1
--- /dev/null
+++ b/tests/extmod/tls_dtls.py.exp
@@ -0,0 +1,3 @@
+Wrapped DTLS Server
+Wrapped DTLS Client
+OK