mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) (#10980)
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
444b0f0f62
commit
befc402d34
@ -1,21 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
DTYPE = "half"
|
||||
|
||||
def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch):
|
||||
"""Test passes if LLMEngine raises an exception when it is configured
|
||||
for automatic prefix caching and it receives a request with
|
||||
prompt_logprobs enabled, which is incompatible."""
|
||||
|
||||
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
|
||||
"""Set up VllmRunner instance."""
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
return vllm_runner(
|
||||
MODEL,
|
||||
dtype=DTYPE,
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=apc,
|
||||
gpu_memory_utilization=0.5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
# Function scope decouples tests & allows
|
||||
# env var adjustment via monkeypatch
|
||||
scope="function",
|
||||
# Prefix caching
|
||||
params=[False, True])
|
||||
def vllm_model(vllm_runner, request, monkeypatch):
|
||||
"""VllmRunner test fixture parameterized by APC True/False."""
|
||||
with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def vllm_model_apc(vllm_runner, monkeypatch):
|
||||
"""VllmRunner test fixture with APC."""
|
||||
with _vllm_model(True, vllm_runner, monkeypatch) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
def _get_test_sampling_params(
|
||||
prompt_list: List[str],
|
||||
seed: Optional[int] = 42,
|
||||
) -> Tuple[List[SamplingParams], List[int]]:
|
||||
"""Generate random sampling params for a batch."""
|
||||
|
||||
def get_mostly_n_gt1() -> int:
|
||||
"""Mostly n \in [2,20], ~1/3 n=1"""
|
||||
x = random.randint(0, 28)
|
||||
if x < 10:
|
||||
return 1
|
||||
else:
|
||||
return x - 8
|
||||
|
||||
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
|
||||
# High temperature to maximize the chance of unique completions
|
||||
return [
|
||||
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
|
||||
for n in n_list
|
||||
], n_list
|
||||
|
||||
|
||||
def test_parallel_sampling(vllm_model, example_prompts) -> None:
|
||||
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
|
||||
|
||||
Args:
|
||||
vllm_model: VllmRunner instance under test.
|
||||
example_prompt: test fixture providing prompts for testing.
|
||||
"""
|
||||
sampling_params_list, n_list = _get_test_sampling_params(example_prompts)
|
||||
model: LLM = vllm_model.model
|
||||
outputs = model.generate(example_prompts, sampling_params_list)
|
||||
|
||||
# Validate each request response
|
||||
for out, n in zip(outputs, n_list):
|
||||
completion_counts: Dict[str, int] = {}
|
||||
# Assert correct number of completions
|
||||
assert len(out.outputs) == n, (
|
||||
f"{len(out.outputs)} completions; {n} expected.")
|
||||
for idx in range(n):
|
||||
comp = out.outputs[idx]
|
||||
# Assert correct completion indices
|
||||
assert comp.index == idx, (f"Index {comp.index}; expected {idx}.")
|
||||
text = comp.text
|
||||
completion_counts[text] = completion_counts.get(text, 0) + 1
|
||||
# Assert unique completions
|
||||
if len(completion_counts) != n:
|
||||
repeats = {
|
||||
txt: num
|
||||
for (txt, num) in completion_counts.items() if num > 1
|
||||
}
|
||||
raise AssertionError(
|
||||
f"{len(completion_counts)} unique completions; expected"
|
||||
f" {n}. Repeats: {repeats}")
|
||||
|
||||
|
||||
def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc):
|
||||
"""Test passes if LLMEngine raises an exception when it is configured
|
||||
for automatic prefix caching and it receives a request with
|
||||
prompt_logprobs enabled, which is incompatible."""
|
||||
model: LLM = vllm_model_apc.model
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate(
|
||||
model.generate(
|
||||
"Hello, my name is",
|
||||
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))
|
||||
|
||||
|
||||
@ -250,6 +250,108 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
"""Parallel sampling without streaming.
|
||||
A single request output contains a list of completions.
|
||||
"""
|
||||
|
||||
prompt = "What is an LLM?"
|
||||
n = 3
|
||||
max_tokens = 5
|
||||
|
||||
# High temperature to maximize chance of unique completions.
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
temperature=0.95,
|
||||
stream=False,
|
||||
seed=42)
|
||||
|
||||
# Assert `n` completions
|
||||
num_completions = len(completion.choices)
|
||||
assert num_completions == n, (
|
||||
f"Num completions {num_completions} but expected {n}.")
|
||||
completion_repeats: Dict[str, int] = {}
|
||||
for idx, choice in enumerate(completion.choices):
|
||||
# Assert correct completion index & some finish reason.
|
||||
assert choice.index == idx, (
|
||||
f"Index {choice.index} but expected {idx}.")
|
||||
assert choice.finish_reason is not None, (
|
||||
"None finish_reason is invalid.")
|
||||
text = choice.text
|
||||
completion_repeats[text] = completion_repeats.get(text, 0) + 1
|
||||
# Assert `n` unique completions
|
||||
num_unique = len(completion_repeats)
|
||||
if num_unique != n:
|
||||
repeats = {
|
||||
txt: num
|
||||
for (txt, num) in completion_repeats.items() if num > 1
|
||||
}
|
||||
raise AssertionError(
|
||||
f"Expected {n} unique completions, got {num_unique};"
|
||||
f" repeats: {repeats}.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Streaming for parallel sampling.
|
||||
The tokens from multiple samples, are flattened into a single stream,
|
||||
with an index to indicate which sample the token belongs to.
|
||||
"""
|
||||
|
||||
prompt = "What is an LLM?"
|
||||
n = 3
|
||||
max_tokens = 5
|
||||
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
temperature=0.95,
|
||||
stream=True,
|
||||
seed=42)
|
||||
chunks: List[List[str]] = [[] for i in range(n)]
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
index = chunk.choices[0].index
|
||||
text = chunk.choices[0].text
|
||||
chunks[index].append(text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
# Assert `n` completions with correct finish reasons
|
||||
assert finish_reason_count == n, (
|
||||
f"Expected {n} completions with valid indices and finish_reason.")
|
||||
completion_repeats: Dict[str, int] = {}
|
||||
for chunk in chunks:
|
||||
chunk_len = len(chunk)
|
||||
# Assert correct number of completion tokens
|
||||
assert chunk_len == max_tokens, (
|
||||
f"max_tokens={max_tokens} but chunk len is {chunk_len}.")
|
||||
text = "".join(chunk)
|
||||
completion_repeats[text] = completion_repeats.get(text, 0) + 1
|
||||
print(text)
|
||||
# Assert `n` unique completions
|
||||
num_unique = len(completion_repeats)
|
||||
if num_unique != n:
|
||||
repeats = {
|
||||
txt: num
|
||||
for (txt, num) in completion_repeats.items() if num > 1
|
||||
}
|
||||
raise AssertionError(f"{num_unique} unique completions, expected {n};"
|
||||
f" repeats: {repeats}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
||||
@ -24,6 +24,7 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import cdiv, kill_process_tree
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
|
||||
@ -170,7 +171,7 @@ class AsyncLLM(EngineClient):
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
# and so we don't need multiple streams which then get
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def generate(
|
||||
async def _generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
@ -241,6 +242,30 @@ class AsyncLLM(EngineClient):
|
||||
await self.abort(request_id)
|
||||
raise
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
kwargs = dict(prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority)
|
||||
if sampling_params.n is None or sampling_params.n == 1:
|
||||
return self._generate(**kwargs)
|
||||
else:
|
||||
# Special handling for parallel sampling requests
|
||||
return generate_parallel_sampling_async(generate=self._generate,
|
||||
**kwargs)
|
||||
|
||||
async def _run_output_handler(self):
|
||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.transformers_utils.tokenizer_group import (
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
@ -48,6 +49,9 @@ class LLMEngine:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
# Bookkeeping for parallel sampling requests
|
||||
self.parallel_manager = SyncParallelSamplingManager()
|
||||
|
||||
# important: init dp group before init the engine_core
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
|
||||
@ -115,7 +119,8 @@ class LLMEngine:
|
||||
multiprocess_mode=enable_multiprocessing)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.output_processor.get_num_unfinished_requests()
|
||||
return self.parallel_manager.get_num_unfinished_requests(
|
||||
self.output_processor.get_num_unfinished_requests())
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
has_unfinished = self.output_processor.has_unfinished_requests()
|
||||
@ -151,7 +156,36 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add request."""
|
||||
kwargs = dict(request_id=request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority)
|
||||
# Handle parallel sampling requests differently.
|
||||
if params is None or isinstance(params,
|
||||
PoolingParams) or params.n == 1:
|
||||
self._add_request(**kwargs)
|
||||
else:
|
||||
# Special handling for parallel sampling requests
|
||||
self.parallel_manager.add_request_parallel_sampling(
|
||||
add_request=self._add_request, **kwargs)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add request, `n=1`"""
|
||||
# 1) Process raw inputs into the request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
@ -182,7 +216,10 @@ class LLMEngine:
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
return processed_outputs.request_outputs
|
||||
request_outputs = processed_outputs.request_outputs
|
||||
|
||||
# 4) Process unfinished parallel sampling requests
|
||||
return self.parallel_manager.step(request_outputs)
|
||||
|
||||
def get_model_config(self):
|
||||
return self.model_config
|
||||
|
||||
375
vllm/v1/engine/parallel_sampling.py
Normal file
375
vllm/v1/engine/parallel_sampling.py
Normal file
@ -0,0 +1,375 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from copy import copy
|
||||
from typing import (AsyncGenerator, Dict, List, Mapping, Optional, Protocol,
|
||||
Tuple, Union)
|
||||
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
|
||||
class AsyncGenerateMethodType(Protocol):
|
||||
|
||||
def __call__(self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0) -> AsyncGenerator[RequestOutput, None]:
|
||||
...
|
||||
|
||||
|
||||
class SyncAddRequestMethodType(Protocol):
|
||||
|
||||
def __call__(self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0) -> None:
|
||||
...
|
||||
|
||||
|
||||
class ParallelSamplingRequest:
|
||||
"""Info, state & processing for parallel sampling request.
|
||||
|
||||
Store parent request ID and sampling params.
|
||||
Facilitate generating child request sampling params.
|
||||
Transform child request outputs into parent request
|
||||
outputs.
|
||||
When stream mode is disabled, then `self.request_output`
|
||||
aggregates child request completions.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
sampling_params: SamplingParams
|
||||
cached_child_sampling_params: Optional[SamplingParams]
|
||||
request_output: Optional[RequestOutput]
|
||||
num_finished_completions: int
|
||||
|
||||
def __init__(self, request_id: str,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
self.request_id = request_id
|
||||
self.sampling_params = sampling_params
|
||||
self.cached_child_sampling_params = None
|
||||
self.request_output = None
|
||||
self.num_finished_completions = 0
|
||||
|
||||
def _get_child_sampling_params(
|
||||
self,
|
||||
index: int,
|
||||
) -> SamplingParams:
|
||||
"""Efficiently obtain child `sampling_params`
|
||||
|
||||
If `sampling_params.seed` is not `None` then
|
||||
each child request requires a unique clone of
|
||||
parent `sampling_params` with a unique seed.
|
||||
|
||||
Args:
|
||||
index: index within `n` child requests
|
||||
|
||||
Returns:
|
||||
Child `sampling_params` instance.
|
||||
"""
|
||||
seed = self.sampling_params.seed
|
||||
if self.cached_child_sampling_params:
|
||||
# Reuse child sampling_params data structure
|
||||
return self.cached_child_sampling_params
|
||||
# Build child sampling_params
|
||||
child_sampling_params = copy(self.sampling_params)
|
||||
child_sampling_params.n = 1
|
||||
if seed is None:
|
||||
# Cache child sampling_params for later reuse
|
||||
self.cached_child_sampling_params = child_sampling_params
|
||||
else:
|
||||
# Each child gets a clone with a unique seed
|
||||
child_sampling_params.seed = seed + index
|
||||
return child_sampling_params
|
||||
|
||||
def _add_output(
|
||||
self,
|
||||
child_req_output: RequestOutput,
|
||||
index: int,
|
||||
) -> None:
|
||||
"""Aggregate a parallel sampling child
|
||||
request output.
|
||||
|
||||
Non-stream-mode (`output_kind == FINAL_ONLY`)
|
||||
only. Inject correct parent request ID and
|
||||
completion index.
|
||||
|
||||
Args:
|
||||
child_req_output: a single request output
|
||||
from a parallel sampling
|
||||
child request.
|
||||
index: index within `n` child
|
||||
"""
|
||||
self.num_finished_completions += 1
|
||||
new_completion = child_req_output.outputs[0]
|
||||
new_completion.index = index
|
||||
if self.request_output is None:
|
||||
# Save the first request output; reinstate
|
||||
# original request ID; metrics are not
|
||||
# supported for parallel sampling
|
||||
child_req_output.request_id = self.request_id
|
||||
child_req_output.metrics = None
|
||||
self.request_output = child_req_output
|
||||
else:
|
||||
# Aggregate additional completion into request output
|
||||
# Note: will be sorted by index later
|
||||
self.request_output.outputs.append(new_completion)
|
||||
|
||||
def _get_final_request_output(self) -> RequestOutput:
|
||||
"""Invariant: parent completion outputs sorted by index"""
|
||||
assert self.request_output is not None
|
||||
self.request_output.finished = True
|
||||
self.request_output.outputs = sorted(self.request_output.outputs,
|
||||
key=lambda x: x.index)
|
||||
return self.request_output
|
||||
|
||||
def get_child_info(self, index: int) -> Tuple[str, SamplingParams]:
|
||||
"""Get child request ID and sampling params.
|
||||
|
||||
Args:
|
||||
index: index within `n` child requests.
|
||||
|
||||
Returns:
|
||||
(request ID, sampling_params) tuple
|
||||
"""
|
||||
return (f"{index}_{self.request_id}",
|
||||
self._get_child_sampling_params(index))
|
||||
|
||||
def process_output(
|
||||
self,
|
||||
child_req_output: RequestOutput,
|
||||
index: int,
|
||||
) -> Optional[RequestOutput]:
|
||||
"""Filter, aggregate and transform parallel sampling
|
||||
child request outputs.
|
||||
|
||||
If the parent request has `stream=false`
|
||||
(`output_kind == FINAL_ONLY`), each child will also have
|
||||
`output_kind == FINAL_ONLY`. All child request outputs
|
||||
must be aggregated into a single request output, with
|
||||
multiple completions. This request output is only returned
|
||||
once `n` completions are aggregated.
|
||||
|
||||
If the parent request has `stream=true`
|
||||
(`output_kind == DELTA`), each child will also have
|
||||
`output_kind == DELTA`. All child request outputs
|
||||
must be streamed directly to the caller.
|
||||
|
||||
Args:
|
||||
child_req_output: a single child request output
|
||||
index: index within `n` child requests
|
||||
|
||||
Returns:
|
||||
`None`, unless a processed request output is ready to
|
||||
send back to the caller.
|
||||
"""
|
||||
if self.output_kind != RequestOutputKind.FINAL_ONLY:
|
||||
# stream=true: return child completions immediately
|
||||
child_req_output.request_id = self.request_id
|
||||
child_req_output.outputs[0].index = index
|
||||
if child_req_output.finished:
|
||||
# Parent request is complete if all child requests are
|
||||
# complete.
|
||||
self.num_finished_completions += 1
|
||||
child_req_output.finished = (
|
||||
self.num_finished_completions == self.n)
|
||||
return child_req_output
|
||||
|
||||
# stream=false: aggregate child completions
|
||||
self._add_output(child_req_output, index)
|
||||
if self.num_finished_completions == self.n:
|
||||
# Return aggregated request output after obtaining
|
||||
# all completions
|
||||
return self._get_final_request_output()
|
||||
return None
|
||||
|
||||
async def wrap_child_async_generator(
|
||||
self,
|
||||
child_gen: AsyncGenerator[RequestOutput, None],
|
||||
index: int,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Output generator for a single parallel sampling
|
||||
child request.
|
||||
|
||||
Each parallel sampling request triggers at
|
||||
least two child requests. This generator
|
||||
yields zero or more request outputs to
|
||||
return to the caller, as they become
|
||||
available.
|
||||
|
||||
Args:
|
||||
child_gen: generator for child request
|
||||
outputs.
|
||||
index: index within the `n` child requests
|
||||
|
||||
Returns:
|
||||
Yields zero or more request outputs to return
|
||||
to the caller.
|
||||
"""
|
||||
async for out in child_gen:
|
||||
if req_out := self.process_output(out, index):
|
||||
yield req_out
|
||||
|
||||
@property
|
||||
def n(self) -> int:
|
||||
return self.sampling_params.n
|
||||
|
||||
@property
|
||||
def output_kind(self) -> RequestOutputKind:
|
||||
return self.sampling_params.output_kind
|
||||
|
||||
|
||||
class SyncParallelSamplingManager:
|
||||
|
||||
def __init__(self):
|
||||
# Parent req ID -> parent request manager
|
||||
self.parent_reqs: Dict[str, ParallelSamplingRequest] = {}
|
||||
# Child req ID -> (child req index, parent req ID)
|
||||
self.child_reqs: Dict[str, Tuple[int, str]] = {}
|
||||
|
||||
def _register_parent_request(self, req: ParallelSamplingRequest) -> None:
|
||||
"""Register parallel sampling parent request."""
|
||||
self.parent_reqs[req.request_id] = req
|
||||
|
||||
def _register_child_request(self, req_id: str, child_req_id: str,
|
||||
index: int) -> None:
|
||||
"""Register parallel sampling child request with parent.
|
||||
|
||||
Args:
|
||||
req_id: parent request ID
|
||||
child_req_id: child request ID
|
||||
index: child request index within `n` child requests
|
||||
"""
|
||||
self.child_reqs[child_req_id] = (index, req_id)
|
||||
|
||||
def get_num_unfinished_requests(self, num_core_reqs: int) -> int:
|
||||
"""Get the number of unfinished requests, correcting for parallel
|
||||
sampling.
|
||||
|
||||
Args:
|
||||
num_core_reqs: The number of unfinished requests in the engine core.
|
||||
|
||||
Returns:
|
||||
Number of unfinished requests, where each parallel sampling req
|
||||
counts as 1
|
||||
"""
|
||||
return num_core_reqs + len(self.parent_reqs) - len(self.child_reqs)
|
||||
|
||||
def add_request_parallel_sampling(
|
||||
self,
|
||||
add_request: SyncAddRequestMethodType,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add sync parallel sampling request."""
|
||||
req = ParallelSamplingRequest(request_id, params)
|
||||
self._register_parent_request(req)
|
||||
# Add n child requests with unique request IDs & random seeds and n=1
|
||||
for idx in range(req.n):
|
||||
child_req_id, child_params = req.get_child_info(idx)
|
||||
self._register_child_request(request_id, child_req_id, idx)
|
||||
add_request(request_id=child_req_id,
|
||||
prompt=prompt,
|
||||
params=child_params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority) # type: ignore
|
||||
|
||||
def step(
|
||||
self,
|
||||
outputs: List[RequestOutput],
|
||||
) -> List[RequestOutput]:
|
||||
"""Build parallel sampling request outputs.
|
||||
|
||||
Extract child request outputs, aggregate them
|
||||
into parent request output, and return parent
|
||||
output when complete.
|
||||
|
||||
Do not modify `n=1` requests.
|
||||
|
||||
Args:
|
||||
outputs: step request outputs. Mix of child request
|
||||
outputs & `n=1` request outputs.
|
||||
|
||||
Return:
|
||||
List of parallel sampling parent request outputs &
|
||||
unmodified `n=1` request outputs passed-thru from input.
|
||||
"""
|
||||
if not (self.parent_reqs and outputs):
|
||||
# Return unmodified
|
||||
return outputs
|
||||
agg_outputs = []
|
||||
for output in outputs:
|
||||
req_id = output.request_id
|
||||
if child_req_entry := self.child_reqs.get(req_id, None):
|
||||
# For each parallel sampling child request output:
|
||||
(index, parent_req_id) = child_req_entry
|
||||
req = self.parent_reqs[parent_req_id]
|
||||
# Update parallel sampling request
|
||||
if out := req.process_output(output, index):
|
||||
# Return parent request output if complete;
|
||||
# cleanup parent request bookkeeping.
|
||||
agg_outputs.append(out)
|
||||
del self.parent_reqs[parent_req_id]
|
||||
# Cleanup child request bookkeeping.
|
||||
del self.child_reqs[req_id]
|
||||
else:
|
||||
# Not a parallel sampling request output
|
||||
agg_outputs.append(output)
|
||||
return agg_outputs
|
||||
|
||||
|
||||
async def generate_parallel_sampling_async(
|
||||
generate: AsyncGenerateMethodType,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate completions for async parallel sampling requests."""
|
||||
parent_req = ParallelSamplingRequest(request_id, sampling_params)
|
||||
|
||||
# Aggregate generators for n child requests
|
||||
gens: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
for idx in range(parent_req.n):
|
||||
child_req_id, child_params = parent_req.get_child_info(idx)
|
||||
child_gen = generate(
|
||||
prompt=prompt,
|
||||
sampling_params=child_params,
|
||||
request_id=child_req_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
) # type: ignore
|
||||
gen = parent_req.wrap_child_async_generator(child_gen, idx)
|
||||
gens.append(gen)
|
||||
|
||||
# Merge generators
|
||||
async for _, out in merge_async_iterators(*gens):
|
||||
yield out
|
||||
Loading…
x
Reference in New Issue
Block a user