summaryrefslogtreecommitdiff
path: root/extmod/modssl_axtls.c
diff options
context:
space:
mode:
Diffstat (limited to 'extmod/modssl_axtls.c')
-rw-r--r--extmod/modssl_axtls.c155
1 files changed, 127 insertions, 28 deletions
diff --git a/extmod/modssl_axtls.c b/extmod/modssl_axtls.c
index de6e0ce5d..d169d89a2 100644
--- a/extmod/modssl_axtls.c
+++ b/extmod/modssl_axtls.c
@@ -4,6 +4,7 @@
* The MIT License (MIT)
*
* Copyright (c) 2015-2019 Paul Sokolovsky
+ * Copyright (c) 2023 Damien P. George
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -35,6 +36,17 @@
#include "ssl.h"
+#define PROTOCOL_TLS_CLIENT (0)
+#define PROTOCOL_TLS_SERVER (1)
+
+// This corresponds to an SSLContext object.
+typedef struct _mp_obj_ssl_context_t {
+ mp_obj_base_t base;
+ mp_obj_t key;
+ mp_obj_t cert;
+} mp_obj_ssl_context_t;
+
+// This corresponds to an SSLSocket object.
typedef struct _mp_obj_ssl_socket_t {
mp_obj_base_t base;
mp_obj_t sock;
@@ -53,8 +65,15 @@ struct ssl_args {
mp_arg_val_t do_handshake;
};
+STATIC const mp_obj_type_t ssl_context_type;
STATIC const mp_obj_type_t ssl_socket_type;
+STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
+ bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname);
+
+/******************************************************************************/
+// Helper functions.
+
// Table of error strings corresponding to SSL_xxx error codes.
STATIC const char *const ssl_error_tab1[] = {
"NOT_OK",
@@ -116,8 +135,71 @@ STATIC NORETURN void ssl_raise_error(int err) {
nlr_raise(mp_obj_exception_make_new(&mp_type_OSError, 2, 0, args));
}
+/******************************************************************************/
+// SSLContext type.
+
+STATIC mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) {
+ mp_arg_check_num(n_args, n_kw, 1, 1, false);
+
+ // The "protocol" argument is ignored in this implementation.
+
+ // Create SSLContext object.
+ #if MICROPY_PY_SSL_FINALISER
+ mp_obj_ssl_context_t *self = m_new_obj_with_finaliser(mp_obj_ssl_context_t);
+ #else
+ mp_obj_ssl_context_t *self = m_new_obj(mp_obj_ssl_context_t);
+ #endif
+ self->base.type = type_in;
+ self->key = mp_const_none;
+ self->cert = mp_const_none;
+
+ return MP_OBJ_FROM_PTR(self);
+}
+
+STATIC void ssl_context_load_key(mp_obj_ssl_context_t *self, mp_obj_t key_obj, mp_obj_t cert_obj) {
+ self->key = key_obj;
+ self->cert = cert_obj;
+}
+
+STATIC mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
+ enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname };
+ static const mp_arg_t allowed_args[] = {
+ { MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} },
+ { MP_QSTR_do_handshake_on_connect, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
+ { MP_QSTR_server_hostname, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
+ };
+
+ // Parse arguments.
+ mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(pos_args[0]);
+ mp_obj_t sock = pos_args[1];
+ mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
+ mp_arg_parse_all(n_args - 2, pos_args + 2, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
+
+ // Create and return the new SSLSocket object.
+ return ssl_socket_make_new(self, sock, args[ARG_server_side].u_bool,
+ args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj);
+}
+STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_context_wrap_socket_obj, 2, ssl_context_wrap_socket);
+
+STATIC const mp_rom_map_elem_t ssl_context_locals_dict_table[] = {
+ { MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_context_wrap_socket_obj) },
+};
+STATIC MP_DEFINE_CONST_DICT(ssl_context_locals_dict, ssl_context_locals_dict_table);
+
+STATIC MP_DEFINE_CONST_OBJ_TYPE(
+ ssl_context_type,
+ MP_QSTR_SSLContext,
+ MP_TYPE_FLAG_NONE,
+ make_new, ssl_context_make_new,
+ locals_dict, &ssl_context_locals_dict
+ );
+
+/******************************************************************************/
+// SSLSocket type.
+
+STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
+ bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) {
-STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args) {
#if MICROPY_PY_SSL_FINALISER
mp_obj_ssl_socket_t *o = m_new_obj_with_finaliser(mp_obj_ssl_socket_t);
#else
@@ -130,43 +212,43 @@ STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args)
o->blocking = true;
uint32_t options = SSL_SERVER_VERIFY_LATER;
- if (!args->do_handshake.u_bool) {
+ if (!do_handshake_on_connect) {
options |= SSL_CONNECT_IN_PARTS;
}
- if (args->key.u_obj != mp_const_none) {
+ if (ssl_context->key != mp_const_none) {
options |= SSL_NO_DEFAULT_KEY;
}
if ((o->ssl_ctx = ssl_ctx_new(options, SSL_DEFAULT_CLNT_SESS)) == NULL) {
mp_raise_OSError(MP_EINVAL);
}
- if (args->key.u_obj != mp_const_none) {
+ if (ssl_context->key != mp_const_none) {
size_t len;
- const byte *data = (const byte *)mp_obj_str_get_data(args->key.u_obj, &len);
+ const byte *data = (const byte *)mp_obj_str_get_data(ssl_context->key, &len);
int res = ssl_obj_memory_load(o->ssl_ctx, SSL_OBJ_RSA_KEY, data, len, NULL);
if (res != SSL_OK) {
mp_raise_ValueError(MP_ERROR_TEXT("invalid key"));
}
- data = (const byte *)mp_obj_str_get_data(args->cert.u_obj, &len);
+ data = (const byte *)mp_obj_str_get_data(ssl_context->cert, &len);
res = ssl_obj_memory_load(o->ssl_ctx, SSL_OBJ_X509_CERT, data, len, NULL);
if (res != SSL_OK) {
mp_raise_ValueError(MP_ERROR_TEXT("invalid cert"));
}
}
- if (args->server_side.u_bool) {
+ if (server_side) {
o->ssl_sock = ssl_server_new(o->ssl_ctx, (long)sock);
} else {
SSL_EXTENSIONS *ext = ssl_ext_new();
- if (args->server_hostname.u_obj != mp_const_none) {
- ext->host_name = (char *)mp_obj_str_get_str(args->server_hostname.u_obj);
+ if (server_hostname != mp_const_none) {
+ ext->host_name = (char *)mp_obj_str_get_str(server_hostname);
}
o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext);
- if (args->do_handshake.u_bool) {
+ if (do_handshake_on_connect) {
int r = ssl_handshake_status(o->ssl_sock);
if (r != SSL_OK) {
@@ -178,18 +260,11 @@ STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args)
ssl_raise_error(r);
}
}
-
}
return o;
}
-STATIC void ssl_socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) {
- (void)kind;
- mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in);
- mp_printf(print, "<_SSLSocket %p>", self->ssl_sock);
-}
-
STATIC mp_uint_t ssl_socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
@@ -305,7 +380,6 @@ STATIC const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) },
#endif
};
-
STATIC MP_DEFINE_CONST_DICT(ssl_socket_locals_dict, ssl_socket_locals_dict_table);
STATIC const mp_stream_p_t ssl_socket_stream_p = {
@@ -316,16 +390,23 @@ STATIC const mp_stream_p_t ssl_socket_stream_p = {
STATIC MP_DEFINE_CONST_OBJ_TYPE(
ssl_socket_type,
- // Save on qstr's, reuse same as for module
- MP_QSTR_ssl,
+ MP_QSTR_SSLSocket,
MP_TYPE_FLAG_NONE,
- print, ssl_socket_print,
protocol, &ssl_socket_stream_p,
locals_dict, &ssl_socket_locals_dict
);
+/******************************************************************************/
+// ssl module.
+
STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
- // TODO: Implement more args
+ enum {
+ ARG_key,
+ ARG_cert,
+ ARG_server_side,
+ ARG_server_hostname,
+ ARG_do_handshake,
+ };
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
{ MP_QSTR_cert, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
@@ -334,22 +415,40 @@ STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_
{ MP_QSTR_do_handshake, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
};
- // TODO: Check that sock implements stream protocol
+ // Parse arguments.
mp_obj_t sock = pos_args[0];
+ mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
+ mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
+
+ // Create SSLContext.
+ mp_int_t protocol = args[ARG_server_side].u_bool ? PROTOCOL_TLS_SERVER : PROTOCOL_TLS_CLIENT;
+ mp_obj_t ssl_context_args[1] = { MP_OBJ_NEW_SMALL_INT(protocol) };
+ mp_obj_ssl_context_t *ssl_context = MP_OBJ_TO_PTR(ssl_context_make_new(&ssl_context_type, 1, 0, ssl_context_args));
- struct ssl_args args;
- mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args,
- MP_ARRAY_SIZE(allowed_args), allowed_args, (mp_arg_val_t *)&args);
+ // Load key and cert if given.
+ if (args[ARG_key].u_obj != mp_const_none) {
+ ssl_context_load_key(ssl_context, args[ARG_key].u_obj, args[ARG_cert].u_obj);
+ }
- return MP_OBJ_FROM_PTR(ssl_socket_new(sock, &args));
+ // Create and return the new SSLSocket object.
+ return ssl_socket_make_new(ssl_context, sock, args[ARG_server_side].u_bool,
+ args[ARG_do_handshake].u_bool, args[ARG_server_hostname].u_obj);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socket);
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ssl) },
+
+ // Functions.
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) },
-};
+ // Classes.
+ { MP_ROM_QSTR(MP_QSTR_SSLContext), MP_ROM_PTR(&ssl_context_type) },
+
+ // Constants.
+ { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(PROTOCOL_TLS_CLIENT) },
+ { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(PROTOCOL_TLS_SERVER) },
+};
STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);
const mp_obj_module_t mp_module_ssl = {