summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extmod/modlwip.c28
-rw-r--r--tests/net_hosted/connect_timeout.py24
2 files changed, 40 insertions, 12 deletions
diff --git a/extmod/modlwip.c b/extmod/modlwip.c
index afac512e8..1ae8c114c 100644
--- a/extmod/modlwip.c
+++ b/extmod/modlwip.c
@@ -326,6 +326,10 @@ typedef struct _lwip_socket_obj_t {
int8_t state;
} lwip_socket_obj_t;
+static inline bool socket_is_timedout(lwip_socket_obj_t *socket, mp_uint_t ticks_start) {
+ return socket->timeout != -1 && (mp_uint_t)(mp_hal_ticks_ms() - ticks_start) >= socket->timeout;
+}
+
static inline void poll_sockets(void) {
mp_event_wait_ms(1);
}
@@ -1130,21 +1134,21 @@ static mp_obj_t lwip_socket_connect(mp_obj_t self_in, mp_obj_t addr_in) {
MICROPY_PY_LWIP_EXIT
// And now we wait...
- if (socket->timeout != -1) {
- for (mp_uint_t retries = socket->timeout / 100; retries--;) {
- mp_hal_delay_ms(100);
- if (socket->state != STATE_CONNECTING) {
- break;
- }
- }
- if (socket->state == STATE_CONNECTING) {
- mp_raise_OSError(MP_EINPROGRESS);
+ mp_uint_t ticks_start = mp_hal_ticks_ms();
+ for (;;) {
+ poll_sockets();
+ if (socket->state != STATE_CONNECTING) {
+ break;
}
- } else {
- while (socket->state == STATE_CONNECTING) {
- poll_sockets();
+ if (socket_is_timedout(socket, ticks_start)) {
+ if (socket->timeout == 0) {
+ mp_raise_OSError(MP_EINPROGRESS);
+ } else {
+ mp_raise_OSError(MP_ETIMEDOUT);
+ }
}
}
+
if (socket->state == STATE_CONNECTED) {
err = ERR_OK;
} else {
diff --git a/tests/net_hosted/connect_timeout.py b/tests/net_hosted/connect_timeout.py
new file mode 100644
index 000000000..5f35047c8
--- /dev/null
+++ b/tests/net_hosted/connect_timeout.py
@@ -0,0 +1,24 @@
+# Test that socket.connect() on a socket with timeout raises EINPROGRESS or ETIMEDOUT appropriately.
+
+import errno
+import socket
+
+
+def test(peer_addr, timeout, expected_exc):
+ s = socket.socket()
+ s.settimeout(timeout)
+ try:
+ s.connect(peer_addr)
+ print("OK")
+ except OSError as er:
+ print(er.args[0] in expected_exc)
+ s.close()
+
+
+if __name__ == "__main__":
+ # This test needs an address that doesn't respond to TCP connections.
+ # 1.1.1.1:8000 seem to reliably timeout, so use that.
+ addr = socket.getaddrinfo("1.1.1.1", 8000)[0][-1]
+
+ test(addr, 0, (errno.EINPROGRESS,))
+ test(addr, 1, (errno.ETIMEDOUT, "timed out")) # CPython uses a string instead of errno