[Core] Return final response for aborted requests from AsyncLLM.generate (#22283)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-08-14 14:49:02 -07:00 committed by GitHub
parent 4121de512e
commit ebcce2cd36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 107 additions and 13 deletions

View File

@ -13,6 +13,7 @@ from vllm.assets.image import ImageAsset
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
@ -398,3 +399,89 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch):
# Test 3: Verify healthy engine still works after mock
await engine.check_health()
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_abort_final_output(
monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind,
):
"""Test that abort() returns a final output with correct information."""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
request_id = "test-abort-final-output"
# Start a long-running request
sampling_params = SamplingParams(
max_tokens=3000, # Long enough to allow abort
ignore_eos=True,
output_kind=output_kind,
temperature=0.5,
seed=42,
)
outputs: list[RequestOutput] = []
generated = asyncio.create_task(
collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params,
outputs))
# Let it generate some tokens
await asyncio.sleep(0.5)
# Abort the request
await engine.abort(request_id)
# Wait for generation to complete and return final output
final_output = await generated
# Verify we got a final output
assert final_output is not None
assert final_output.finished
assert len(final_output.outputs) == 1
assert final_output.outputs[0].finish_reason == "abort"
assert final_output.outputs[0].stop_reason is None
# Verify num_cached_tokens is set correctly
assert hasattr(final_output, 'num_cached_tokens')
assert final_output.num_cached_tokens >= 0
# If we got intermediate outputs, verify they are consistent
if output_kind == RequestOutputKind.DELTA:
# For DELTA, sum all intermediate tokens should <= final tokens
token_count = sum(
len(output.outputs[0].token_ids) for output in outputs)
assert token_count > 0
assert len(final_output.outputs[0].token_ids) == 0
else:
# For FINAL_ONLY, we should only get the final output
assert len(outputs) == 0
assert len(final_output.outputs[0].token_ids) > 0
assert not engine.output_processor.has_unfinished_requests()
async def collect_outputs(
engine: AsyncLLM,
request_id: str,
prompt: PromptType,
sampling_params: SamplingParams,
outputs_list: list[RequestOutput],
) -> Optional[RequestOutput]:
"""Helper to collect outputs and return the final one."""
final_output: Optional[RequestOutput] = None
async for output in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
if not output.finished:
outputs_list.append(output)
final_output = output
return final_output

View File

@ -107,6 +107,7 @@ class RequestState:
self.max_tokens_param = max_tokens_param
self.is_prefilling = True
self.queue = queue
self.num_cached_tokens = 0
self.stats = RequestStateStats(
arrival_time=arrival_time) if log_stats else None
@ -167,7 +168,6 @@ class RequestState:
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Optional[Union[RequestOutput, PoolingRequestOutput]]:
finished = finish_reason is not None
@ -195,7 +195,7 @@ class RequestState:
return None
return self._new_request_output(request_id, outputs, finished,
kv_transfer_params, num_cached_tokens)
kv_transfer_params)
def _new_request_output(
self,
@ -203,14 +203,14 @@ class RequestState:
outputs: Union[list[CompletionOutput], list[PoolingOutput]],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Union[RequestOutput, PoolingRequestOutput]:
if isinstance(outputs[0], PoolingOutput):
first_output = outputs[0]
if isinstance(first_output, PoolingOutput):
assert len(outputs) == 1
return PoolingRequestOutput(
request_id=request_id,
outputs=outputs[0],
outputs=first_output,
prompt_token_ids=self.prompt_token_ids,
finished=finished,
)
@ -229,7 +229,7 @@ class RequestState:
outputs=cast(list[CompletionOutput], outputs),
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=num_cached_tokens,
num_cached_tokens=self.num_cached_tokens,
)
def _new_completion_output(
@ -308,11 +308,18 @@ class OutputProcessor:
if req_state is not None:
self.lora_states.abort_request(req_state)
request_ids_to_abort.append(request_id)
else:
parent = self.parent_requests.pop(request_id, None)
if parent and parent.child_requests:
self.abort_requests(parent.child_requests)
request_ids_to_abort.extend(parent.child_requests)
# Produce final abort output.
if req_state.queue is not None and (
request_output := req_state.make_request_output(
[], None, FinishReason.ABORT, None, None)):
req_state.queue.put(request_output)
elif parent := self.parent_requests.get(request_id):
# Abort children prior to removing the parent.
if parent.child_requests:
child_reqs = list(parent.child_requests)
child_reqs = self.abort_requests(child_reqs)
request_ids_to_abort.extend(child_reqs)
self.parent_requests.pop(request_id, None)
return request_ids_to_abort
def add_request(
@ -390,7 +397,7 @@ class OutputProcessor:
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
num_cached_tokens = engine_core_output.num_cached_tokens
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
if pooling_output is None:
@ -411,7 +418,7 @@ class OutputProcessor:
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, pooling_output, finish_reason, stop_reason,
kv_transfer_params, num_cached_tokens):
kv_transfer_params):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)