From ebcce2cd36a75effd10556942f0467f5f670a080 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 14 Aug 2025 14:49:02 -0700 Subject: [PATCH] [Core] Return final response for aborted requests from `AsyncLLM.generate` (#22283) Signed-off-by: Nick Hill --- tests/v1/engine/test_async_llm.py | 87 ++++++++++++++++++++++++++++++ vllm/v1/engine/output_processor.py | 33 +++++++----- 2 files changed, 107 insertions(+), 13 deletions(-) diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 21694491dd73a..484640233f522 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -13,6 +13,7 @@ from vllm.assets.image import ImageAsset from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType +from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.utils import set_default_torch_num_threads @@ -398,3 +399,89 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch): # Test 3: Verify healthy engine still works after mock await engine.check_health() + + +@pytest.mark.parametrize( + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) +@pytest.mark.asyncio +async def test_abort_final_output( + monkeypatch: pytest.MonkeyPatch, + output_kind: RequestOutputKind, +): + """Test that abort() returns a final output with correct information.""" + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) + after.callback(engine.shutdown) + + request_id = "test-abort-final-output" + + # Start a long-running request + sampling_params = SamplingParams( + max_tokens=3000, # Long enough to allow abort + ignore_eos=True, + output_kind=output_kind, + temperature=0.5, + seed=42, + ) + + outputs: list[RequestOutput] = [] + generated = asyncio.create_task( + collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, + outputs)) + + # Let it generate some tokens + await asyncio.sleep(0.5) + + # Abort the request + await engine.abort(request_id) + + # Wait for generation to complete and return final output + final_output = await generated + + # Verify we got a final output + assert final_output is not None + assert final_output.finished + assert len(final_output.outputs) == 1 + + assert final_output.outputs[0].finish_reason == "abort" + assert final_output.outputs[0].stop_reason is None + + # Verify num_cached_tokens is set correctly + assert hasattr(final_output, 'num_cached_tokens') + assert final_output.num_cached_tokens >= 0 + + # If we got intermediate outputs, verify they are consistent + if output_kind == RequestOutputKind.DELTA: + # For DELTA, sum all intermediate tokens should <= final tokens + token_count = sum( + len(output.outputs[0].token_ids) for output in outputs) + assert token_count > 0 + assert len(final_output.outputs[0].token_ids) == 0 + else: + # For FINAL_ONLY, we should only get the final output + assert len(outputs) == 0 + assert len(final_output.outputs[0].token_ids) > 0 + + assert not engine.output_processor.has_unfinished_requests() + + +async def collect_outputs( + engine: AsyncLLM, + request_id: str, + prompt: PromptType, + sampling_params: SamplingParams, + outputs_list: list[RequestOutput], +) -> Optional[RequestOutput]: + """Helper to collect outputs and return the final one.""" + final_output: Optional[RequestOutput] = None + async for output in engine.generate(request_id=request_id, + prompt=prompt, + sampling_params=sampling_params): + if not output.finished: + outputs_list.append(output) + final_output = output + return final_output diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 3be6c48212140..2ee55b585da6c 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -107,6 +107,7 @@ class RequestState: self.max_tokens_param = max_tokens_param self.is_prefilling = True self.queue = queue + self.num_cached_tokens = 0 self.stats = RequestStateStats( arrival_time=arrival_time) if log_stats else None @@ -167,7 +168,6 @@ class RequestState: finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], kv_transfer_params: Optional[dict[str, Any]] = None, - num_cached_tokens: int = 0, ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: finished = finish_reason is not None @@ -195,7 +195,7 @@ class RequestState: return None return self._new_request_output(request_id, outputs, finished, - kv_transfer_params, num_cached_tokens) + kv_transfer_params) def _new_request_output( self, @@ -203,14 +203,14 @@ class RequestState: outputs: Union[list[CompletionOutput], list[PoolingOutput]], finished: bool, kv_transfer_params: Optional[dict[str, Any]] = None, - num_cached_tokens: int = 0, ) -> Union[RequestOutput, PoolingRequestOutput]: - if isinstance(outputs[0], PoolingOutput): + first_output = outputs[0] + if isinstance(first_output, PoolingOutput): assert len(outputs) == 1 return PoolingRequestOutput( request_id=request_id, - outputs=outputs[0], + outputs=first_output, prompt_token_ids=self.prompt_token_ids, finished=finished, ) @@ -229,7 +229,7 @@ class RequestState: outputs=cast(list[CompletionOutput], outputs), finished=finished, kv_transfer_params=kv_transfer_params, - num_cached_tokens=num_cached_tokens, + num_cached_tokens=self.num_cached_tokens, ) def _new_completion_output( @@ -308,11 +308,18 @@ class OutputProcessor: if req_state is not None: self.lora_states.abort_request(req_state) request_ids_to_abort.append(request_id) - else: - parent = self.parent_requests.pop(request_id, None) - if parent and parent.child_requests: - self.abort_requests(parent.child_requests) - request_ids_to_abort.extend(parent.child_requests) + # Produce final abort output. + if req_state.queue is not None and ( + request_output := req_state.make_request_output( + [], None, FinishReason.ABORT, None, None)): + req_state.queue.put(request_output) + elif parent := self.parent_requests.get(request_id): + # Abort children prior to removing the parent. + if parent.child_requests: + child_reqs = list(parent.child_requests) + child_reqs = self.abort_requests(child_reqs) + request_ids_to_abort.extend(child_reqs) + self.parent_requests.pop(request_id, None) return request_ids_to_abort def add_request( @@ -390,7 +397,7 @@ class OutputProcessor: finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason kv_transfer_params = engine_core_output.kv_transfer_params - num_cached_tokens = engine_core_output.num_cached_tokens + req_state.num_cached_tokens = engine_core_output.num_cached_tokens req_state.is_prefilling = False if pooling_output is None: @@ -411,7 +418,7 @@ class OutputProcessor: # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( new_token_ids, pooling_output, finish_reason, stop_reason, - kv_transfer_params, num_cached_tokens): + kv_transfer_params): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output)