mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 20:27:08 +08:00
[Bugfix] Properly abort pooling request. (#25734)
Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
515e30b023
commit
fb0eece290
@ -12,6 +12,7 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
|||||||
STOP_STRINGS,
|
STOP_STRINGS,
|
||||||
DummyOutputProcessorTestVectors,
|
DummyOutputProcessorTestVectors,
|
||||||
MockEngineCore)
|
MockEngineCore)
|
||||||
|
from vllm import PoolingParams
|
||||||
from vllm.logprobs import PromptLogprobs, SampleLogprobs
|
from vllm.logprobs import PromptLogprobs, SampleLogprobs
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
@ -998,3 +999,35 @@ async def test_cumulative_output_collector_n():
|
|||||||
third = [k for k in result.outputs if k.index == 2]
|
third = [k for k in result.outputs if k.index == 2]
|
||||||
assert len(third) == 1
|
assert len(third) == 1
|
||||||
assert third[0].text == "c"
|
assert third[0].text == "c"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("runner", ["generate", "pooling"])
|
||||||
|
def test_abort_requests(runner: str, dummy_test_vectors):
|
||||||
|
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||||
|
log_stats=True)
|
||||||
|
requests = [
|
||||||
|
EngineCoreRequest(
|
||||||
|
request_id=f"request-{idx}",
|
||||||
|
prompt_token_ids=prompt_tokens,
|
||||||
|
mm_features=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
arrival_time=0,
|
||||||
|
lora_request=None,
|
||||||
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
|
sampling_params=SamplingParams() if runner == "generate" else None,
|
||||||
|
pooling_params=PoolingParams(
|
||||||
|
task="embed") if runner == "pooling" else None,
|
||||||
|
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
|
]
|
||||||
|
|
||||||
|
for request in requests:
|
||||||
|
if runner == "generate":
|
||||||
|
output_kind = request.sampling_params.output_kind
|
||||||
|
else:
|
||||||
|
output_kind = request.pooling_params.output_kind
|
||||||
|
queue = RequestOutputCollector(output_kind=output_kind)
|
||||||
|
output_processor.add_request(request, None, queue=queue)
|
||||||
|
|
||||||
|
for request in requests:
|
||||||
|
output_processor.abort_requests([request.request_id])
|
||||||
|
|||||||
@ -335,7 +335,14 @@ class OutputProcessor:
|
|||||||
# Produce final abort output.
|
# Produce final abort output.
|
||||||
if req_state.queue is not None and (
|
if req_state.queue is not None and (
|
||||||
request_output := req_state.make_request_output(
|
request_output := req_state.make_request_output(
|
||||||
[], None, FinishReason.ABORT, None, None)):
|
new_token_ids=[],
|
||||||
|
# Set pooling_output is not None to
|
||||||
|
# correctly enter the abort pooling branch
|
||||||
|
pooling_output=torch.randn(0, device="cpu")
|
||||||
|
if req_state.detokenizer is None else None,
|
||||||
|
finish_reason=FinishReason.ABORT,
|
||||||
|
stop_reason=None,
|
||||||
|
kv_transfer_params=None)):
|
||||||
req_state.queue.put(request_output)
|
req_state.queue.put(request_output)
|
||||||
elif parent := self.parent_requests.get(request_id):
|
elif parent := self.parent_requests.get(request_id):
|
||||||
# Abort children prior to removing the parent.
|
# Abort children prior to removing the parent.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user