mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 10:51:51 +08:00
[Perf] Add skip_clone to SamplingParams for internal request handling (#31041)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
09dc7c690c
commit
bc5ef333e0
@ -60,7 +60,8 @@ async def generate(request: Request) -> Response:
|
||||
async def _generate(request_dict: dict, raw_request: Request) -> Response:
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
# Since SamplingParams is created fresh per request, safe to skip clone
|
||||
sampling_params = SamplingParams(**request_dict, skip_clone=True)
|
||||
request_id = random_uuid()
|
||||
|
||||
assert engine is not None
|
||||
|
||||
@ -642,7 +642,10 @@ class LLM:
|
||||
# following the huggingface transformers implementation
|
||||
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
|
||||
beam_search_params = SamplingParams(
|
||||
logprobs=2 * beam_width, max_tokens=1, temperature=temperature
|
||||
logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
skip_clone=True, # Internal beam search, safe to skip clone
|
||||
)
|
||||
instances: list[BeamSearchInstance] = []
|
||||
|
||||
|
||||
@ -474,6 +474,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
),
|
||||
structured_outputs=structured_outputs,
|
||||
logit_bias=self.logit_bias,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
def is_include_output_logprobs(self) -> bool:
|
||||
@ -876,6 +877,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
bad_words=self.bad_words,
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=extra_args or None,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -1316,6 +1318,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=extra_args or None,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -2182,6 +2185,7 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
extra_args=self.vllm_xargs,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -2409,6 +2413,7 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@ -219,6 +219,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
dummy_params = SamplingParams(
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
skip_clone=True, # Internal warmup, safe to skip clone
|
||||
)
|
||||
|
||||
# Process the dummy input through the input processor
|
||||
|
||||
@ -211,6 +211,12 @@ class SamplingParams(
|
||||
set to an integer k, will use only the last k tokens from the prompt
|
||||
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
|
||||
skip_clone: bool = False
|
||||
"""Internal flag indicating that this SamplingParams instance is safe to
|
||||
reuse without cloning. When True, clone() will return self without
|
||||
performing a deep copy. This should only be set when the params object
|
||||
is guaranteed to be dedicated to a single request and won't be modified
|
||||
in ways that would affect other uses."""
|
||||
|
||||
# The below fields are not supposed to be used as an input.
|
||||
# They are set in post_init.
|
||||
@ -270,6 +276,7 @@ class SamplingParams(
|
||||
logit_bias: dict[int, float] | dict[str, float] | None = None,
|
||||
allowed_token_ids: list[int] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
skip_clone: bool = False,
|
||||
) -> "SamplingParams":
|
||||
if logit_bias is not None:
|
||||
# Convert token_id to integer
|
||||
@ -310,6 +317,7 @@ class SamplingParams(
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids=allowed_token_ids,
|
||||
extra_args=extra_args,
|
||||
skip_clone=skip_clone,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@ -540,8 +548,13 @@ class SamplingParams(
|
||||
data that is expensive to copy. However, if not copied, the processor
|
||||
needs to support parallel decoding for multiple sequences
|
||||
See https://github.com/vllm-project/vllm/issues/3087
|
||||
|
||||
If skip_clone is True, uses shallow copy instead of deep copy.
|
||||
"""
|
||||
|
||||
if self.skip_clone:
|
||||
return copy.copy(self)
|
||||
|
||||
logit_processor_refs = (
|
||||
None
|
||||
if self.logits_processors is None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user