mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 05:54:27 +08:00
[Bugfix] Properly abort pooling request. (#25734)
Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
515e30b023
commit
fb0eece290
@ -12,6 +12,7 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
STOP_STRINGS,
|
||||
DummyOutputProcessorTestVectors,
|
||||
MockEngineCore)
|
||||
from vllm import PoolingParams
|
||||
from vllm.logprobs import PromptLogprobs, SampleLogprobs
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
@ -998,3 +999,35 @@ async def test_cumulative_output_collector_n():
|
||||
third = [k for k in result.outputs if k.index == 2]
|
||||
assert len(third) == 1
|
||||
assert third[0].text == "c"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("runner", ["generate", "pooling"])
|
||||
def test_abort_requests(runner: str, dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=True)
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams() if runner == "generate" else None,
|
||||
pooling_params=PoolingParams(
|
||||
task="embed") if runner == "pooling" else None,
|
||||
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
for request in requests:
|
||||
if runner == "generate":
|
||||
output_kind = request.sampling_params.output_kind
|
||||
else:
|
||||
output_kind = request.pooling_params.output_kind
|
||||
queue = RequestOutputCollector(output_kind=output_kind)
|
||||
output_processor.add_request(request, None, queue=queue)
|
||||
|
||||
for request in requests:
|
||||
output_processor.abort_requests([request.request_id])
|
||||
|
||||
@ -335,7 +335,14 @@ class OutputProcessor:
|
||||
# Produce final abort output.
|
||||
if req_state.queue is not None and (
|
||||
request_output := req_state.make_request_output(
|
||||
[], None, FinishReason.ABORT, None, None)):
|
||||
new_token_ids=[],
|
||||
# Set pooling_output is not None to
|
||||
# correctly enter the abort pooling branch
|
||||
pooling_output=torch.randn(0, device="cpu")
|
||||
if req_state.detokenizer is None else None,
|
||||
finish_reason=FinishReason.ABORT,
|
||||
stop_reason=None,
|
||||
kv_transfer_params=None)):
|
||||
req_state.queue.put(request_output)
|
||||
elif parent := self.parent_requests.get(request_id):
|
||||
# Abort children prior to removing the parent.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user