[BugFix][V1] Fix parallel sampling finishing/aborts (#14512)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-03-12 13:29:48 -04:00 committed by GitHub
parent 916836bbfb
commit f5d3acd474
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 137 additions and 113 deletions

View File

@ -46,6 +46,7 @@ async def generate(engine: AsyncLLM,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
n: int = 1,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)
@ -54,13 +55,15 @@ async def generate(engine: AsyncLLM,
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True,
output_kind=output_kind,
temperature=0,
temperature=0.5,
seed=33,
n=n,
prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
num_tokens = len(out.outputs[0].token_ids)
num_tokens = sum(len(output.token_ids) for output in out.outputs)
if output_kind == RequestOutputKind.DELTA:
count += num_tokens
else:
@ -136,17 +139,22 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
NUM_REQUESTS = 100
NUM_EXPECTED_TOKENS = 100
NUM_EXPECTED_TOKENS_LONG = 50000
REQUEST_IDS_TO_ABORT = range(1, 100, 10)
PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks: list[asyncio.Task] = []
for request_id in request_ids:
for idx, request_id in enumerate(request_ids):
max_tokens = NUM_EXPECTED_TOKENS_LONG if (
idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
tasks.append(
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))
max_tokens, n)))
# API server cancels requests when they disconnect.
for idx in REQUEST_IDS_TO_ABORT:
@ -162,10 +170,13 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
else:
# Otherwise, make sure the request was not impacted.
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
expected_tokens = NUM_EXPECTED_TOKENS * n
assert num_generated_tokens == expected_tokens, (
f"{request_id} generated {num_generated_tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
f"expected {expected_tokens}")
# Make sure all aborted requests were really aborted.
assert not engine.output_processor.has_unfinished_requests()
# Confirm we can do another generation.
@ -176,3 +187,36 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
num_generated_tokens, request_id = await task
assert num_generated_tokens == NUM_EXPECTED_TOKENS
assert not engine.output_processor.has_unfinished_requests()
@pytest.mark.parametrize("n", [1, 3])
@pytest.mark.parametrize("engine_args_and_prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio
async def test_finished_flag(monkeypatch, n: int,
engine_args_and_prompt: tuple[AsyncEngineArgs,
PromptType]):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine_args, prompt = engine_args_and_prompt
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=100,
output_kind=RequestOutputKind.DELTA,
temperature=1.0,
seed=33,
n=n)
outputs = [
out
async for out in engine.generate(request_id="request-33",
prompt=prompt,
sampling_params=sampling_params)
]
# Assert only the last output has the finished flag set
assert all(not out.finished for out in outputs[:-1])
assert outputs[-1].finished

View File

@ -263,15 +263,16 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
prompt = "What is an LLM?"
n = 3
max_tokens = 5
max_tokens = 50 # we want some to finish earlier than others
# High temperature to maximize chance of unique completions.
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
temperature=0.95,
temperature=1.0,
stream=False,
logprobs=0,
seed=42)
# Assert `n` completions
@ -279,6 +280,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
assert num_completions == n, (
f"Num completions {num_completions} but expected {n}.")
completion_repeats: dict[str, int] = {}
output_token_lengths = set()
for idx, choice in enumerate(completion.choices):
# Assert correct completion index & some finish reason.
assert choice.index == idx, (
@ -287,6 +289,9 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
"None finish_reason is invalid.")
text = choice.text
completion_repeats[text] = completion_repeats.get(text, 0) + 1
output_token_lengths.add(len(choice.logprobs.tokens))
# Assert subrequests finished at different times
assert len(output_token_lengths) > 1
# Assert `n` unique completions
num_unique = len(completion_repeats)
if num_unique != n:
@ -312,16 +317,16 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
prompt = "What is an LLM?"
n = 3
max_tokens = 5
max_tokens = 50 # we want some to finish earlier than others
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
temperature=0.95,
temperature=1.0,
stream=True,
seed=42)
chunks: list[list[str]] = [[] for i in range(n)]
chunks: list[list[str]] = [[] for _ in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
@ -333,14 +338,18 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
assert finish_reason_count == n, (
f"Expected {n} completions with valid indices and finish_reason.")
completion_repeats: dict[str, int] = {}
chunk_lengths = set()
for chunk in chunks:
chunk_len = len(chunk)
# Assert correct number of completion tokens
assert chunk_len == max_tokens, (
chunk_lengths.add(chunk_len)
assert chunk_len <= max_tokens, (
f"max_tokens={max_tokens} but chunk len is {chunk_len}.")
text = "".join(chunk)
completion_repeats[text] = completion_repeats.get(text, 0) + 1
print(text)
# Assert subrequests finished at different times
assert len(chunk_lengths) > 1
# Assert `n` unique completions
num_unique = len(completion_repeats)
if num_unique != n:

View File

@ -134,57 +134,29 @@ class RequestOutput:
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
@classmethod
def new(
cls,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[list[int]],
text: str,
token_ids: list[int],
logprobs: Optional[SampleLogprobs],
prompt_logprobs: Optional[PromptLogprobs],
cumulative_logprob: Optional[float],
finished: bool = False,
) -> "RequestOutput":
"""Initialize a new RequestOutput object."""
# TODO: Support `n` > 1.
completion_output = CompletionOutput(
index=0,
text=text,
token_ids=token_ids,
cumulative_logprob=cumulative_logprob,
logprobs=logprobs)
return RequestOutput(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=[completion_output],
finished=finished,
)
def add(self, next_output: "RequestOutput") -> None:
"""Merge subsequent RequestOutput into this one"""
self.prompt = next_output.prompt
self.prompt_token_ids = next_output.prompt_token_ids
self.prompt_logprobs = next_output.prompt_logprobs
self.finished |= next_output.finished
#TODO assuming n == 1 for now
completion = self.outputs[0]
next_completion = next_output.outputs[0]
completion.text += next_completion.text
if not isinstance(completion.token_ids, MutableSequence):
completion.token_ids = list(completion.token_ids)
completion.token_ids.extend(next_completion.token_ids)
if next_completion.logprobs:
assert completion.logprobs is not None
completion.logprobs.extend(next_completion.logprobs)
completion.cumulative_logprob = next_completion.cumulative_logprob
for next_completion in next_output.outputs:
for completion in self.outputs:
if completion.index == next_completion.index:
# Merge outputs with same index
completion.text += next_completion.text
if not isinstance(completion.token_ids, MutableSequence):
completion.token_ids = list(completion.token_ids)
completion.token_ids.extend(next_completion.token_ids)
if next_completion.logprobs:
assert completion.logprobs is not None
completion.logprobs.extend(next_completion.logprobs)
completion.cumulative_logprob = (
next_completion.cumulative_logprob)
completion.finish_reason = next_completion.finish_reason
completion.stop_reason = next_completion.stop_reason
break
else:
self.outputs.append(next_completion)
@classmethod
def from_seq_group(

View File

@ -298,9 +298,8 @@ class AsyncLLM(EngineClient):
async def abort(self, request_id: str) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""
request_ids = [request_id]
request_ids = self.output_processor.abort_requests((request_id, ))
await self.engine_core.abort_requests_async(request_ids)
self.output_processor.abort_requests(request_ids)
if self.log_requests:
logger.info("Aborted request %s.", request_id)

View File

@ -137,8 +137,8 @@ class LLMEngine:
def abort_request(self, request_ids: list[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
request_ids = self.output_processor.abort_requests(request_ids)
self.engine_core.abort_requests(request_ids)
self.output_processor.abort_requests(request_ids)
def add_request(
self,

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional, Union
@ -102,8 +103,7 @@ class RequestState:
) -> Optional[RequestOutput]:
finished = finish_reason is not None
output_kind = self.output_kind
final_only = output_kind == RequestOutputKind.FINAL_ONLY
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
@ -111,24 +111,24 @@ class RequestState:
# Only the final output is required in FINAL_ONLY mode.
return None
def new_request_output(request_id: str) -> RequestOutput:
return self._new_request_output(request_id, finished)
completion_output = self._new_completion_output(
new_token_ids, finish_reason, stop_reason)
if self.parent_req is not None:
return self.parent_req.make_request_output(final_only,
completion_output,
new_request_output)
request_id = self.request_id
if self.parent_req is None:
outputs = [completion_output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, completion_output)
if not outputs:
return None
request_output = new_request_output(self.request_id)
request_output.outputs.append(completion_output)
return request_output
return self._new_request_output(request_id, outputs, finished)
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
finished: bool,
) -> RequestOutput:
@ -143,7 +143,7 @@ class RequestState:
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=[],
outputs=outputs,
finished=finished,
)
@ -188,6 +188,7 @@ class OutputProcessor:
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
def get_num_unfinished_requests(self):
@ -198,14 +199,20 @@ class OutputProcessor:
def abort_requests(
self,
request_ids: list[str],
) -> None:
request_ids: Iterable[str],
) -> list[str]:
request_ids_to_abort = []
for request_id in request_ids:
req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.abort_request(req_state)
if req_state.parent_req is not None:
req_state.parent_req.finish_child_request(request_id)
request_ids_to_abort.append(request_id)
else:
parent = self.parent_requests.pop(request_id, None)
if parent and parent.child_requests:
self.abort_requests(parent.child_requests)
request_ids_to_abort.extend(parent.child_requests)
return request_ids_to_abort
def add_request(
self,
@ -227,6 +234,8 @@ class OutputProcessor:
log_stats=self.log_stats)
self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
def process_outputs(
self,
@ -314,12 +323,14 @@ class OutputProcessor:
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
if req_state.parent_req is not None:
req_state.parent_req.finish_child_request(req_id)
# Track per-request stats
self._update_stats_from_finished(req_state, finish_reason,

View File

@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from copy import copy
from typing import Callable, Optional, Union
from typing import Optional, Union
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.outputs import CompletionOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.metrics.stats import IterationStats
@ -23,7 +23,7 @@ class ParentRequest:
child_requests: set[str]
# To aggregate child completions when not streaming
output_aggregator: Optional[RequestOutput]
output_aggregator: list[CompletionOutput]
# To find the max number of generated tokens across all children
max_num_generation_tokens: int
@ -37,7 +37,9 @@ class ParentRequest:
self.sampling_params = sampling_params
self.child_requests = set()
self.output_aggregator = None
self.output_aggregator = [None] * sampling_params.n if (
sampling_params.output_kind
== RequestOutputKind.FINAL_ONLY) else []
self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None
@ -93,43 +95,30 @@ class ParentRequest:
"""
child_req_id = f"{index}_{self.request_id}"
self.child_requests.add(child_req_id)
return (child_req_id, self._get_child_sampling_params(index))
def finish_child_request(self, req_id: str):
self.child_requests.remove(req_id)
return child_req_id, self._get_child_sampling_params(index)
@property
def n(self) -> int:
return self.sampling_params.n
def make_request_output(
def get_outputs(
self,
final_only: bool,
child_request_id: str,
completion_output: CompletionOutput,
new_request_output: Callable[[str], RequestOutput],
) -> Optional[RequestOutput]:
# Use an existing RequestOutput if we're aggregating
request_output = self.output_aggregator
) -> tuple[str, list[CompletionOutput], bool]:
if completion_output.finished():
self.child_requests.remove(child_request_id)
# Make new RequestOutput otherwise
if request_output is None:
request_output = new_request_output(self.request_id)
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
# If streaming, just return the current output.
outputs = [completion_output]
else:
# If not streaming, aggregate the n final outputs.
self.output_aggregator[completion_output.index] = completion_output
outputs = [] if self.child_requests else self.output_aggregator
# Add a new completion
request_output.outputs.append(completion_output)
# If not streaming, aggregate until all child requests complete
if final_only and len(request_output.outputs) != self.n:
self.output_aggregator = request_output
return None
# We're done aggregating
self.output_aggregator = None
# Parent completion output list must be sorted by index
request_output.outputs = sorted(request_output.outputs,
key=lambda x: x.index)
return request_output
finished = not self.child_requests
return self.request_id, outputs, finished
def observe_num_generation_tokens(self, num_generation_tokens: int):
self.max_num_generation_tokens = max(num_generation_tokens,