mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 21:22:09 +08:00
[V1] Refactor parallel sampling support (#13774)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
f35f8e2242
commit
4167252eaf
@ -25,7 +25,7 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import cdiv, kill_process_tree
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
|
||||
@ -145,25 +145,30 @@ class AsyncLLM(EngineClient):
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
# 1) Create a new output queue for the request.
|
||||
if self.output_processor.is_request_active(request_id):
|
||||
raise ValueError(f"Request id {request_id} already running.")
|
||||
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
|
||||
|
||||
# 2) Convert Input --> Request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
# 2) Fan out child requests (for n>1)
|
||||
parent_req = ParentRequest.from_params(request_id, params)
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
for idx in range(n):
|
||||
if parent_req is not None:
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
|
||||
# 3) Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, queue)
|
||||
# 3) Convert Input --> Request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
# 4) Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
# 4) Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, parent_req, idx, queue)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request_id)
|
||||
# 5) Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request_id)
|
||||
|
||||
return queue
|
||||
|
||||
@ -172,7 +177,7 @@ class AsyncLLM(EngineClient):
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
# and so we don't need multiple streams which then get
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def _generate(
|
||||
async def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
@ -243,30 +248,6 @@ class AsyncLLM(EngineClient):
|
||||
await self.abort(request_id)
|
||||
raise
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
kwargs = dict(prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority)
|
||||
if sampling_params.n is None or sampling_params.n == 1:
|
||||
return self._generate(**kwargs)
|
||||
else:
|
||||
# Special handling for parallel sampling requests
|
||||
return generate_parallel_sampling_async(generate=self._generate,
|
||||
**kwargs)
|
||||
|
||||
async def _run_output_handler(self):
|
||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer_group import (
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
@ -50,9 +50,6 @@ class LLMEngine:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
# Bookkeeping for parallel sampling requests
|
||||
self.parallel_manager = SyncParallelSamplingManager()
|
||||
|
||||
# important: init dp group before init the engine_core
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
|
||||
@ -120,8 +117,7 @@ class LLMEngine:
|
||||
multiprocess_mode=enable_multiprocessing)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.parallel_manager.get_num_unfinished_requests(
|
||||
self.output_processor.get_num_unfinished_requests())
|
||||
return self.output_processor.get_num_unfinished_requests()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
has_unfinished = self.output_processor.has_unfinished_requests()
|
||||
@ -157,48 +153,25 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add request."""
|
||||
kwargs = dict(request_id=request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority)
|
||||
# Handle parallel sampling requests differently.
|
||||
if params is None or isinstance(params,
|
||||
PoolingParams) or params.n == 1:
|
||||
self._add_request(**kwargs)
|
||||
else:
|
||||
# Special handling for parallel sampling requests
|
||||
self.parallel_manager.add_request_parallel_sampling(
|
||||
add_request=self._add_request, **kwargs)
|
||||
# 1) Fan out child requests (for n>1)
|
||||
parent_req = ParentRequest.from_params(request_id, params)
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
for idx in range(n):
|
||||
if parent_req is not None:
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add request, `n=1`"""
|
||||
# 1) Process raw inputs into the request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
# 2) Process raw inputs into the request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
# 2) Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request)
|
||||
# 3) Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, parent_req, idx)
|
||||
|
||||
# 3) Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
# 3) Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
|
||||
def step(self) -> list[RequestOutput]:
|
||||
|
||||
@ -217,10 +190,7 @@ class LLMEngine:
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
request_outputs = processed_outputs.request_outputs
|
||||
|
||||
# 4) Process unfinished parallel sampling requests
|
||||
return self.parallel_manager.step(request_outputs)
|
||||
return processed_outputs.request_outputs
|
||||
|
||||
def get_model_config(self):
|
||||
return self.model_config
|
||||
|
||||
@ -4,13 +4,14 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
|
||||
RequestStateStats)
|
||||
|
||||
@ -27,6 +28,8 @@ class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
parent_req: Optional[ParentRequest],
|
||||
request_index: int,
|
||||
lora_name: Optional[str],
|
||||
output_kind: RequestOutputKind,
|
||||
prompt: Optional[str],
|
||||
@ -38,6 +41,8 @@ class RequestState:
|
||||
log_stats: bool,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.parent_req = parent_req
|
||||
self.request_index = request_index
|
||||
self.lora_name = lora_name
|
||||
self.output_kind = output_kind
|
||||
self.prompt = prompt
|
||||
@ -56,11 +61,15 @@ class RequestState:
|
||||
cls,
|
||||
tokenizer: AnyTokenizer,
|
||||
request: EngineCoreRequest,
|
||||
parent_req: Optional[ParentRequest],
|
||||
request_index: int,
|
||||
queue: Optional[asyncio.Queue[RequestOutput]],
|
||||
log_stats: bool,
|
||||
) -> "RequestState":
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
lora_name=(request.lora_request.name
|
||||
if request.lora_request is not None else None),
|
||||
output_kind=request.sampling_params.output_kind,
|
||||
@ -79,6 +88,88 @@ class RequestState:
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
def make_request_output(
|
||||
self,
|
||||
new_token_ids: list[int],
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Union[int, str, None],
|
||||
) -> Optional[RequestOutput]:
|
||||
|
||||
finished = finish_reason is not None
|
||||
output_kind = self.output_kind
|
||||
final_only = output_kind == RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# In follow up, we will switch to invariant where EngineCore
|
||||
# does not stream partial prefills.
|
||||
if not finished and (self.is_prefilling or final_only):
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
def new_request_output(request_id: str) -> RequestOutput:
|
||||
return self._new_request_output(request_id, finished)
|
||||
|
||||
completion_output = self._new_completion_output(
|
||||
new_token_ids, finish_reason, stop_reason)
|
||||
|
||||
if self.parent_req is not None:
|
||||
return self.parent_req.make_request_output(final_only,
|
||||
completion_output,
|
||||
new_request_output)
|
||||
|
||||
request_output = new_request_output(self.request_id)
|
||||
request_output.outputs.append(completion_output)
|
||||
return request_output
|
||||
|
||||
def _new_request_output(
|
||||
self,
|
||||
request_id: str,
|
||||
finished: bool,
|
||||
) -> RequestOutput:
|
||||
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
# Side effect: logprobs processor forgets prompt logprobs
|
||||
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
|
||||
else:
|
||||
prompt_logprobs = self.logprobs_processor.prompt_logprobs
|
||||
|
||||
return RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=self.prompt,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=[],
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
def _new_completion_output(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Union[int, str, None],
|
||||
) -> CompletionOutput:
|
||||
|
||||
finished = finish_reason is not None
|
||||
delta = self.output_kind == RequestOutputKind.DELTA
|
||||
|
||||
# Prepare text and token_ids, based on delta mode
|
||||
text = self.detokenizer.get_next_output_text(finished, delta)
|
||||
if not delta:
|
||||
token_ids = self.detokenizer.output_token_ids
|
||||
|
||||
# Prepare logprobs, based on delta mode
|
||||
logprobs = self.logprobs_processor.logprobs
|
||||
if delta and logprobs:
|
||||
logprobs = logprobs[-len(token_ids):]
|
||||
|
||||
return CompletionOutput(
|
||||
index=self.request_index,
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
logprobs=logprobs,
|
||||
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
|
||||
finish_reason=str(finish_reason) if finished else None,
|
||||
stop_reason=stop_reason if finished else None)
|
||||
|
||||
|
||||
class OutputProcessor:
|
||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||
@ -93,9 +184,6 @@ class OutputProcessor:
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.lora_states = LoRARequestStates()
|
||||
|
||||
def is_request_active(self, request_id: str) -> bool:
|
||||
return request_id in self.request_states
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
return len(self.request_states)
|
||||
|
||||
@ -114,6 +202,8 @@ class OutputProcessor:
|
||||
def add_request(
|
||||
self,
|
||||
request: EngineCoreRequest,
|
||||
parent_req: Optional[ParentRequest] = None,
|
||||
request_index: int = 0,
|
||||
queue: Optional[asyncio.Queue[RequestOutput]] = None,
|
||||
) -> None:
|
||||
request_id = request.request_id
|
||||
@ -123,6 +213,8 @@ class OutputProcessor:
|
||||
req_state = RequestState.from_new_request(
|
||||
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
|
||||
request=request,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
queue=queue,
|
||||
log_stats=self.log_stats)
|
||||
self.request_states[request_id] = req_state
|
||||
@ -202,8 +294,8 @@ class OutputProcessor:
|
||||
req_state.logprobs_processor.update_from_output(engine_core_output)
|
||||
|
||||
# 4) Create and handle RequestOutput objects.
|
||||
if request_output := self._make_request_output(
|
||||
req_state, new_token_ids, finish_reason, stop_reason):
|
||||
if request_output := req_state.make_request_output(
|
||||
new_token_ids, finish_reason, stop_reason):
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
req_state.queue.put_nowait(request_output)
|
||||
@ -211,18 +303,17 @@ class OutputProcessor:
|
||||
# LLMEngine: return list of RequestOutputs.
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Free completed requests.
|
||||
if request_output.finished:
|
||||
self.request_states.pop(req_id)
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
# Free completed requests.
|
||||
if finish_reason is not None:
|
||||
self.request_states.pop(req_id)
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(req_state, request_output,
|
||||
finish_reason,
|
||||
iteration_stats)
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(req_state, finish_reason,
|
||||
iteration_stats)
|
||||
|
||||
self.lora_states.update_iteration_stats(iteration_stats)
|
||||
|
||||
@ -249,7 +340,6 @@ class OutputProcessor:
|
||||
req_state.stats, lora_stats)
|
||||
|
||||
def _update_stats_from_finished(self, req_state: RequestState,
|
||||
request_output: RequestOutput,
|
||||
finish_reason: Optional[FinishReason],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
if iteration_stats is None:
|
||||
@ -257,55 +347,8 @@ class OutputProcessor:
|
||||
|
||||
assert finish_reason is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_finished_request(finish_reason,
|
||||
request_output,
|
||||
req_state.stats)
|
||||
iteration_stats.update_from_finished_request(
|
||||
finish_reason=finish_reason,
|
||||
num_prompt_tokens=len(req_state.prompt_token_ids),
|
||||
req_stats=req_state.stats)
|
||||
self.lora_states.finish_request(req_state)
|
||||
|
||||
@staticmethod
|
||||
def _make_request_output(
|
||||
request_state: RequestState,
|
||||
new_token_ids: list[int],
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Union[int, str, None],
|
||||
) -> Optional[RequestOutput]:
|
||||
|
||||
finished = finish_reason is not None
|
||||
output_kind = request_state.output_kind
|
||||
# In follow up, we will switch to invariant where EngineCore
|
||||
# does not stream partial prefills.
|
||||
if not finished and (request_state.is_prefilling
|
||||
or output_kind == RequestOutputKind.FINAL_ONLY):
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
detokenizer = request_state.detokenizer
|
||||
logprobs_processor = request_state.logprobs_processor
|
||||
|
||||
delta = output_kind == RequestOutputKind.DELTA
|
||||
logprobs = logprobs_processor.logprobs
|
||||
if delta:
|
||||
if logprobs:
|
||||
logprobs = logprobs[-len(new_token_ids):]
|
||||
# Side effect: logprobs processor forgets prompt logprobs
|
||||
prompt_logprobs = logprobs_processor.pop_prompt_logprobs()
|
||||
else:
|
||||
prompt_logprobs = logprobs_processor.prompt_logprobs
|
||||
|
||||
request_output = RequestOutput.new(
|
||||
request_id=request_state.request_id,
|
||||
prompt=request_state.prompt,
|
||||
prompt_token_ids=request_state.prompt_token_ids,
|
||||
text=detokenizer.get_next_output_text(finished, delta),
|
||||
token_ids=new_token_ids if delta else detokenizer.output_token_ids,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
cumulative_logprob=logprobs_processor.cumulative_logprob,
|
||||
finished=finished,
|
||||
)
|
||||
if finished:
|
||||
completion_output = request_output.outputs[0]
|
||||
completion_output.finish_reason = str(finish_reason)
|
||||
completion_output.stop_reason = stop_reason
|
||||
|
||||
return request_output
|
||||
|
||||
@ -1,69 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from copy import copy
|
||||
from typing import Optional, Protocol, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class AsyncGenerateMethodType(Protocol):
|
||||
|
||||
def __call__(self,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0) -> AsyncGenerator[RequestOutput, None]:
|
||||
...
|
||||
|
||||
|
||||
class SyncAddRequestMethodType(Protocol):
|
||||
|
||||
def __call__(self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0) -> None:
|
||||
...
|
||||
|
||||
|
||||
class ParallelSamplingRequest:
|
||||
class ParentRequest:
|
||||
"""Info, state & processing for parallel sampling request.
|
||||
|
||||
|
||||
Store parent request ID and sampling params.
|
||||
Facilitate generating child request sampling params.
|
||||
Transform child request outputs into parent request
|
||||
outputs.
|
||||
When stream mode is disabled, then `self.request_output`
|
||||
aggregates child request completions.
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
sampling_params: SamplingParams
|
||||
|
||||
# To aggregate child completions when not streaming
|
||||
output_aggregator: Optional[RequestOutput]
|
||||
|
||||
# To efficiently obtain child sampling params
|
||||
cached_child_sampling_params: Optional[SamplingParams]
|
||||
request_output: Optional[RequestOutput]
|
||||
num_finished_completions: int
|
||||
|
||||
def __init__(self, request_id: str,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
self.request_id = request_id
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.output_aggregator = None
|
||||
self.cached_child_sampling_params = None
|
||||
self.request_output = None
|
||||
self.num_finished_completions = 0
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
request_id: str,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
) -> Optional['ParentRequest']:
|
||||
if not isinstance(params, SamplingParams) or params.n == 1:
|
||||
return None
|
||||
return cls(request_id, params)
|
||||
|
||||
def _get_child_sampling_params(
|
||||
self,
|
||||
@ -96,47 +73,6 @@ class ParallelSamplingRequest:
|
||||
child_sampling_params.seed = seed + index
|
||||
return child_sampling_params
|
||||
|
||||
def _add_output(
|
||||
self,
|
||||
child_req_output: RequestOutput,
|
||||
index: int,
|
||||
) -> None:
|
||||
"""Aggregate a parallel sampling child
|
||||
request output.
|
||||
|
||||
Non-stream-mode (`output_kind == FINAL_ONLY`)
|
||||
only. Inject correct parent request ID and
|
||||
completion index.
|
||||
|
||||
Args:
|
||||
child_req_output: a single request output
|
||||
from a parallel sampling
|
||||
child request.
|
||||
index: index within `n` child
|
||||
"""
|
||||
self.num_finished_completions += 1
|
||||
new_completion = child_req_output.outputs[0]
|
||||
new_completion.index = index
|
||||
if self.request_output is None:
|
||||
# Save the first request output; reinstate
|
||||
# original request ID; metrics are not
|
||||
# supported for parallel sampling
|
||||
child_req_output.request_id = self.request_id
|
||||
child_req_output.metrics = None
|
||||
self.request_output = child_req_output
|
||||
else:
|
||||
# Aggregate additional completion into request output
|
||||
# Note: will be sorted by index later
|
||||
self.request_output.outputs.append(new_completion)
|
||||
|
||||
def _get_final_request_output(self) -> RequestOutput:
|
||||
"""Invariant: parent completion outputs sorted by index"""
|
||||
assert self.request_output is not None
|
||||
self.request_output.finished = True
|
||||
self.request_output.outputs = sorted(self.request_output.outputs,
|
||||
key=lambda x: x.index)
|
||||
return self.request_output
|
||||
|
||||
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
|
||||
"""Get child request ID and sampling params.
|
||||
|
||||
@ -149,227 +85,35 @@ class ParallelSamplingRequest:
|
||||
return (f"{index}_{self.request_id}",
|
||||
self._get_child_sampling_params(index))
|
||||
|
||||
def process_output(
|
||||
self,
|
||||
child_req_output: RequestOutput,
|
||||
index: int,
|
||||
) -> Optional[RequestOutput]:
|
||||
"""Filter, aggregate and transform parallel sampling
|
||||
child request outputs.
|
||||
|
||||
If the parent request has `stream=false`
|
||||
(`output_kind == FINAL_ONLY`), each child will also have
|
||||
`output_kind == FINAL_ONLY`. All child request outputs
|
||||
must be aggregated into a single request output, with
|
||||
multiple completions. This request output is only returned
|
||||
once `n` completions are aggregated.
|
||||
|
||||
If the parent request has `stream=true`
|
||||
(`output_kind == DELTA`), each child will also have
|
||||
`output_kind == DELTA`. All child request outputs
|
||||
must be streamed directly to the caller.
|
||||
|
||||
Args:
|
||||
child_req_output: a single child request output
|
||||
index: index within `n` child requests
|
||||
|
||||
Returns:
|
||||
`None`, unless a processed request output is ready to
|
||||
send back to the caller.
|
||||
"""
|
||||
if self.output_kind != RequestOutputKind.FINAL_ONLY:
|
||||
# stream=true: return child completions immediately
|
||||
child_req_output.request_id = self.request_id
|
||||
child_req_output.outputs[0].index = index
|
||||
if child_req_output.finished:
|
||||
# Parent request is complete if all child requests are
|
||||
# complete.
|
||||
self.num_finished_completions += 1
|
||||
child_req_output.finished = (
|
||||
self.num_finished_completions == self.n)
|
||||
return child_req_output
|
||||
|
||||
# stream=false: aggregate child completions
|
||||
self._add_output(child_req_output, index)
|
||||
if self.num_finished_completions == self.n:
|
||||
# Return aggregated request output after obtaining
|
||||
# all completions
|
||||
return self._get_final_request_output()
|
||||
return None
|
||||
|
||||
async def wrap_child_async_generator(
|
||||
self,
|
||||
child_gen: AsyncGenerator[RequestOutput, None],
|
||||
index: int,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Output generator for a single parallel sampling
|
||||
child request.
|
||||
|
||||
Each parallel sampling request triggers at
|
||||
least two child requests. This generator
|
||||
yields zero or more request outputs to
|
||||
return to the caller, as they become
|
||||
available.
|
||||
|
||||
Args:
|
||||
child_gen: generator for child request
|
||||
outputs.
|
||||
index: index within the `n` child requests
|
||||
|
||||
Returns:
|
||||
Yields zero or more request outputs to return
|
||||
to the caller.
|
||||
"""
|
||||
async for out in child_gen:
|
||||
if req_out := self.process_output(out, index):
|
||||
yield req_out
|
||||
|
||||
@property
|
||||
def n(self) -> int:
|
||||
return self.sampling_params.n
|
||||
|
||||
@property
|
||||
def output_kind(self) -> RequestOutputKind:
|
||||
return self.sampling_params.output_kind
|
||||
|
||||
|
||||
class SyncParallelSamplingManager:
|
||||
|
||||
def __init__(self):
|
||||
# Parent req ID -> parent request manager
|
||||
self.parent_reqs: dict[str, ParallelSamplingRequest] = {}
|
||||
# Child req ID -> (child req index, parent req ID)
|
||||
self.child_reqs: dict[str, tuple[int, str]] = {}
|
||||
|
||||
def _register_parent_request(self, req: ParallelSamplingRequest) -> None:
|
||||
"""Register parallel sampling parent request."""
|
||||
self.parent_reqs[req.request_id] = req
|
||||
|
||||
def _register_child_request(self, req_id: str, child_req_id: str,
|
||||
index: int) -> None:
|
||||
"""Register parallel sampling child request with parent.
|
||||
|
||||
Args:
|
||||
req_id: parent request ID
|
||||
child_req_id: child request ID
|
||||
index: child request index within `n` child requests
|
||||
"""
|
||||
self.child_reqs[child_req_id] = (index, req_id)
|
||||
|
||||
def get_num_unfinished_requests(self, num_core_reqs: int) -> int:
|
||||
"""Get the number of unfinished requests, correcting for parallel
|
||||
sampling.
|
||||
|
||||
Args:
|
||||
num_core_reqs: The number of unfinished requests in the engine core.
|
||||
|
||||
Returns:
|
||||
Number of unfinished requests, where each parallel sampling req
|
||||
counts as 1
|
||||
"""
|
||||
return num_core_reqs + len(self.parent_reqs) - len(self.child_reqs)
|
||||
|
||||
def add_request_parallel_sampling(
|
||||
def make_request_output(
|
||||
self,
|
||||
add_request: SyncAddRequestMethodType,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
"""Add sync parallel sampling request."""
|
||||
req = ParallelSamplingRequest(request_id, params)
|
||||
self._register_parent_request(req)
|
||||
# Add n child requests with unique request IDs & random seeds and n=1
|
||||
for idx in range(req.n):
|
||||
child_req_id, child_params = req.get_child_info(idx)
|
||||
self._register_child_request(request_id, child_req_id, idx)
|
||||
add_request(request_id=child_req_id,
|
||||
prompt=prompt,
|
||||
params=child_params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority) # type: ignore
|
||||
final_only: bool,
|
||||
completion_output: CompletionOutput,
|
||||
new_request_output: Callable[[str], RequestOutput],
|
||||
) -> Optional[RequestOutput]:
|
||||
# Use an existing RequestOutput if we're aggregating
|
||||
request_output = self.output_aggregator
|
||||
|
||||
def step(
|
||||
self,
|
||||
outputs: list[RequestOutput],
|
||||
) -> list[RequestOutput]:
|
||||
"""Build parallel sampling request outputs.
|
||||
|
||||
Extract child request outputs, aggregate them
|
||||
into parent request output, and return parent
|
||||
output when complete.
|
||||
# Make new RequestOutput otherwise
|
||||
if request_output is None:
|
||||
request_output = new_request_output(self.request_id)
|
||||
|
||||
Do not modify `n=1` requests.
|
||||
# Add a new completion
|
||||
request_output.outputs.append(completion_output)
|
||||
|
||||
Args:
|
||||
outputs: step request outputs. Mix of child request
|
||||
outputs & `n=1` request outputs.
|
||||
# If not streaming, aggregate until all child requests complete
|
||||
if final_only and len(request_output.outputs) != self.n:
|
||||
self.output_aggregator = request_output
|
||||
return None
|
||||
|
||||
Return:
|
||||
List of parallel sampling parent request outputs &
|
||||
unmodified `n=1` request outputs passed-thru from input.
|
||||
"""
|
||||
if not (self.parent_reqs and outputs):
|
||||
# Return unmodified
|
||||
return outputs
|
||||
agg_outputs = []
|
||||
for output in outputs:
|
||||
req_id = output.request_id
|
||||
if child_req_entry := self.child_reqs.get(req_id, None):
|
||||
# For each parallel sampling child request output:
|
||||
(index, parent_req_id) = child_req_entry
|
||||
req = self.parent_reqs[parent_req_id]
|
||||
# Update parallel sampling request
|
||||
if out := req.process_output(output, index):
|
||||
# Return parent request output if complete;
|
||||
# cleanup parent request bookkeeping.
|
||||
agg_outputs.append(out)
|
||||
del self.parent_reqs[parent_req_id]
|
||||
# Cleanup child request bookkeeping.
|
||||
del self.child_reqs[req_id]
|
||||
else:
|
||||
# Not a parallel sampling request output
|
||||
agg_outputs.append(output)
|
||||
return agg_outputs
|
||||
# We're done aggregating
|
||||
self.output_aggregator = None
|
||||
|
||||
|
||||
async def generate_parallel_sampling_async(
|
||||
generate: AsyncGenerateMethodType,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate completions for async parallel sampling requests."""
|
||||
parent_req = ParallelSamplingRequest(request_id, sampling_params)
|
||||
|
||||
# Aggregate generators for n child requests
|
||||
gens: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
for idx in range(parent_req.n):
|
||||
child_req_id, child_params = parent_req.get_child_info(idx)
|
||||
child_gen = generate(
|
||||
prompt=prompt,
|
||||
sampling_params=child_params,
|
||||
request_id=child_req_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
) # type: ignore
|
||||
gen = parent_req.wrap_child_async_generator(child_gen, idx)
|
||||
gens.append(gen)
|
||||
|
||||
# Merge generators
|
||||
async for _, out in merge_async_iterators(*gens):
|
||||
yield out
|
||||
# Parent completion output list must be sorted by index
|
||||
request_output.outputs = sorted(request_output.outputs,
|
||||
key=lambda x: x.index)
|
||||
return request_output
|
||||
|
||||
@ -5,7 +5,6 @@ from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
|
||||
from vllm.v1.output_processor import RequestState
|
||||
|
||||
@ -150,7 +149,7 @@ class IterationStats:
|
||||
self.num_preempted_reqs += 1
|
||||
|
||||
def update_from_finished_request(self, finish_reason: "FinishReason",
|
||||
request_output: "RequestOutput",
|
||||
num_prompt_tokens: int,
|
||||
req_stats: RequestStateStats):
|
||||
e2e_latency = self._time_since(req_stats.arrival_time)
|
||||
|
||||
@ -172,7 +171,7 @@ class IterationStats:
|
||||
finished_req = \
|
||||
FinishedRequestStats(finish_reason=finish_reason,
|
||||
e2e_latency=e2e_latency,
|
||||
num_prompt_tokens=len(request_output.prompt_token_ids),
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
num_generation_tokens=req_stats.num_generation_tokens,
|
||||
queued_time=queued_time,
|
||||
prefill_time=prefill_time,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user