diff options
| author | Jeff Epler <jepler@gmail.com> | 2025-06-13 17:24:08 +0200 |
|---|---|---|
| committer | Damien George <damien@micropython.org> | 2025-06-17 10:15:59 +1000 |
| commit | b6b7d64bd9257cc1c138d37033108dcff5ceb89e (patch) | |
| tree | 34cbfaaeec6a62bf98b9d6be82668c15351f391f | |
| parent | 5ade8b7058f95e77a0d62d523c8443647f176532 (diff) | |
py/modio: Fix the case where write fails in BufferedWriter.flush.
Previously, there was no test coverage of the "write failed" path. In
fact, the assertion would fire instead of gracefully raising a Python
exception.
Slightly re-organize the code to place the assertion later. Add a test
case which exercises all paths, and update the expected output.
Signed-off-by: Jeff Epler <jepler@gmail.com>
| -rw-r--r-- | py/modio.c | 9 | ||||
| -rw-r--r-- | tests/basics/io_buffered_writer.py | 24 | ||||
| -rw-r--r-- | tests/basics/io_buffered_writer.py.exp | 5 |
3 files changed, 34 insertions, 4 deletions
diff --git a/py/modio.c b/py/modio.c index d3e563dbc..9aeb42d30 100644 --- a/py/modio.c +++ b/py/modio.c @@ -169,12 +169,13 @@ static mp_obj_t bufwriter_flush(mp_obj_t self_in) { int err; mp_uint_t out_sz = mp_stream_write_exactly(self->stream, self->buf, self->len, &err); (void)out_sz; - // TODO: try to recover from a case of non-blocking stream, e.g. move - // remaining chunk to the beginning of buffer. - assert(out_sz == self->len); - self->len = 0; if (err != 0) { mp_raise_OSError(err); + } else { + // TODO: try to recover from a case of non-blocking stream, e.g. move + // remaining chunk to the beginning of buffer. + assert(out_sz == self->len); + self->len = 0; } } diff --git a/tests/basics/io_buffered_writer.py b/tests/basics/io_buffered_writer.py index 5c065f158..60cf2c837 100644 --- a/tests/basics/io_buffered_writer.py +++ b/tests/basics/io_buffered_writer.py @@ -28,3 +28,27 @@ print(bts.getvalue()) # hashing a BufferedWriter print(type(hash(buf))) + +# Test failing flush() +class MyIO(io.IOBase): + def __init__(self): + self.count = 0 + + def write(self, buf): + self.count += 1 + if self.count < 3: + return None + print("writing", buf) + return len(buf) + + +buf = io.BufferedWriter(MyIO(), 8) + +buf.write(b"foobar") + +for _ in range(4): + try: + buf.flush() + print("flushed") + except OSError: + print("OSError") diff --git a/tests/basics/io_buffered_writer.py.exp b/tests/basics/io_buffered_writer.py.exp index 2209348f5..d61eb148b 100644 --- a/tests/basics/io_buffered_writer.py.exp +++ b/tests/basics/io_buffered_writer.py.exp @@ -4,3 +4,8 @@ b'foobarfoobar' b'foobarfoobar' b'foo' <class 'int'> +OSError +OSError +writing bytearray(b'foobar') +flushed +flushed |
