[Bugfix][Core] Update cancellation logic in generate() to handle Generator exits (#19225)

Co-authored-by: Adolfo Victoria <adovi@meta.com>
This commit is contained in:
Adolfo Victoria 2025-06-06 13:17:54 -07:00 committed by GitHub
parent aad30bd306
commit ca27f0f9c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 129 additions and 45 deletions

View File

@ -22,9 +22,11 @@ if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True)
TEXT_ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
disable_log_requests=True)
TEXT_ENGINE_ARGS = AsyncEngineArgs(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
disable_log_requests=True,
)
VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
enforce_eager=True,
@ -41,28 +43,33 @@ VISION_PROMPT = {
"prompt": VISION_PROMPT_TEMPLATE,
"multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image
}
},
}
async def generate(engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
n: int = 1,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
async def generate(
engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
n: int = 1,
prompt_logprobs: Optional[int] = None,
cancel_after: Optional[int] = None,
) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)
count = 0
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True,
output_kind=output_kind,
temperature=0.5,
seed=33,
n=n,
prompt_logprobs=prompt_logprobs)
sampling_params = SamplingParams(
max_tokens=max_tokens,
ignore_eos=True,
output_kind=output_kind,
temperature=0.5,
seed=33,
n=n,
prompt_logprobs=prompt_logprobs,
)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
@ -73,20 +80,27 @@ async def generate(engine: AsyncLLM,
else:
count = num_tokens
await asyncio.sleep(0.)
if cancel_after is not None and count >= cancel_after:
return count, request_id
await asyncio.sleep(0.0)
return count, request_id
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_load(monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
engine_args: AsyncEngineArgs, prompt: PromptType):
async def test_load(
monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
engine_args: AsyncEngineArgs,
prompt: PromptType,
):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
# tests.
@ -125,13 +139,17 @@ async def test_load(monkeypatch: pytest.MonkeyPatch,
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_abort(monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
engine_args: AsyncEngineArgs, prompt: PromptType):
async def test_abort(
monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
engine_args: AsyncEngineArgs,
prompt: PromptType,
):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
@ -150,8 +168,9 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch,
# Create concurrent requests.
tasks: list[asyncio.Task] = []
for idx, request_id in enumerate(request_ids):
max_tokens = NUM_EXPECTED_TOKENS_LONG if (
idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS
max_tokens = (NUM_EXPECTED_TOKENS_LONG if
(idx
in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
tasks.append(
asyncio.create_task(
@ -192,12 +211,17 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch,
@pytest.mark.parametrize("n", [1, 3])
@pytest.mark.parametrize("engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
engine_args: AsyncEngineArgs, prompt: PromptType):
async def test_finished_flag(
monkeypatch: pytest.MonkeyPatch,
n: int,
engine_args: AsyncEngineArgs,
prompt: PromptType,
):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
@ -205,11 +229,13 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=100,
output_kind=RequestOutputKind.DELTA,
temperature=1.0,
seed=33,
n=n)
sampling_params = SamplingParams(
max_tokens=100,
output_kind=RequestOutputKind.DELTA,
temperature=1.0,
seed=33,
n=n,
)
outputs = [
out
async for out in engine.generate(request_id="request-33",
@ -222,6 +248,63 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
assert outputs[-1].finished
@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
engine_args: AsyncEngineArgs,
prompt: PromptType):
"""Test that requests can be cancelled mid-stream."""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
NUM_REQUESTS = 100
NUM_TOKENS = 1000
NUM_EXPECTED_TOKENS = 20
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests that will be cancelled mid-stream
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(
engine,
request_id,
prompt,
RequestOutputKind.DELTA,
NUM_TOKENS,
cancel_after=NUM_EXPECTED_TOKENS,
)))
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
# Verify all tasks were cancelled at the expected point
for num_generated_tokens, request_id in results:
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} tokens but "
f"expected to cancel after {NUM_EXPECTED_TOKENS}")
# Make sure no requests are left hanging
assert not engine.output_processor.has_unfinished_requests()
# Confirm we can reuse the request id after the cancellations.
request_id = request_ids[0]
task = asyncio.create_task(
generate(engine, request_id, prompt, RequestOutputKind.DELTA,
NUM_EXPECTED_TOKENS))
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS
assert not engine.output_processor.has_unfinished_requests()
class MockLoggingStatLogger(LoggingStatLogger):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):

View File

@ -332,8 +332,9 @@ class AsyncLLM(EngineClient):
yield out
# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
# is cancelled or the generator is garbage collected. So,
# we abort the request if we end up here.
except (asyncio.CancelledError, GeneratorExit):
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)