[Bugfix][Core] Update cancellation logic in generate() to handle Generator exits (#19225)

Co-authored-by: Adolfo Victoria <adovi@meta.com>
This commit is contained in:
Adolfo Victoria 2025-06-06 13:17:54 -07:00 committed by GitHub
parent aad30bd306
commit ca27f0f9c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 129 additions and 45 deletions

View File

@ -22,9 +22,11 @@ if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", pytest.skip(reason="V1 currently only supported on CUDA.",
allow_module_level=True) allow_module_level=True)
TEXT_ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B-Instruct", TEXT_ENGINE_ARGS = AsyncEngineArgs(
enforce_eager=True, model="meta-llama/Llama-3.2-1B-Instruct",
disable_log_requests=True) enforce_eager=True,
disable_log_requests=True,
)
VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct", VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
enforce_eager=True, enforce_eager=True,
@ -41,28 +43,33 @@ VISION_PROMPT = {
"prompt": VISION_PROMPT_TEMPLATE, "prompt": VISION_PROMPT_TEMPLATE,
"multi_modal_data": { "multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image "image": ImageAsset("stop_sign").pil_image
} },
} }
async def generate(engine: AsyncLLM, async def generate(
request_id: str, engine: AsyncLLM,
prompt: PromptType, request_id: str,
output_kind: RequestOutputKind, prompt: PromptType,
max_tokens: int, output_kind: RequestOutputKind,
n: int = 1, max_tokens: int,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]: n: int = 1,
prompt_logprobs: Optional[int] = None,
cancel_after: Optional[int] = None,
) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test. # Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
count = 0 count = 0
sampling_params = SamplingParams(max_tokens=max_tokens, sampling_params = SamplingParams(
ignore_eos=True, max_tokens=max_tokens,
output_kind=output_kind, ignore_eos=True,
temperature=0.5, output_kind=output_kind,
seed=33, temperature=0.5,
n=n, seed=33,
prompt_logprobs=prompt_logprobs) n=n,
prompt_logprobs=prompt_logprobs,
)
async for out in engine.generate(request_id=request_id, async for out in engine.generate(request_id=request_id,
prompt=prompt, prompt=prompt,
sampling_params=sampling_params): sampling_params=sampling_params):
@ -73,20 +80,27 @@ async def generate(engine: AsyncLLM,
else: else:
count = num_tokens count = num_tokens
await asyncio.sleep(0.) if cancel_after is not None and count >= cancel_after:
return count, request_id
await asyncio.sleep(0.0)
return count, request_id return count, request_id
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("engine_args,prompt", @pytest.mark.parametrize(
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), "engine_args,prompt",
(VISION_ENGINE_ARGS, VISION_PROMPT)]) [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load(monkeypatch: pytest.MonkeyPatch, async def test_load(
output_kind: RequestOutputKind, monkeypatch: pytest.MonkeyPatch,
engine_args: AsyncEngineArgs, prompt: PromptType): output_kind: RequestOutputKind,
engine_args: AsyncEngineArgs,
prompt: PromptType,
):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1 # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the # so that in the future when we switch, we don't have to change all the
# tests. # tests.
@ -125,13 +139,17 @@ async def test_load(monkeypatch: pytest.MonkeyPatch,
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("engine_args,prompt", @pytest.mark.parametrize(
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), "engine_args,prompt",
(VISION_ENGINE_ARGS, VISION_PROMPT)]) [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_abort(monkeypatch: pytest.MonkeyPatch, async def test_abort(
output_kind: RequestOutputKind, monkeypatch: pytest.MonkeyPatch,
engine_args: AsyncEngineArgs, prompt: PromptType): output_kind: RequestOutputKind,
engine_args: AsyncEngineArgs,
prompt: PromptType,
):
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
@ -150,8 +168,9 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch,
# Create concurrent requests. # Create concurrent requests.
tasks: list[asyncio.Task] = [] tasks: list[asyncio.Task] = []
for idx, request_id in enumerate(request_ids): for idx, request_id in enumerate(request_ids):
max_tokens = NUM_EXPECTED_TOKENS_LONG if ( max_tokens = (NUM_EXPECTED_TOKENS_LONG if
idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS (idx
in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
@ -192,12 +211,17 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch,
@pytest.mark.parametrize("n", [1, 3]) @pytest.mark.parametrize("n", [1, 3])
@pytest.mark.parametrize("engine_args,prompt", @pytest.mark.parametrize(
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), "engine_args,prompt",
(VISION_ENGINE_ARGS, VISION_PROMPT)]) [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int, async def test_finished_flag(
engine_args: AsyncEngineArgs, prompt: PromptType): monkeypatch: pytest.MonkeyPatch,
n: int,
engine_args: AsyncEngineArgs,
prompt: PromptType,
):
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
@ -205,11 +229,13 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=100, sampling_params = SamplingParams(
output_kind=RequestOutputKind.DELTA, max_tokens=100,
temperature=1.0, output_kind=RequestOutputKind.DELTA,
seed=33, temperature=1.0,
n=n) seed=33,
n=n,
)
outputs = [ outputs = [
out out
async for out in engine.generate(request_id="request-33", async for out in engine.generate(request_id="request-33",
@ -222,6 +248,63 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
assert outputs[-1].finished assert outputs[-1].finished
@pytest.mark.parametrize(
"engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
engine_args: AsyncEngineArgs,
prompt: PromptType):
"""Test that requests can be cancelled mid-stream."""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
NUM_REQUESTS = 100
NUM_TOKENS = 1000
NUM_EXPECTED_TOKENS = 20
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests that will be cancelled mid-stream
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(
engine,
request_id,
prompt,
RequestOutputKind.DELTA,
NUM_TOKENS,
cancel_after=NUM_EXPECTED_TOKENS,
)))
# Wait for all tasks to complete
results = await asyncio.gather(*tasks)
# Verify all tasks were cancelled at the expected point
for num_generated_tokens, request_id in results:
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
f"{request_id} generated {num_generated_tokens} tokens but "
f"expected to cancel after {NUM_EXPECTED_TOKENS}")
# Make sure no requests are left hanging
assert not engine.output_processor.has_unfinished_requests()
# Confirm we can reuse the request id after the cancellations.
request_id = request_ids[0]
task = asyncio.create_task(
generate(engine, request_id, prompt, RequestOutputKind.DELTA,
NUM_EXPECTED_TOKENS))
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS
assert not engine.output_processor.has_unfinished_requests()
class MockLoggingStatLogger(LoggingStatLogger): class MockLoggingStatLogger(LoggingStatLogger):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):

View File

@ -332,8 +332,9 @@ class AsyncLLM(EngineClient):
yield out yield out
# If the request is disconnected by the client, generate() # If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here. # is cancelled or the generator is garbage collected. So,
except asyncio.CancelledError: # we abort the request if we end up here.
except (asyncio.CancelledError, GeneratorExit):
await self.abort(request_id) await self.abort(request_id)
if self.log_requests: if self.log_requests:
logger.info("Request %s aborted.", request_id) logger.info("Request %s aborted.", request_id)