mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 02:15:57 +08:00
137 lines
4.6 KiB
Python
137 lines
4.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from copy import copy
|
|
from typing import Optional
|
|
|
|
from vllm.outputs import CompletionOutput
|
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
|
from vllm.v1.metrics.stats import IterationStats
|
|
|
|
|
|
class ParentRequest:
|
|
"""Info, state & processing for parallel sampling request.
|
|
|
|
Store parent request ID and sampling params.
|
|
Facilitate generating child request sampling params.
|
|
"""
|
|
|
|
request_id: str
|
|
sampling_params: SamplingParams
|
|
|
|
# To track the completion of child requests
|
|
child_requests: set[str]
|
|
|
|
# To aggregate child completions when not streaming
|
|
output_aggregator: list[CompletionOutput]
|
|
|
|
# To find the max number of generated tokens across all children
|
|
max_num_generation_tokens: int
|
|
|
|
# To efficiently obtain child sampling params
|
|
cached_child_sampling_params: Optional[SamplingParams]
|
|
|
|
def __init__(self, request_id: str, sampling_params: SamplingParams) -> None:
|
|
self.request_id = request_id
|
|
self.sampling_params = sampling_params
|
|
|
|
self.child_requests = set()
|
|
self.output_aggregator = (
|
|
[None] * sampling_params.n
|
|
if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
|
|
else []
|
|
)
|
|
self.max_num_generation_tokens = 0
|
|
self.cached_child_sampling_params = None
|
|
|
|
def _get_child_sampling_params(
|
|
self,
|
|
index: int,
|
|
) -> SamplingParams:
|
|
"""Efficiently obtain child `sampling_params`
|
|
|
|
If `sampling_params.seed` is not `None` then
|
|
each child request requires a unique clone of
|
|
parent `sampling_params` with a unique seed.
|
|
|
|
Args:
|
|
index: index within `n` child requests
|
|
|
|
Returns:
|
|
Child `sampling_params` instance.
|
|
"""
|
|
seed = self.sampling_params.seed
|
|
if self.cached_child_sampling_params:
|
|
# Reuse child sampling_params data structure
|
|
return self.cached_child_sampling_params
|
|
# Build child sampling_params
|
|
child_sampling_params = copy(self.sampling_params)
|
|
child_sampling_params.n = 1
|
|
if seed is None:
|
|
# Cache child sampling_params for later reuse
|
|
self.cached_child_sampling_params = child_sampling_params
|
|
else:
|
|
# Each child gets a clone with a unique seed
|
|
child_sampling_params.seed = seed + index
|
|
return child_sampling_params
|
|
|
|
def get_child_info(self, index: int) -> tuple[str, SamplingParams]:
|
|
"""Get child request ID and sampling params.
|
|
|
|
Args:
|
|
index: index within `n` child requests.
|
|
|
|
Returns:
|
|
(request ID, sampling_params) tuple
|
|
"""
|
|
child_req_id = f"{index}_{self.request_id}"
|
|
self.child_requests.add(child_req_id)
|
|
return child_req_id, self._get_child_sampling_params(index)
|
|
|
|
@property
|
|
def n(self) -> int:
|
|
return self.sampling_params.n
|
|
|
|
def get_outputs(
|
|
self,
|
|
child_request_id: str,
|
|
completion_output: CompletionOutput,
|
|
) -> tuple[str, list[CompletionOutput], bool]:
|
|
if completion_output.finished():
|
|
self.child_requests.remove(child_request_id)
|
|
|
|
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
|
|
# If streaming, just return the current output.
|
|
outputs = [completion_output]
|
|
else:
|
|
# If not streaming, aggregate the n final outputs.
|
|
self.output_aggregator[completion_output.index] = completion_output
|
|
outputs = [] if self.child_requests else self.output_aggregator
|
|
|
|
finished = not self.child_requests
|
|
return self.request_id, outputs, finished
|
|
|
|
def observe_num_generation_tokens(self, num_generation_tokens: int):
|
|
self.max_num_generation_tokens = max(
|
|
num_generation_tokens, self.max_num_generation_tokens
|
|
)
|
|
return self.max_num_generation_tokens
|
|
|
|
@staticmethod
|
|
def observe_finished_request(
|
|
parent_req: Optional["ParentRequest"],
|
|
iteration_stats: IterationStats,
|
|
num_generation_tokens: int,
|
|
):
|
|
n_param = parent_req.n if parent_req is not None else 1
|
|
|
|
if parent_req is not None:
|
|
num_generation_tokens = parent_req.observe_num_generation_tokens(
|
|
num_generation_tokens
|
|
)
|
|
|
|
# Child requests finished, we can now record to iteration stats
|
|
if parent_req is None or not parent_req.child_requests:
|
|
iteration_stats.max_num_generation_tokens_iter.append(num_generation_tokens)
|
|
iteration_stats.n_params_iter.append(n_param)
|