mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:15:48 +08:00
[Bugfix][Core] Update cancellation logic in generate() to handle Generator exits (#19225)
Co-authored-by: Adolfo Victoria <adovi@meta.com>
This commit is contained in:
parent
aad30bd306
commit
ca27f0f9c1
@ -22,9 +22,11 @@ if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
allow_module_level=True)
|
||||
|
||||
TEXT_ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
disable_log_requests=True)
|
||||
TEXT_ENGINE_ARGS = AsyncEngineArgs(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
disable_log_requests=True,
|
||||
)
|
||||
|
||||
VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
|
||||
enforce_eager=True,
|
||||
@ -41,28 +43,33 @@ VISION_PROMPT = {
|
||||
"prompt": VISION_PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("stop_sign").pil_image
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def generate(engine: AsyncLLM,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int,
|
||||
n: int = 1,
|
||||
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
|
||||
async def generate(
|
||||
engine: AsyncLLM,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int,
|
||||
n: int = 1,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
cancel_after: Optional[int] = None,
|
||||
) -> tuple[int, str]:
|
||||
# Ensure generate doesn't complete too fast for cancellation test.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
count = 0
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True,
|
||||
output_kind=output_kind,
|
||||
temperature=0.5,
|
||||
seed=33,
|
||||
n=n,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=True,
|
||||
output_kind=output_kind,
|
||||
temperature=0.5,
|
||||
seed=33,
|
||||
n=n,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
async for out in engine.generate(request_id=request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params):
|
||||
@ -73,20 +80,27 @@ async def generate(engine: AsyncLLM,
|
||||
else:
|
||||
count = num_tokens
|
||||
|
||||
await asyncio.sleep(0.)
|
||||
if cancel_after is not None and count >= cancel_after:
|
||||
return count, request_id
|
||||
|
||||
await asyncio.sleep(0.0)
|
||||
|
||||
return count, request_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.parametrize("engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
|
||||
(VISION_ENGINE_ARGS, VISION_PROMPT)])
|
||||
@pytest.mark.parametrize(
|
||||
"engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(monkeypatch: pytest.MonkeyPatch,
|
||||
output_kind: RequestOutputKind,
|
||||
engine_args: AsyncEngineArgs, prompt: PromptType):
|
||||
async def test_load(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
output_kind: RequestOutputKind,
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompt: PromptType,
|
||||
):
|
||||
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
|
||||
# so that in the future when we switch, we don't have to change all the
|
||||
# tests.
|
||||
@ -125,13 +139,17 @@ async def test_load(monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.parametrize("engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
|
||||
(VISION_ENGINE_ARGS, VISION_PROMPT)])
|
||||
@pytest.mark.parametrize(
|
||||
"engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort(monkeypatch: pytest.MonkeyPatch,
|
||||
output_kind: RequestOutputKind,
|
||||
engine_args: AsyncEngineArgs, prompt: PromptType):
|
||||
async def test_abort(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
output_kind: RequestOutputKind,
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompt: PromptType,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
@ -150,8 +168,9 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch,
|
||||
# Create concurrent requests.
|
||||
tasks: list[asyncio.Task] = []
|
||||
for idx, request_id in enumerate(request_ids):
|
||||
max_tokens = NUM_EXPECTED_TOKENS_LONG if (
|
||||
idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS
|
||||
max_tokens = (NUM_EXPECTED_TOKENS_LONG if
|
||||
(idx
|
||||
in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
|
||||
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
@ -192,12 +211,17 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [1, 3])
|
||||
@pytest.mark.parametrize("engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
|
||||
(VISION_ENGINE_ARGS, VISION_PROMPT)])
|
||||
@pytest.mark.parametrize(
|
||||
"engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
|
||||
engine_args: AsyncEngineArgs, prompt: PromptType):
|
||||
async def test_finished_flag(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompt: PromptType,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
@ -205,11 +229,13 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=100,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
temperature=1.0,
|
||||
seed=33,
|
||||
n=n)
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=100,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
temperature=1.0,
|
||||
seed=33,
|
||||
n=n,
|
||||
)
|
||||
outputs = [
|
||||
out
|
||||
async for out in engine.generate(request_id="request-33",
|
||||
@ -222,6 +248,63 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
|
||||
assert outputs[-1].finished
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompt: PromptType):
|
||||
"""Test that requests can be cancelled mid-stream."""
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
NUM_TOKENS = 1000
|
||||
NUM_EXPECTED_TOKENS = 20
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests that will be cancelled mid-stream
|
||||
tasks = []
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(
|
||||
engine,
|
||||
request_id,
|
||||
prompt,
|
||||
RequestOutputKind.DELTA,
|
||||
NUM_TOKENS,
|
||||
cancel_after=NUM_EXPECTED_TOKENS,
|
||||
)))
|
||||
|
||||
# Wait for all tasks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all tasks were cancelled at the expected point
|
||||
for num_generated_tokens, request_id in results:
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||
f"{request_id} generated {num_generated_tokens} tokens but "
|
||||
f"expected to cancel after {NUM_EXPECTED_TOKENS}")
|
||||
|
||||
# Make sure no requests are left hanging
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
# Confirm we can reuse the request id after the cancellations.
|
||||
request_id = request_ids[0]
|
||||
task = asyncio.create_task(
|
||||
generate(engine, request_id, prompt, RequestOutputKind.DELTA,
|
||||
NUM_EXPECTED_TOKENS))
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
class MockLoggingStatLogger(LoggingStatLogger):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
|
||||
@ -332,8 +332,9 @@ class AsyncLLM(EngineClient):
|
||||
yield out
|
||||
|
||||
# If the request is disconnected by the client, generate()
|
||||
# is cancelled. So, we abort the request if we end up here.
|
||||
except asyncio.CancelledError:
|
||||
# is cancelled or the generator is garbage collected. So,
|
||||
# we abort the request if we end up here.
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
await self.abort(request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Request %s aborted.", request_id)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user