Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions Lib/test/test_free_threading/test_itertools_tee_race.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import itertools
import unittest

from test.support import threading_helper


@threading_helper.requires_working_threading()
class TestTeeConcurrent(unittest.TestCase):
# itertools.tee branches share a linked list of internal data cells.
# Concurrent iteration must not corrupt that shared state or crash the
# free-threaded build. A crash shows up as the interpreter dying (not as a
# caught exception); tee is documented as not thread-safe, so a
# ``RuntimeError`` from the re-entrancy guard is an allowed outcome and is
# tolerated here.

def test_same_branch(self):
# Many threads consume the same tee branch.
errors = []

def consume(it):
try:
for _ in it:
pass
except RuntimeError:
pass
except Exception as e:
errors.append(e)

for _ in range(100):
a, _ = itertools.tee(iter(range(2000)), 2)
threading_helper.run_concurrently(consume, nthreads=8, args=(a,))

self.assertEqual(errors, [], msg=f"unexpected errors: {errors}")

def test_sibling_branches(self):
# Each thread consumes a different sibling branch of the same tee.
errors = []

def make_worker(it):
def consume():
try:
for _ in it:
pass
except RuntimeError:
pass
except Exception as e:
errors.append(e)
return consume

for _ in range(100):
branches = itertools.tee(iter(range(4000)), 8)
threading_helper.run_concurrently([make_worker(it) for it in branches])

self.assertEqual(errors, [], msg=f"unexpected errors: {errors}")


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix a crash when concurrently iterating an :func:`itertools.tee` iterator on
the free-threaded build.
80 changes: 73 additions & 7 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -768,13 +768,17 @@ teedataobject_newinternal(itertools_state *state, PyObject *it)
static PyObject *
teedataobject_jumplink(itertools_state *state, teedataobject *tdo)
{
PyObject *link;
Py_BEGIN_CRITICAL_SECTION(tdo);
if (tdo->nextlink == NULL)
tdo->nextlink = teedataobject_newinternal(state, tdo->it);
return Py_XNewRef(tdo->nextlink);
link = Py_XNewRef(tdo->nextlink);
Py_END_CRITICAL_SECTION();
return link;
}

static PyObject *
teedataobject_getitem(teedataobject *tdo, int i)
teedataobject_getitem_lock_held(teedataobject *tdo, int i)
{
PyObject *value;

Expand All @@ -800,6 +804,16 @@ teedataobject_getitem(teedataobject *tdo, int i)
return Py_NewRef(value);
}

static PyObject *
teedataobject_getitem(teedataobject *tdo, int i)
{
PyObject *result;
Py_BEGIN_CRITICAL_SECTION(tdo);
result = teedataobject_getitem_lock_held(tdo, i);
Py_END_CRITICAL_SECTION();
return result;
}

static int
teedataobject_traverse(PyObject *op, visitproc visit, void * arg)
{
Expand All @@ -819,8 +833,11 @@ teedataobject_safe_decref(PyObject *obj)
{
while (obj && _PyObject_IsUniquelyReferenced(obj)) {
teedataobject *tmp = teedataobject_CAST(obj);
PyObject *nextlink = tmp->nextlink;
PyObject *nextlink;
Py_BEGIN_CRITICAL_SECTION(obj);
nextlink = tmp->nextlink;
tmp->nextlink = NULL;
Py_END_CRITICAL_SECTION();
Py_SETREF(obj, nextlink);
}
Py_XDECREF(obj);
Expand All @@ -833,11 +850,13 @@ teedataobject_clear(PyObject *op)
PyObject *tmp;
teedataobject *tdo = teedataobject_CAST(op);

Py_BEGIN_CRITICAL_SECTION(op);
Py_CLEAR(tdo->it);
for (i=0 ; i<tdo->numread ; i++)
Py_CLEAR(tdo->values[i]);
tmp = tdo->nextlink;
tdo->nextlink = NULL;
Py_END_CRITICAL_SECTION();
teedataobject_safe_decref(tmp);
return 0;
}
Expand Down Expand Up @@ -930,20 +949,67 @@ static PyObject *
tee_next(PyObject *op)
{
teeobject *to = teeobject_CAST(op);
PyObject *value, *link;
PyObject *value;

#ifndef Py_GIL_DISABLED
/* The GIL already serializes access, so keep the simple path without the
snapshot and revalidation that the free-threaded build needs. */
if (to->index >= LINKCELLS) {
link = teedataobject_jumplink(to->state, to->dataobj);
if (link == NULL)
PyObject *link = teedataobject_jumplink(to->state, to->dataobj);
if (link == NULL) {
return NULL;
}
Py_SETREF(to->dataobj, (teedataobject *)link);
to->index = 0;
}
value = teedataobject_getitem(to->dataobj, to->index);
if (value == NULL)
if (value == NULL) {
return NULL;
}
to->index++;
return value;
#else
for (;;) {
teedataobject *dataobj;
int index;

/* Snapshot the branch position (strong ref to the shared data object)
under the tee lock; the data object is locked separately, not nested,
then the advance is revalidated. */
Py_BEGIN_CRITICAL_SECTION(op);
dataobj = (teedataobject *)Py_NewRef((PyObject *)to->dataobj);
index = to->index;
Py_END_CRITICAL_SECTION();

if (index < LINKCELLS) {
value = teedataobject_getitem(dataobj, index);
if (value != NULL) {
Py_BEGIN_CRITICAL_SECTION(op);
if (to->dataobj == dataobj && to->index == index) {
to->index = index + 1;
}
Py_END_CRITICAL_SECTION();
}
Py_DECREF(dataobj);
return value;
}

PyObject *link = teedataobject_jumplink(to->state, dataobj);
if (link == NULL) {
Py_DECREF(dataobj);
return NULL;
}
Py_BEGIN_CRITICAL_SECTION(op);
if (to->dataobj == dataobj) {
Py_SETREF(to->dataobj, (teedataobject *)link);
to->index = 0;
link = NULL;
}
Py_END_CRITICAL_SECTION();
Py_XDECREF(link);
Py_DECREF(dataobj);
}
#endif
}

static int
Expand Down
Loading