summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extmod/asyncio/stream.py26
-rw-r--r--tests/net_hosted/asyncio_start_server.py38
-rw-r--r--tests/net_hosted/asyncio_start_server.py.exp18
3 files changed, 79 insertions, 3 deletions
diff --git a/extmod/asyncio/stream.py b/extmod/asyncio/stream.py
index c47c48cf0..5547bfbd5 100644
--- a/extmod/asyncio/stream.py
+++ b/extmod/asyncio/stream.py
@@ -127,20 +127,30 @@ class Server:
await self.wait_closed()
def close(self):
+ # Note: the _serve task must have already started by now due to the sleep
+ # in start_server, so `state` won't be clobbered at the start of _serve.
+ self.state = True
self.task.cancel()
async def wait_closed(self):
await self.task
async def _serve(self, s, cb):
+ self.state = False
# Accept incoming connections
while True:
try:
yield core._io_queue.queue_read(s)
- except core.CancelledError:
- # Shutdown server
+ except core.CancelledError as er:
+ # The server task was cancelled, shutdown server and close socket.
s.close()
- return
+ if self.state:
+ # If the server was explicitly closed, ignore the cancellation.
+ return
+ else:
+ # Otherwise e.g. the parent task was cancelled, propagate
+ # cancellation.
+ raise er
try:
s2, addr = s.accept()
except:
@@ -167,6 +177,16 @@ async def start_server(cb, host, port, backlog=5):
# Create and return server object and task.
srv = Server()
srv.task = core.create_task(srv._serve(s, cb))
+ try:
+ # Ensure that the _serve task has been scheduled so that it gets to
+ # handle cancellation.
+ await core.sleep_ms(0)
+ except core.CancelledError as er:
+ # If the parent task is cancelled during this first sleep, then
+ # we will leak the task and it will sit waiting for the socket, so
+ # cancel it.
+ srv.task.cancel()
+ raise er
return srv
diff --git a/tests/net_hosted/asyncio_start_server.py b/tests/net_hosted/asyncio_start_server.py
index 316221898..e76faf7ed 100644
--- a/tests/net_hosted/asyncio_start_server.py
+++ b/tests/net_hosted/asyncio_start_server.py
@@ -22,6 +22,44 @@ async def test():
print("sleep")
await asyncio.sleep(0)
+ # Test that cancellation works before the server starts if
+ # the subsequent code raises.
+ print("create server3")
+ server3 = await asyncio.start_server(None, "0.0.0.0", 8000)
+ try:
+ async with server3:
+ raise OSError
+ except OSError as er:
+ print("OSError")
+
+ # Test that closing doesn't raise CancelledError.
+ print("create server4")
+ server4 = await asyncio.start_server(None, "0.0.0.0", 8000)
+ server4.close()
+ await server4.wait_closed()
+ print("server4 closed")
+
+ # Test that cancelling the task will still raise CancelledError, checking
+ # edge cases around how many times the tasks have been re-scheduled by
+ # sleep.
+ async def task(n):
+ print("create task server", n)
+ srv = await asyncio.start_server(None, "0.0.0.0", 8000)
+ await srv.wait_closed()
+ # This should be unreachable.
+ print("task finished")
+
+ for num_sleep in range(0, 5):
+ print("sleep", num_sleep)
+ t = asyncio.create_task(task(num_sleep))
+ for _ in range(num_sleep):
+ await asyncio.sleep(0)
+ t.cancel()
+ try:
+ await t
+ except asyncio.CancelledError:
+ print("CancelledError")
+
print("done")
diff --git a/tests/net_hosted/asyncio_start_server.py.exp b/tests/net_hosted/asyncio_start_server.py.exp
index 0fb8e6a63..58982a108 100644
--- a/tests/net_hosted/asyncio_start_server.py.exp
+++ b/tests/net_hosted/asyncio_start_server.py.exp
@@ -2,4 +2,22 @@ create server1
create server2
OSError
sleep
+create server3
+OSError
+create server4
+server4 closed
+sleep 0
+CancelledError
+sleep 1
+create task server 1
+CancelledError
+sleep 2
+create task server 2
+CancelledError
+sleep 3
+create task server 3
+CancelledError
+sleep 4
+create task server 4
+CancelledError
done