[Bugfix] Properly abort pooling request. (#25734)

Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
wang.yuqi 2025-09-26 20:47:34 +08:00 committed by yewentao256
parent 515e30b023
commit fb0eece290
2 changed files with 41 additions and 1 deletions

View File

@ -12,6 +12,7 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
STOP_STRINGS,
DummyOutputProcessorTestVectors,
MockEngineCore)
from vllm import PoolingParams
from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
@ -998,3 +999,35 @@ async def test_cumulative_output_collector_n():
third = [k for k in result.outputs if k.index == 2]
assert len(third) == 1
assert third[0].text == "c"
@pytest.mark.parametrize("runner", ["generate", "pooling"])
def test_abort_requests(runner: str, dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
log_stats=True)
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams() if runner == "generate" else None,
pooling_params=PoolingParams(
task="embed") if runner == "pooling" else None,
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
for request in requests:
if runner == "generate":
output_kind = request.sampling_params.output_kind
else:
output_kind = request.pooling_params.output_kind
queue = RequestOutputCollector(output_kind=output_kind)
output_processor.add_request(request, None, queue=queue)
for request in requests:
output_processor.abort_requests([request.request_id])

View File

@ -335,7 +335,14 @@ class OutputProcessor:
# Produce final abort output.
if req_state.queue is not None and (
request_output := req_state.make_request_output(
[], None, FinishReason.ABORT, None, None)):
new_token_ids=[],
# Set pooling_output is not None to
# correctly enter the abort pooling branch
pooling_output=torch.randn(0, device="cpu")
if req_state.detokenizer is None else None,
finish_reason=FinishReason.ABORT,
stop_reason=None,
kv_transfer_params=None)):
req_state.queue.put(request_output)
elif parent := self.parent_requests.get(request_id):
# Abort children prior to removing the parent.