From f5d3acd47466f094beb36f7a5d05520466713f93 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 12 Mar 2025 13:29:48 -0400 Subject: [PATCH] [BugFix][V1] Fix parallel sampling finishing/aborts (#14512) Signed-off-by: Nick Hill --- tests/v1/engine/test_async_llm.py | 56 ++++++++++++++-- .../v1/entrypoints/openai/test_completion.py | 21 ++++-- vllm/outputs.py | 64 ++++++------------- vllm/v1/engine/async_llm.py | 3 +- vllm/v1/engine/llm_engine.py | 2 +- vllm/v1/engine/output_processor.py | 49 ++++++++------ vllm/v1/engine/parallel_sampling.py | 55 +++++++--------- 7 files changed, 137 insertions(+), 113 deletions(-) diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 0de0026eb2842..5b9725d59ddc5 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -46,6 +46,7 @@ async def generate(engine: AsyncLLM, prompt: PromptType, output_kind: RequestOutputKind, max_tokens: int, + n: int = 1, prompt_logprobs: Optional[int] = None) -> tuple[int, str]: # Ensure generate doesn't complete too fast for cancellation test. await asyncio.sleep(0.2) @@ -54,13 +55,15 @@ async def generate(engine: AsyncLLM, sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True, output_kind=output_kind, - temperature=0, + temperature=0.5, + seed=33, + n=n, prompt_logprobs=prompt_logprobs) async for out in engine.generate(request_id=request_id, prompt=prompt, sampling_params=sampling_params): - num_tokens = len(out.outputs[0].token_ids) + num_tokens = sum(len(output.token_ids) for output in out.outputs) if output_kind == RequestOutputKind.DELTA: count += num_tokens else: @@ -136,17 +139,22 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind, NUM_REQUESTS = 100 NUM_EXPECTED_TOKENS = 100 + NUM_EXPECTED_TOKENS_LONG = 50000 REQUEST_IDS_TO_ABORT = range(1, 100, 10) + PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15) request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)] # Create concurrent requests. tasks: list[asyncio.Task] = [] - for request_id in request_ids: + for idx, request_id in enumerate(request_ids): + max_tokens = NUM_EXPECTED_TOKENS_LONG if ( + idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS + n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 tasks.append( asyncio.create_task( generate(engine, request_id, prompt, output_kind, - NUM_EXPECTED_TOKENS))) + max_tokens, n))) # API server cancels requests when they disconnect. for idx in REQUEST_IDS_TO_ABORT: @@ -162,10 +170,13 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind, else: # Otherwise, make sure the request was not impacted. num_generated_tokens, request_id = await task - assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( + n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 + expected_tokens = NUM_EXPECTED_TOKENS * n + assert num_generated_tokens == expected_tokens, ( f"{request_id} generated {num_generated_tokens} but " - f"expected {NUM_EXPECTED_TOKENS}") + f"expected {expected_tokens}") + # Make sure all aborted requests were really aborted. assert not engine.output_processor.has_unfinished_requests() # Confirm we can do another generation. @@ -176,3 +187,36 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind, num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS assert not engine.output_processor.has_unfinished_requests() + + +@pytest.mark.parametrize("n", [1, 3]) +@pytest.mark.parametrize("engine_args_and_prompt", + [(TEXT_ENGINE_ARGS, TEXT_PROMPT), + (VISION_ENGINE_ARGS, VISION_PROMPT)]) +@pytest.mark.asyncio +async def test_finished_flag(monkeypatch, n: int, + engine_args_and_prompt: tuple[AsyncEngineArgs, + PromptType]): + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + engine_args, prompt = engine_args_and_prompt + + engine = AsyncLLM.from_engine_args(engine_args) + after.callback(engine.shutdown) + + sampling_params = SamplingParams(max_tokens=100, + output_kind=RequestOutputKind.DELTA, + temperature=1.0, + seed=33, + n=n) + outputs = [ + out + async for out in engine.generate(request_id="request-33", + prompt=prompt, + sampling_params=sampling_params) + ] + + # Assert only the last output has the finished flag set + assert all(not out.finished for out in outputs[:-1]) + assert outputs[-1].finished diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 171c84176eae7..57ca99e1f68c6 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -263,15 +263,16 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, prompt = "What is an LLM?" n = 3 - max_tokens = 5 + max_tokens = 50 # we want some to finish earlier than others # High temperature to maximize chance of unique completions. completion = await client.completions.create(model=model_name, prompt=prompt, max_tokens=max_tokens, n=n, - temperature=0.95, + temperature=1.0, stream=False, + logprobs=0, seed=42) # Assert `n` completions @@ -279,6 +280,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, assert num_completions == n, ( f"Num completions {num_completions} but expected {n}.") completion_repeats: dict[str, int] = {} + output_token_lengths = set() for idx, choice in enumerate(completion.choices): # Assert correct completion index & some finish reason. assert choice.index == idx, ( @@ -287,6 +289,9 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, "None finish_reason is invalid.") text = choice.text completion_repeats[text] = completion_repeats.get(text, 0) + 1 + output_token_lengths.add(len(choice.logprobs.tokens)) + # Assert subrequests finished at different times + assert len(output_token_lengths) > 1 # Assert `n` unique completions num_unique = len(completion_repeats) if num_unique != n: @@ -312,16 +317,16 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): prompt = "What is an LLM?" n = 3 - max_tokens = 5 + max_tokens = 50 # we want some to finish earlier than others stream = await client.completions.create(model=model_name, prompt=prompt, max_tokens=max_tokens, n=n, - temperature=0.95, + temperature=1.0, stream=True, seed=42) - chunks: list[list[str]] = [[] for i in range(n)] + chunks: list[list[str]] = [[] for _ in range(n)] finish_reason_count = 0 async for chunk in stream: index = chunk.choices[0].index @@ -333,14 +338,18 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): assert finish_reason_count == n, ( f"Expected {n} completions with valid indices and finish_reason.") completion_repeats: dict[str, int] = {} + chunk_lengths = set() for chunk in chunks: chunk_len = len(chunk) # Assert correct number of completion tokens - assert chunk_len == max_tokens, ( + chunk_lengths.add(chunk_len) + assert chunk_len <= max_tokens, ( f"max_tokens={max_tokens} but chunk len is {chunk_len}.") text = "".join(chunk) completion_repeats[text] = completion_repeats.get(text, 0) + 1 print(text) + # Assert subrequests finished at different times + assert len(chunk_lengths) > 1 # Assert `n` unique completions num_unique = len(completion_repeats) if num_unique != n: diff --git a/vllm/outputs.py b/vllm/outputs.py index 8c355c89e3e9b..7a20c340edcf7 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -134,57 +134,29 @@ class RequestOutput: self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens - @classmethod - def new( - cls, - request_id: str, - prompt: Optional[str], - prompt_token_ids: Optional[list[int]], - text: str, - token_ids: list[int], - logprobs: Optional[SampleLogprobs], - prompt_logprobs: Optional[PromptLogprobs], - cumulative_logprob: Optional[float], - finished: bool = False, - ) -> "RequestOutput": - """Initialize a new RequestOutput object.""" - - # TODO: Support `n` > 1. - completion_output = CompletionOutput( - index=0, - text=text, - token_ids=token_ids, - cumulative_logprob=cumulative_logprob, - logprobs=logprobs) - - return RequestOutput( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=prompt_logprobs, - outputs=[completion_output], - finished=finished, - ) - def add(self, next_output: "RequestOutput") -> None: """Merge subsequent RequestOutput into this one""" - self.prompt = next_output.prompt - self.prompt_token_ids = next_output.prompt_token_ids - self.prompt_logprobs = next_output.prompt_logprobs self.finished |= next_output.finished - #TODO assuming n == 1 for now - completion = self.outputs[0] - next_completion = next_output.outputs[0] - completion.text += next_completion.text - if not isinstance(completion.token_ids, MutableSequence): - completion.token_ids = list(completion.token_ids) - completion.token_ids.extend(next_completion.token_ids) - if next_completion.logprobs: - assert completion.logprobs is not None - completion.logprobs.extend(next_completion.logprobs) - completion.cumulative_logprob = next_completion.cumulative_logprob + for next_completion in next_output.outputs: + for completion in self.outputs: + if completion.index == next_completion.index: + # Merge outputs with same index + completion.text += next_completion.text + if not isinstance(completion.token_ids, MutableSequence): + completion.token_ids = list(completion.token_ids) + completion.token_ids.extend(next_completion.token_ids) + if next_completion.logprobs: + assert completion.logprobs is not None + completion.logprobs.extend(next_completion.logprobs) + completion.cumulative_logprob = ( + next_completion.cumulative_logprob) + completion.finish_reason = next_completion.finish_reason + completion.stop_reason = next_completion.stop_reason + break + else: + self.outputs.append(next_completion) @classmethod def from_seq_group( diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3dc513a728339..05633352be6c0 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -298,9 +298,8 @@ class AsyncLLM(EngineClient): async def abort(self, request_id: str) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" - request_ids = [request_id] + request_ids = self.output_processor.abort_requests((request_id, )) await self.engine_core.abort_requests_async(request_ids) - self.output_processor.abort_requests(request_ids) if self.log_requests: logger.info("Aborted request %s.", request_id) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 213faaa451605..d56aee1accc2d 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -137,8 +137,8 @@ class LLMEngine: def abort_request(self, request_ids: list[str]) -> None: """Remove request_ids from EngineCore and Detokenizer.""" + request_ids = self.output_processor.abort_requests(request_ids) self.engine_core.abort_requests(request_ids) - self.output_processor.abort_requests(request_ids) def add_request( self, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index aea526188a8f5..83180b66bea0d 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +from collections.abc import Iterable from dataclasses import dataclass from typing import Optional, Union @@ -102,8 +103,7 @@ class RequestState: ) -> Optional[RequestOutput]: finished = finish_reason is not None - output_kind = self.output_kind - final_only = output_kind == RequestOutputKind.FINAL_ONLY + final_only = self.output_kind == RequestOutputKind.FINAL_ONLY # In follow up, we will switch to invariant where EngineCore # does not stream partial prefills. @@ -111,24 +111,24 @@ class RequestState: # Only the final output is required in FINAL_ONLY mode. return None - def new_request_output(request_id: str) -> RequestOutput: - return self._new_request_output(request_id, finished) - completion_output = self._new_completion_output( new_token_ids, finish_reason, stop_reason) - if self.parent_req is not None: - return self.parent_req.make_request_output(final_only, - completion_output, - new_request_output) + request_id = self.request_id + if self.parent_req is None: + outputs = [completion_output] + else: + request_id, outputs, finished = self.parent_req.get_outputs( + request_id, completion_output) + if not outputs: + return None - request_output = new_request_output(self.request_id) - request_output.outputs.append(completion_output) - return request_output + return self._new_request_output(request_id, outputs, finished) def _new_request_output( self, request_id: str, + outputs: list[CompletionOutput], finished: bool, ) -> RequestOutput: @@ -143,7 +143,7 @@ class RequestState: prompt=self.prompt, prompt_token_ids=self.prompt_token_ids, prompt_logprobs=prompt_logprobs, - outputs=[], + outputs=outputs, finished=finished, ) @@ -188,6 +188,7 @@ class OutputProcessor: self.log_stats = log_stats self.tokenizer = tokenizer self.request_states: dict[str, RequestState] = {} + self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates() def get_num_unfinished_requests(self): @@ -198,14 +199,20 @@ class OutputProcessor: def abort_requests( self, - request_ids: list[str], - ) -> None: + request_ids: Iterable[str], + ) -> list[str]: + request_ids_to_abort = [] for request_id in request_ids: req_state = self.request_states.pop(request_id, None) if req_state is not None: self.lora_states.abort_request(req_state) - if req_state.parent_req is not None: - req_state.parent_req.finish_child_request(request_id) + request_ids_to_abort.append(request_id) + else: + parent = self.parent_requests.pop(request_id, None) + if parent and parent.child_requests: + self.abort_requests(parent.child_requests) + request_ids_to_abort.extend(parent.child_requests) + return request_ids_to_abort def add_request( self, @@ -227,6 +234,8 @@ class OutputProcessor: log_stats=self.log_stats) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) + if parent_req: + self.parent_requests[parent_req.request_id] = parent_req def process_outputs( self, @@ -314,12 +323,14 @@ class OutputProcessor: # Free completed requests. if finish_reason is not None: self.request_states.pop(req_id) + # Remove parent request if applicable. + parent_req = req_state.parent_req + if parent_req and not parent_req.child_requests: + self.parent_requests.pop(parent_req.request_id, None) if not engine_core_output.finished: # If req not finished in EngineCore, but Detokenizer # detected stop string, abort needed in EngineCore. reqs_to_abort.append(req_id) - if req_state.parent_req is not None: - req_state.parent_req.finish_child_request(req_id) # Track per-request stats self._update_stats_from_finished(req_state, finish_reason, diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 4e2c78173b513..0eeca657406e5 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 from copy import copy -from typing import Callable, Optional, Union +from typing import Optional, Union -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import CompletionOutput from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.v1.metrics.stats import IterationStats @@ -23,7 +23,7 @@ class ParentRequest: child_requests: set[str] # To aggregate child completions when not streaming - output_aggregator: Optional[RequestOutput] + output_aggregator: list[CompletionOutput] # To find the max number of generated tokens across all children max_num_generation_tokens: int @@ -37,7 +37,9 @@ class ParentRequest: self.sampling_params = sampling_params self.child_requests = set() - self.output_aggregator = None + self.output_aggregator = [None] * sampling_params.n if ( + sampling_params.output_kind + == RequestOutputKind.FINAL_ONLY) else [] self.max_num_generation_tokens = 0 self.cached_child_sampling_params = None @@ -93,43 +95,30 @@ class ParentRequest: """ child_req_id = f"{index}_{self.request_id}" self.child_requests.add(child_req_id) - return (child_req_id, self._get_child_sampling_params(index)) - - def finish_child_request(self, req_id: str): - self.child_requests.remove(req_id) + return child_req_id, self._get_child_sampling_params(index) @property def n(self) -> int: return self.sampling_params.n - def make_request_output( + def get_outputs( self, - final_only: bool, + child_request_id: str, completion_output: CompletionOutput, - new_request_output: Callable[[str], RequestOutput], - ) -> Optional[RequestOutput]: - # Use an existing RequestOutput if we're aggregating - request_output = self.output_aggregator + ) -> tuple[str, list[CompletionOutput], bool]: + if completion_output.finished(): + self.child_requests.remove(child_request_id) - # Make new RequestOutput otherwise - if request_output is None: - request_output = new_request_output(self.request_id) + if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY: + # If streaming, just return the current output. + outputs = [completion_output] + else: + # If not streaming, aggregate the n final outputs. + self.output_aggregator[completion_output.index] = completion_output + outputs = [] if self.child_requests else self.output_aggregator - # Add a new completion - request_output.outputs.append(completion_output) - - # If not streaming, aggregate until all child requests complete - if final_only and len(request_output.outputs) != self.n: - self.output_aggregator = request_output - return None - - # We're done aggregating - self.output_aggregator = None - - # Parent completion output list must be sorted by index - request_output.outputs = sorted(request_output.outputs, - key=lambda x: x.index) - return request_output + finished = not self.child_requests + return self.request_id, outputs, finished def observe_num_generation_tokens(self, num_generation_tokens: int): self.max_num_generation_tokens = max(num_generation_tokens,