mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:25:01 +08:00
[BugFix][V1] Fix parallel sampling finishing/aborts (#14512)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
916836bbfb
commit
f5d3acd474
@ -46,6 +46,7 @@ async def generate(engine: AsyncLLM,
|
||||
prompt: PromptType,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int,
|
||||
n: int = 1,
|
||||
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
|
||||
# Ensure generate doesn't complete too fast for cancellation test.
|
||||
await asyncio.sleep(0.2)
|
||||
@ -54,13 +55,15 @@ async def generate(engine: AsyncLLM,
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True,
|
||||
output_kind=output_kind,
|
||||
temperature=0,
|
||||
temperature=0.5,
|
||||
seed=33,
|
||||
n=n,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
async for out in engine.generate(request_id=request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params):
|
||||
|
||||
num_tokens = len(out.outputs[0].token_ids)
|
||||
num_tokens = sum(len(output.token_ids) for output in out.outputs)
|
||||
if output_kind == RequestOutputKind.DELTA:
|
||||
count += num_tokens
|
||||
else:
|
||||
@ -136,17 +139,22 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
NUM_EXPECTED_TOKENS = 100
|
||||
NUM_EXPECTED_TOKENS_LONG = 50000
|
||||
REQUEST_IDS_TO_ABORT = range(1, 100, 10)
|
||||
PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests.
|
||||
tasks: list[asyncio.Task] = []
|
||||
for request_id in request_ids:
|
||||
for idx, request_id in enumerate(request_ids):
|
||||
max_tokens = NUM_EXPECTED_TOKENS_LONG if (
|
||||
idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS
|
||||
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, prompt, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
max_tokens, n)))
|
||||
|
||||
# API server cancels requests when they disconnect.
|
||||
for idx in REQUEST_IDS_TO_ABORT:
|
||||
@ -162,10 +170,13 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
|
||||
else:
|
||||
# Otherwise, make sure the request was not impacted.
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
|
||||
expected_tokens = NUM_EXPECTED_TOKENS * n
|
||||
assert num_generated_tokens == expected_tokens, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {NUM_EXPECTED_TOKENS}")
|
||||
f"expected {expected_tokens}")
|
||||
|
||||
# Make sure all aborted requests were really aborted.
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
# Confirm we can do another generation.
|
||||
@ -176,3 +187,36 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [1, 3])
|
||||
@pytest.mark.parametrize("engine_args_and_prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT),
|
||||
(VISION_ENGINE_ARGS, VISION_PROMPT)])
|
||||
@pytest.mark.asyncio
|
||||
async def test_finished_flag(monkeypatch, n: int,
|
||||
engine_args_and_prompt: tuple[AsyncEngineArgs,
|
||||
PromptType]):
|
||||
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
engine_args, prompt = engine_args_and_prompt
|
||||
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=100,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
temperature=1.0,
|
||||
seed=33,
|
||||
n=n)
|
||||
outputs = [
|
||||
out
|
||||
async for out in engine.generate(request_id="request-33",
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params)
|
||||
]
|
||||
|
||||
# Assert only the last output has the finished flag set
|
||||
assert all(not out.finished for out in outputs[:-1])
|
||||
assert outputs[-1].finished
|
||||
|
||||
@ -263,15 +263,16 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
|
||||
|
||||
prompt = "What is an LLM?"
|
||||
n = 3
|
||||
max_tokens = 5
|
||||
max_tokens = 50 # we want some to finish earlier than others
|
||||
|
||||
# High temperature to maximize chance of unique completions.
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
temperature=0.95,
|
||||
temperature=1.0,
|
||||
stream=False,
|
||||
logprobs=0,
|
||||
seed=42)
|
||||
|
||||
# Assert `n` completions
|
||||
@ -279,6 +280,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
|
||||
assert num_completions == n, (
|
||||
f"Num completions {num_completions} but expected {n}.")
|
||||
completion_repeats: dict[str, int] = {}
|
||||
output_token_lengths = set()
|
||||
for idx, choice in enumerate(completion.choices):
|
||||
# Assert correct completion index & some finish reason.
|
||||
assert choice.index == idx, (
|
||||
@ -287,6 +289,9 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
|
||||
"None finish_reason is invalid.")
|
||||
text = choice.text
|
||||
completion_repeats[text] = completion_repeats.get(text, 0) + 1
|
||||
output_token_lengths.add(len(choice.logprobs.tokens))
|
||||
# Assert subrequests finished at different times
|
||||
assert len(output_token_lengths) > 1
|
||||
# Assert `n` unique completions
|
||||
num_unique = len(completion_repeats)
|
||||
if num_unique != n:
|
||||
@ -312,16 +317,16 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||
|
||||
prompt = "What is an LLM?"
|
||||
n = 3
|
||||
max_tokens = 5
|
||||
max_tokens = 50 # we want some to finish earlier than others
|
||||
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
temperature=0.95,
|
||||
temperature=1.0,
|
||||
stream=True,
|
||||
seed=42)
|
||||
chunks: list[list[str]] = [[] for i in range(n)]
|
||||
chunks: list[list[str]] = [[] for _ in range(n)]
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
index = chunk.choices[0].index
|
||||
@ -333,14 +338,18 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||
assert finish_reason_count == n, (
|
||||
f"Expected {n} completions with valid indices and finish_reason.")
|
||||
completion_repeats: dict[str, int] = {}
|
||||
chunk_lengths = set()
|
||||
for chunk in chunks:
|
||||
chunk_len = len(chunk)
|
||||
# Assert correct number of completion tokens
|
||||
assert chunk_len == max_tokens, (
|
||||
chunk_lengths.add(chunk_len)
|
||||
assert chunk_len <= max_tokens, (
|
||||
f"max_tokens={max_tokens} but chunk len is {chunk_len}.")
|
||||
text = "".join(chunk)
|
||||
completion_repeats[text] = completion_repeats.get(text, 0) + 1
|
||||
print(text)
|
||||
# Assert subrequests finished at different times
|
||||
assert len(chunk_lengths) > 1
|
||||
# Assert `n` unique completions
|
||||
num_unique = len(completion_repeats)
|
||||
if num_unique != n:
|
||||
|
||||
@ -134,57 +134,29 @@ class RequestOutput:
|
||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||
self.num_cached_tokens = num_cached_tokens
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
text: str,
|
||||
token_ids: list[int],
|
||||
logprobs: Optional[SampleLogprobs],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
cumulative_logprob: Optional[float],
|
||||
finished: bool = False,
|
||||
) -> "RequestOutput":
|
||||
"""Initialize a new RequestOutput object."""
|
||||
|
||||
# TODO: Support `n` > 1.
|
||||
completion_output = CompletionOutput(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
cumulative_logprob=cumulative_logprob,
|
||||
logprobs=logprobs)
|
||||
|
||||
return RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=[completion_output],
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
def add(self, next_output: "RequestOutput") -> None:
|
||||
"""Merge subsequent RequestOutput into this one"""
|
||||
|
||||
self.prompt = next_output.prompt
|
||||
self.prompt_token_ids = next_output.prompt_token_ids
|
||||
self.prompt_logprobs = next_output.prompt_logprobs
|
||||
self.finished |= next_output.finished
|
||||
|
||||
#TODO assuming n == 1 for now
|
||||
completion = self.outputs[0]
|
||||
next_completion = next_output.outputs[0]
|
||||
completion.text += next_completion.text
|
||||
if not isinstance(completion.token_ids, MutableSequence):
|
||||
completion.token_ids = list(completion.token_ids)
|
||||
completion.token_ids.extend(next_completion.token_ids)
|
||||
if next_completion.logprobs:
|
||||
assert completion.logprobs is not None
|
||||
completion.logprobs.extend(next_completion.logprobs)
|
||||
completion.cumulative_logprob = next_completion.cumulative_logprob
|
||||
for next_completion in next_output.outputs:
|
||||
for completion in self.outputs:
|
||||
if completion.index == next_completion.index:
|
||||
# Merge outputs with same index
|
||||
completion.text += next_completion.text
|
||||
if not isinstance(completion.token_ids, MutableSequence):
|
||||
completion.token_ids = list(completion.token_ids)
|
||||
completion.token_ids.extend(next_completion.token_ids)
|
||||
if next_completion.logprobs:
|
||||
assert completion.logprobs is not None
|
||||
completion.logprobs.extend(next_completion.logprobs)
|
||||
completion.cumulative_logprob = (
|
||||
next_completion.cumulative_logprob)
|
||||
completion.finish_reason = next_completion.finish_reason
|
||||
completion.stop_reason = next_completion.stop_reason
|
||||
break
|
||||
else:
|
||||
self.outputs.append(next_completion)
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(
|
||||
|
||||
@ -298,9 +298,8 @@ class AsyncLLM(EngineClient):
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort RequestId in OutputProcessor and EngineCore."""
|
||||
|
||||
request_ids = [request_id]
|
||||
request_ids = self.output_processor.abort_requests((request_id, ))
|
||||
await self.engine_core.abort_requests_async(request_ids)
|
||||
self.output_processor.abort_requests(request_ids)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request_id)
|
||||
|
||||
@ -137,8 +137,8 @@ class LLMEngine:
|
||||
def abort_request(self, request_ids: list[str]) -> None:
|
||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||
|
||||
request_ids = self.output_processor.abort_requests(request_ids)
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
self.output_processor.abort_requests(request_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -102,8 +103,7 @@ class RequestState:
|
||||
) -> Optional[RequestOutput]:
|
||||
|
||||
finished = finish_reason is not None
|
||||
output_kind = self.output_kind
|
||||
final_only = output_kind == RequestOutputKind.FINAL_ONLY
|
||||
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# In follow up, we will switch to invariant where EngineCore
|
||||
# does not stream partial prefills.
|
||||
@ -111,24 +111,24 @@ class RequestState:
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
def new_request_output(request_id: str) -> RequestOutput:
|
||||
return self._new_request_output(request_id, finished)
|
||||
|
||||
completion_output = self._new_completion_output(
|
||||
new_token_ids, finish_reason, stop_reason)
|
||||
|
||||
if self.parent_req is not None:
|
||||
return self.parent_req.make_request_output(final_only,
|
||||
completion_output,
|
||||
new_request_output)
|
||||
request_id = self.request_id
|
||||
if self.parent_req is None:
|
||||
outputs = [completion_output]
|
||||
else:
|
||||
request_id, outputs, finished = self.parent_req.get_outputs(
|
||||
request_id, completion_output)
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
request_output = new_request_output(self.request_id)
|
||||
request_output.outputs.append(completion_output)
|
||||
return request_output
|
||||
return self._new_request_output(request_id, outputs, finished)
|
||||
|
||||
def _new_request_output(
|
||||
self,
|
||||
request_id: str,
|
||||
outputs: list[CompletionOutput],
|
||||
finished: bool,
|
||||
) -> RequestOutput:
|
||||
|
||||
@ -143,7 +143,7 @@ class RequestState:
|
||||
prompt=self.prompt,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=[],
|
||||
outputs=outputs,
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
@ -188,6 +188,7 @@ class OutputProcessor:
|
||||
self.log_stats = log_stats
|
||||
self.tokenizer = tokenizer
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.lora_states = LoRARequestStates()
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
@ -198,14 +199,20 @@ class OutputProcessor:
|
||||
|
||||
def abort_requests(
|
||||
self,
|
||||
request_ids: list[str],
|
||||
) -> None:
|
||||
request_ids: Iterable[str],
|
||||
) -> list[str]:
|
||||
request_ids_to_abort = []
|
||||
for request_id in request_ids:
|
||||
req_state = self.request_states.pop(request_id, None)
|
||||
if req_state is not None:
|
||||
self.lora_states.abort_request(req_state)
|
||||
if req_state.parent_req is not None:
|
||||
req_state.parent_req.finish_child_request(request_id)
|
||||
request_ids_to_abort.append(request_id)
|
||||
else:
|
||||
parent = self.parent_requests.pop(request_id, None)
|
||||
if parent and parent.child_requests:
|
||||
self.abort_requests(parent.child_requests)
|
||||
request_ids_to_abort.extend(parent.child_requests)
|
||||
return request_ids_to_abort
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
@ -227,6 +234,8 @@ class OutputProcessor:
|
||||
log_stats=self.log_stats)
|
||||
self.request_states[request_id] = req_state
|
||||
self.lora_states.add_request(req_state)
|
||||
if parent_req:
|
||||
self.parent_requests[parent_req.request_id] = parent_req
|
||||
|
||||
def process_outputs(
|
||||
self,
|
||||
@ -314,12 +323,14 @@ class OutputProcessor:
|
||||
# Free completed requests.
|
||||
if finish_reason is not None:
|
||||
self.request_states.pop(req_id)
|
||||
# Remove parent request if applicable.
|
||||
parent_req = req_state.parent_req
|
||||
if parent_req and not parent_req.child_requests:
|
||||
self.parent_requests.pop(parent_req.request_id, None)
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
if req_state.parent_req is not None:
|
||||
req_state.parent_req.finish_child_request(req_id)
|
||||
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(req_state, finish_reason,
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from copy import copy
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class ParentRequest:
|
||||
child_requests: set[str]
|
||||
|
||||
# To aggregate child completions when not streaming
|
||||
output_aggregator: Optional[RequestOutput]
|
||||
output_aggregator: list[CompletionOutput]
|
||||
|
||||
# To find the max number of generated tokens across all children
|
||||
max_num_generation_tokens: int
|
||||
@ -37,7 +37,9 @@ class ParentRequest:
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.child_requests = set()
|
||||
self.output_aggregator = None
|
||||
self.output_aggregator = [None] * sampling_params.n if (
|
||||
sampling_params.output_kind
|
||||
== RequestOutputKind.FINAL_ONLY) else []
|
||||
self.max_num_generation_tokens = 0
|
||||
self.cached_child_sampling_params = None
|
||||
|
||||
@ -93,43 +95,30 @@ class ParentRequest:
|
||||
"""
|
||||
child_req_id = f"{index}_{self.request_id}"
|
||||
self.child_requests.add(child_req_id)
|
||||
return (child_req_id, self._get_child_sampling_params(index))
|
||||
|
||||
def finish_child_request(self, req_id: str):
|
||||
self.child_requests.remove(req_id)
|
||||
return child_req_id, self._get_child_sampling_params(index)
|
||||
|
||||
@property
|
||||
def n(self) -> int:
|
||||
return self.sampling_params.n
|
||||
|
||||
def make_request_output(
|
||||
def get_outputs(
|
||||
self,
|
||||
final_only: bool,
|
||||
child_request_id: str,
|
||||
completion_output: CompletionOutput,
|
||||
new_request_output: Callable[[str], RequestOutput],
|
||||
) -> Optional[RequestOutput]:
|
||||
# Use an existing RequestOutput if we're aggregating
|
||||
request_output = self.output_aggregator
|
||||
) -> tuple[str, list[CompletionOutput], bool]:
|
||||
if completion_output.finished():
|
||||
self.child_requests.remove(child_request_id)
|
||||
|
||||
# Make new RequestOutput otherwise
|
||||
if request_output is None:
|
||||
request_output = new_request_output(self.request_id)
|
||||
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
|
||||
# If streaming, just return the current output.
|
||||
outputs = [completion_output]
|
||||
else:
|
||||
# If not streaming, aggregate the n final outputs.
|
||||
self.output_aggregator[completion_output.index] = completion_output
|
||||
outputs = [] if self.child_requests else self.output_aggregator
|
||||
|
||||
# Add a new completion
|
||||
request_output.outputs.append(completion_output)
|
||||
|
||||
# If not streaming, aggregate until all child requests complete
|
||||
if final_only and len(request_output.outputs) != self.n:
|
||||
self.output_aggregator = request_output
|
||||
return None
|
||||
|
||||
# We're done aggregating
|
||||
self.output_aggregator = None
|
||||
|
||||
# Parent completion output list must be sorted by index
|
||||
request_output.outputs = sorted(request_output.outputs,
|
||||
key=lambda x: x.index)
|
||||
return request_output
|
||||
finished = not self.child_requests
|
||||
return self.request_id, outputs, finished
|
||||
|
||||
def observe_num_generation_tokens(self, num_generation_tokens: int):
|
||||
self.max_num_generation_tokens = max(num_generation_tokens,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user