summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeff Epler <jepler@gmail.com>2025-06-13 17:24:08 +0200
committerDamien George <damien@micropython.org>2025-06-17 10:15:59 +1000
commitb6b7d64bd9257cc1c138d37033108dcff5ceb89e (patch)
tree34cbfaaeec6a62bf98b9d6be82668c15351f391f
parent5ade8b7058f95e77a0d62d523c8443647f176532 (diff)
py/modio: Fix the case where write fails in BufferedWriter.flush.
Previously, there was no test coverage of the "write failed" path. In fact, the assertion would fire instead of gracefully raising a Python exception. Slightly re-organize the code to place the assertion later. Add a test case which exercises all paths, and update the expected output. Signed-off-by: Jeff Epler <jepler@gmail.com>
-rw-r--r--py/modio.c9
-rw-r--r--tests/basics/io_buffered_writer.py24
-rw-r--r--tests/basics/io_buffered_writer.py.exp5
3 files changed, 34 insertions, 4 deletions
diff --git a/py/modio.c b/py/modio.c
index d3e563dbc..9aeb42d30 100644
--- a/py/modio.c
+++ b/py/modio.c
@@ -169,12 +169,13 @@ static mp_obj_t bufwriter_flush(mp_obj_t self_in) {
int err;
mp_uint_t out_sz = mp_stream_write_exactly(self->stream, self->buf, self->len, &err);
(void)out_sz;
- // TODO: try to recover from a case of non-blocking stream, e.g. move
- // remaining chunk to the beginning of buffer.
- assert(out_sz == self->len);
- self->len = 0;
if (err != 0) {
mp_raise_OSError(err);
+ } else {
+ // TODO: try to recover from a case of non-blocking stream, e.g. move
+ // remaining chunk to the beginning of buffer.
+ assert(out_sz == self->len);
+ self->len = 0;
}
}
diff --git a/tests/basics/io_buffered_writer.py b/tests/basics/io_buffered_writer.py
index 5c065f158..60cf2c837 100644
--- a/tests/basics/io_buffered_writer.py
+++ b/tests/basics/io_buffered_writer.py
@@ -28,3 +28,27 @@ print(bts.getvalue())
# hashing a BufferedWriter
print(type(hash(buf)))
+
+# Test failing flush()
+class MyIO(io.IOBase):
+ def __init__(self):
+ self.count = 0
+
+ def write(self, buf):
+ self.count += 1
+ if self.count < 3:
+ return None
+ print("writing", buf)
+ return len(buf)
+
+
+buf = io.BufferedWriter(MyIO(), 8)
+
+buf.write(b"foobar")
+
+for _ in range(4):
+ try:
+ buf.flush()
+ print("flushed")
+ except OSError:
+ print("OSError")
diff --git a/tests/basics/io_buffered_writer.py.exp b/tests/basics/io_buffered_writer.py.exp
index 2209348f5..d61eb148b 100644
--- a/tests/basics/io_buffered_writer.py.exp
+++ b/tests/basics/io_buffered_writer.py.exp
@@ -4,3 +4,8 @@ b'foobarfoobar'
b'foobarfoobar'
b'foo'
<class 'int'>
+OSError
+OSError
+writing bytearray(b'foobar')
+flushed
+flushed