[Frontend] Reduce frequency of client cancellation checking (#7959)

This commit is contained in:
Nick Hill 2024-10-21 21:28:10 +01:00 committed by GitHub
parent 5241aa1494
commit 9d9186be97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,10 +13,11 @@ import subprocess
import sys import sys
import tempfile import tempfile
import threading import threading
import time
import uuid import uuid
import warnings import warnings
import weakref import weakref
from asyncio import FIRST_COMPLETED, ensure_future from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections.abc import Mapping from collections.abc import Mapping
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
@ -437,6 +438,12 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
return _async_wrapper return _async_wrapper
def _next_task(iterator: AsyncGenerator[T, None],
loop: AbstractEventLoop) -> Task:
# Can use anext() in python >= 3.10
return loop.create_task(iterator.__anext__()) # type: ignore[arg-type]
async def iterate_with_cancellation( async def iterate_with_cancellation(
iterator: AsyncGenerator[T, None], iterator: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]], is_cancelled: Callable[[], Awaitable[bool]],
@ -445,19 +452,27 @@ async def iterate_with_cancellation(
at least once per second to check for client cancellation. at least once per second to check for client cancellation.
""" """
# Can use anext() in python >= 3.10 loop = asyncio.get_running_loop()
awaits = [ensure_future(iterator.__anext__())]
awaits: List[Future[T]] = [_next_task(iterator, loop)]
next_cancel_check: float = 0
while True: while True:
done, pending = await asyncio.wait(awaits, timeout=1) done, pending = await asyncio.wait(awaits, timeout=1.5)
# Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled(): if await is_cancelled():
with contextlib.suppress(BaseException): with contextlib.suppress(BaseException):
awaits[0].cancel() awaits[0].cancel()
await iterator.aclose() await iterator.aclose()
raise asyncio.CancelledError("client cancelled") raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
if done: if done:
try: try:
item = await awaits[0] item = await awaits[0]
awaits[0] = ensure_future(iterator.__anext__()) awaits[0] = _next_task(iterator, loop)
yield item yield item
except StopAsyncIteration: except StopAsyncIteration:
# we are done # we are done
@ -478,25 +493,29 @@ async def merge_async_iterators(
to check for client cancellation. to check for client cancellation.
""" """
# Can use anext() in python >= 3.10 loop = asyncio.get_running_loop()
awaits = {
ensure_future(pair[1].__anext__()): pair awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
for pair in enumerate(iterators) timeout = None if is_cancelled is None else 1.5
} next_cancel_check: float = 0
timeout = None if is_cancelled is None else 1
try: try:
while awaits: while awaits:
done, pending = await asyncio.wait(awaits.keys(), done, pending = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED, return_when=FIRST_COMPLETED,
timeout=timeout) timeout=timeout)
if is_cancelled is not None and await is_cancelled(): if is_cancelled is not None:
# Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled():
raise asyncio.CancelledError("client cancelled") raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
for d in done: for d in done:
pair = awaits.pop(d) pair = awaits.pop(d)
try: try:
item = await d item = await d
i, it = pair i, it = pair
awaits[ensure_future(it.__anext__())] = pair awaits[_next_task(it, loop)] = pair
yield i, item yield i, item
except StopAsyncIteration: except StopAsyncIteration:
pass pass