[FrontEnd] Make merge_async_iterators is_cancelled arg optional (#7282)

This commit is contained in:
Nick Hill 2024-08-07 13:35:14 -07:00 committed by GitHub
parent 311f743831
commit fc1493a01e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -405,7 +405,7 @@ async def iterate_with_cancellation(
async def merge_async_iterators( async def merge_async_iterators(
*iterators: AsyncGenerator[T, None], *iterators: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]], is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
) -> AsyncGenerator[Tuple[int, T], None]: ) -> AsyncGenerator[Tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator. """Merge multiple asynchronous iterators into a single iterator.
@ -413,8 +413,8 @@ async def merge_async_iterators(
When it yields, it yields a tuple (i, item) where i is the index of the When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item. iterator that yields the item.
It also polls the provided function at least once per second to check It also optionally polls a provided function at least once per second
for client cancellation. to check for client cancellation.
""" """
# Can use anext() in python >= 3.10 # Can use anext() in python >= 3.10
@ -422,12 +422,13 @@ async def merge_async_iterators(
ensure_future(pair[1].__anext__()): pair ensure_future(pair[1].__anext__()): pair
for pair in enumerate(iterators) for pair in enumerate(iterators)
} }
timeout = None if is_cancelled is None else 1
try: try:
while awaits: while awaits:
done, pending = await asyncio.wait(awaits.keys(), done, pending = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED, return_when=FIRST_COMPLETED,
timeout=1) timeout=timeout)
if await is_cancelled(): if is_cancelled is not None and await is_cancelled():
raise asyncio.CancelledError("client cancelled") raise asyncio.CancelledError("client cancelled")
for d in done: for d in done:
pair = awaits.pop(d) pair = awaits.pop(d)