[Frontend] Reduce frequency of client cancellation checking (#7959)

This commit is contained in:
Nick Hill 2024-10-21 21:28:10 +01:00 committed by GitHub
parent 5241aa1494
commit 9d9186be97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,10 +13,11 @@ import subprocess
import sys
import tempfile
import threading
import time
import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, ensure_future
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections.abc import Mapping
from functools import lru_cache, partial, wraps
from platform import uname
@ -437,6 +438,12 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
return _async_wrapper
def _next_task(iterator: AsyncGenerator[T, None],
loop: AbstractEventLoop) -> Task:
# Can use anext() in python >= 3.10
return loop.create_task(iterator.__anext__()) # type: ignore[arg-type]
async def iterate_with_cancellation(
iterator: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]],
@ -445,19 +452,27 @@ async def iterate_with_cancellation(
at least once per second to check for client cancellation.
"""
# Can use anext() in python >= 3.10
awaits = [ensure_future(iterator.__anext__())]
loop = asyncio.get_running_loop()
awaits: List[Future[T]] = [_next_task(iterator, loop)]
next_cancel_check: float = 0
while True:
done, pending = await asyncio.wait(awaits, timeout=1)
if await is_cancelled():
with contextlib.suppress(BaseException):
awaits[0].cancel()
await iterator.aclose()
raise asyncio.CancelledError("client cancelled")
done, pending = await asyncio.wait(awaits, timeout=1.5)
# Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled():
with contextlib.suppress(BaseException):
awaits[0].cancel()
await iterator.aclose()
raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
if done:
try:
item = await awaits[0]
awaits[0] = ensure_future(iterator.__anext__())
awaits[0] = _next_task(iterator, loop)
yield item
except StopAsyncIteration:
# we are done
@ -478,25 +493,29 @@ async def merge_async_iterators(
to check for client cancellation.
"""
# Can use anext() in python >= 3.10
awaits = {
ensure_future(pair[1].__anext__()): pair
for pair in enumerate(iterators)
}
timeout = None if is_cancelled is None else 1
loop = asyncio.get_running_loop()
awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
timeout = None if is_cancelled is None else 1.5
next_cancel_check: float = 0
try:
while awaits:
done, pending = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED,
timeout=timeout)
if is_cancelled is not None and await is_cancelled():
raise asyncio.CancelledError("client cancelled")
if is_cancelled is not None:
# Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled():
raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
for d in done:
pair = awaits.pop(d)
try:
item = await d
i, it = pair
awaits[ensure_future(it.__anext__())] = pair
awaits[_next_task(it, loop)] = pair
yield i, item
except StopAsyncIteration:
pass