summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ports/esp32/modsocket.c19
-rw-r--r--tests/esp32/resolve_on_connect.py59
-rw-r--r--tests/net_inet/getaddrinfo.py52
3 files changed, 123 insertions, 7 deletions
diff --git a/ports/esp32/modsocket.c b/ports/esp32/modsocket.c
index 69a74ec25..85433e575 100644
--- a/ports/esp32/modsocket.c
+++ b/ports/esp32/modsocket.c
@@ -244,19 +244,24 @@ static int _socket_getaddrinfo2(const mp_obj_t host, const mp_obj_t portx, struc
int res = _socket_getaddrinfo3(host_str, port_str, &hints, resp);
MP_THREAD_GIL_ENTER();
+ // Per docs: instead of raising gaierror getaddrinfo raises negative error number
+ if (res != 0) {
+ mp_raise_OSError(res > 0 ? -res : res);
+ }
+ // Somehow LwIP returns a resolution of 0.0.0.0 for failed lookups, traced it as far back
+ // as netconn_gethostbyname_addrtype returning OK instead of error.
+ if (*resp == NULL ||
+ (strcmp(resp[0]->ai_canonname, "0.0.0.0") == 0 && strcmp(host_str, "0.0.0.0") != 0)) {
+ mp_raise_OSError(-2); // name or service not known
+ }
+
return res;
}
STATIC void _socket_getaddrinfo(const mp_obj_t addrtuple, struct addrinfo **resp) {
mp_obj_t *elem;
mp_obj_get_array_fixed_n(addrtuple, 2, &elem);
- int res = _socket_getaddrinfo2(elem[0], elem[1], resp);
- if (res != 0) {
- mp_raise_OSError(res);
- }
- if (*resp == NULL) {
- mp_raise_OSError(-2); // name or service not known
- }
+ _socket_getaddrinfo2(elem[0], elem[1], resp);
}
STATIC mp_obj_t socket_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) {
diff --git a/tests/esp32/resolve_on_connect.py b/tests/esp32/resolve_on_connect.py
new file mode 100644
index 000000000..068757ab2
--- /dev/null
+++ b/tests/esp32/resolve_on_connect.py
@@ -0,0 +1,59 @@
+# Test that the esp32's socket module performs DNS resolutions on bind and connect
+import sys
+
+if sys.implementation.name == "micropython" and sys.platform != "esp32":
+ print("SKIP")
+ raise SystemExit
+
+try:
+ import usocket as socket, sys
+except:
+ import socket, sys
+
+
+def test_bind_resolves_0_0_0_0():
+ s = socket.socket()
+ try:
+ s.bind(("0.0.0.0", 31245))
+ print("bind actually bound")
+ s.close()
+ except Exception as e:
+ print("bind raised", e)
+
+
+def test_bind_resolves_localhost():
+ s = socket.socket()
+ try:
+ s.bind(("localhost", 31245))
+ print("bind actually bound")
+ s.close()
+ except Exception as e:
+ print("bind raised", e)
+
+
+def test_connect_resolves():
+ s = socket.socket()
+ try:
+ s.connect(("micropython.org", 80))
+ print("connect actually connected")
+ s.close()
+ except Exception as e:
+ print("connect raised", e)
+
+
+def test_connect_non_existent():
+ s = socket.socket()
+ try:
+ s.connect(("nonexistent.example.com", 80))
+ print("connect actually connected")
+ s.close()
+ except OSError as e:
+ print("connect raised OSError")
+ except Exception as e:
+ print("connect raised", e)
+
+
+test_funs = [n for n in dir() if n.startswith("test_")]
+for f in sorted(test_funs):
+ print("--", f, end=": ")
+ eval(f + "()")
diff --git a/tests/net_inet/getaddrinfo.py b/tests/net_inet/getaddrinfo.py
new file mode 100644
index 000000000..765723ae7
--- /dev/null
+++ b/tests/net_inet/getaddrinfo.py
@@ -0,0 +1,52 @@
+try:
+ import usocket as socket, sys
+except:
+ import socket, sys
+
+
+def test_non_existent():
+ try:
+ res = socket.getaddrinfo("nonexistent.example.com", 80)
+ print("getaddrinfo returned", res)
+ except OSError as e:
+ print("getaddrinfo raised")
+
+
+def test_bogus():
+ try:
+ res = socket.getaddrinfo("hey.!!$$", 80)
+ print("getaddrinfo returned", res)
+ except OSError as e:
+ print("getaddrinfo raised")
+ except Exception as e:
+ print("getaddrinfo raised") # CPython raises UnicodeError!?
+
+
+def test_ip_addr():
+ try:
+ res = socket.getaddrinfo("10.10.10.10", 80)
+ print("getaddrinfo returned resolutions")
+ except Exception as e:
+ print("getaddrinfo raised", e)
+
+
+def test_0_0_0_0():
+ try:
+ res = socket.getaddrinfo("0.0.0.0", 80)
+ print("getaddrinfo returned resolutions")
+ except Exception as e:
+ print("getaddrinfo raised", e)
+
+
+def test_valid():
+ try:
+ res = socket.getaddrinfo("micropython.org", 80)
+ print("getaddrinfo returned resolutions")
+ except Exception as e:
+ print("getaddrinfo raised", e)
+
+
+test_funs = [n for n in dir() if n.startswith("test_")]
+for f in sorted(test_funs):
+ print("--", f, end=": ")
+ eval(f + "()")