From fb0eece2903f8550edd826d57f97421be5bbaf58 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 26 Sep 2025 20:47:34 +0800 Subject: [PATCH] [Bugfix] Properly abort pooling request. (#25734) Signed-off-by: wang.yuqi Co-authored-by: Cyrus Leung Signed-off-by: yewentao256 --- tests/v1/engine/test_output_processor.py | 33 ++++++++++++++++++++++++ vllm/v1/engine/output_processor.py | 9 ++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index bdb40be99aa3f..72c0a9a13e231 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -12,6 +12,7 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, STOP_STRINGS, DummyOutputProcessorTestVectors, MockEngineCore) +from vllm import PoolingParams from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -998,3 +999,35 @@ async def test_cumulative_output_collector_n(): third = [k for k in result.outputs if k.index == 2] assert len(third) == 1 assert third[0].text == "c" + + +@pytest.mark.parametrize("runner", ["generate", "pooling"]) +def test_abort_requests(runner: str, dummy_test_vectors): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, + log_stats=True) + requests = [ + EngineCoreRequest( + request_id=f"request-{idx}", + prompt_token_ids=prompt_tokens, + mm_features=None, + eos_token_id=None, + arrival_time=0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + sampling_params=SamplingParams() if runner == "generate" else None, + pooling_params=PoolingParams( + task="embed") if runner == "pooling" else None, + ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) + ] + + for request in requests: + if runner == "generate": + output_kind = request.sampling_params.output_kind + else: + output_kind = request.pooling_params.output_kind + queue = RequestOutputCollector(output_kind=output_kind) + output_processor.add_request(request, None, queue=queue) + + for request in requests: + output_processor.abort_requests([request.request_id]) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index c17dc3e204ecd..38b2d6824b473 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -335,7 +335,14 @@ class OutputProcessor: # Produce final abort output. if req_state.queue is not None and ( request_output := req_state.make_request_output( - [], None, FinishReason.ABORT, None, None)): + new_token_ids=[], + # Set pooling_output is not None to + # correctly enter the abort pooling branch + pooling_output=torch.randn(0, device="cpu") + if req_state.detokenizer is None else None, + finish_reason=FinishReason.ABORT, + stop_reason=None, + kv_transfer_params=None)): req_state.queue.put(request_output) elif parent := self.parent_requests.get(request_id): # Abort children prior to removing the parent.