From 5008a54afb271e8d539bf34e62c9d7f882657749 Mon Sep 17 00:00:00 2001 From: Sadha Chilukoori Date: Fri, 26 Jun 2026 17:24:48 -0700 Subject: [PATCH] Replace TrinoResult generator with class-based iterator Python generators are permanently finalized after an unhandled exception propagates through yield. With lazy spooled segment iteration (PR #597), a transient I/O error (e.g. S3 timeout) during iteration kills the generator. Subsequent fetchone() calls return None as if the query completed normally, silently dropping remaining rows. Replace the generator-based __iter__ with __iter__/__next__ on the class itself. Instance fields (_row_iter, _next_rows) survive exceptions, so the iterator can resume on the next next() call. Fixes #598 --- tests/unit/test_client.py | 43 +++++++++++++++++++++++++++++++++++++++ trino/client.py | 32 +++++++++++++++++++++-------- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 29f3f388..a75723b9 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1145,6 +1145,49 @@ def text(self): assert isinstance(result, TrinoResult) +def test_trino_result_iterator_survives_transient_error(): + """Iterator recovers from mid-iteration exceptions without losing rows.""" + + class FailOnceIterator: + def __init__(self): + self._count = 0 + self._failed = False + + def __iter__(self): + return self + + def __next__(self): + self._count += 1 + if self._count == 3 and not self._failed: + self._failed = True + raise IOError("S3 segment download failed") + if self._count > 5: + raise StopIteration + return [self._count] + + class FakeQuery: + finished = True + + def fetch(self): + return [] + + result = TrinoResult(FakeQuery(), FailOnceIterator()) + it = iter(result) + + assert next(it) == [1] + assert next(it) == [2] + + with pytest.raises(IOError, match="S3 segment download failed"): + next(it) + + # Key: iterator resumes after the error + assert next(it) == [4] + assert next(it) == [5] + + with pytest.raises(StopIteration): + next(it) + + def test_delay_exponential_without_jitter(): max_delay = 1200.0 get_delay = _DelayExponential(base=5, jitter=False, max_delay=max_delay) diff --git a/trino/client.py b/trino/client.py index 692d10fa..99adc066 100644 --- a/trino/client.py +++ b/trino/client.py @@ -821,8 +821,10 @@ class TrinoResult: """ Represent the result of a Trino query as an iterator on rows. - This class implements the iterator protocol as a generator type - https://docs.python.org/3/library/stdtypes.html#generator-types + This class implements the iterator protocol using __next__ so that + transient exceptions during iteration do not permanently kill the + iterator (unlike a generator, whose frame is finalized after an + unhandled exception). """ def __init__(self, query, rows: List[Any]): @@ -830,6 +832,8 @@ def __init__(self, query, rows: List[Any]): # Initial rows from the first POST request self._rows = rows self._rownumber = 0 + self._row_iter: Optional[Iterator[Any]] = None + self._next_rows = None @property def rows(self): @@ -844,15 +848,27 @@ def rownumber(self) -> int: return self._rownumber def __iter__(self): + return self + + def __next__(self): + # Lazy init: prefetch the next batch before exposing current rows. # A query only transitions to a FINISHED state when the results are fully consumed: # The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi. - while not self._query.finished or self._rows is not None: - next_rows = self._query.fetch() if not self._query.finished else None - for row in self._rows: - self._rownumber += 1 - yield row + if self._row_iter is None: + self._row_iter = iter(self._rows) + self._next_rows = self._query.fetch() if not self._query.finished else None - self._rows = next_rows + while True: + try: + row = next(self._row_iter) + self._rownumber += 1 + return row + except StopIteration: + if self._next_rows is None: + raise + self._rows = self._next_rows + self._row_iter = iter(self._rows) + self._next_rows = self._query.fetch() if not self._query.finished else None class TrinoQuery: