[Perf] Add skip_clone to SamplingParams for internal request handling (#31041)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-12-24 17:35:57 -05:00 committed by GitHub
parent 09dc7c690c
commit bc5ef333e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 25 additions and 2 deletions

View File

@ -60,7 +60,8 @@ async def generate(request: Request) -> Response:
async def _generate(request_dict: dict, raw_request: Request) -> Response:
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
# Since SamplingParams is created fresh per request, safe to skip clone
sampling_params = SamplingParams(**request_dict, skip_clone=True)
request_id = random_uuid()
assert engine is not None

View File

@ -642,7 +642,10 @@ class LLM:
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(
logprobs=2 * beam_width, max_tokens=1, temperature=temperature
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
skip_clone=True, # Internal beam search, safe to skip clone
)
instances: list[BeamSearchInstance] = []

View File

@ -474,6 +474,7 @@ class ResponsesRequest(OpenAIBaseModel):
),
structured_outputs=structured_outputs,
logit_bias=self.logit_bias,
skip_clone=True, # Created fresh per request, safe to skip clone
)
def is_include_output_logprobs(self) -> bool:
@ -876,6 +877,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
bad_words=self.bad_words,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@ -1316,6 +1318,7 @@ class CompletionRequest(OpenAIBaseModel):
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@ -2182,6 +2185,7 @@ class TranscriptionRequest(OpenAIBaseModel):
if self.stream
else RequestOutputKind.FINAL_ONLY,
extra_args=self.vllm_xargs,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")
@ -2409,6 +2413,7 @@ class TranslationRequest(OpenAIBaseModel):
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
skip_clone=True, # Created fresh per request, safe to skip clone
)
@model_validator(mode="before")

View File

@ -219,6 +219,7 @@ class OpenAISpeechToText(OpenAIServing):
dummy_params = SamplingParams(
max_tokens=1,
temperature=0.0,
skip_clone=True, # Internal warmup, safe to skip clone
)
# Process the dummy input through the input processor

View File

@ -211,6 +211,12 @@ class SamplingParams(
set to an integer k, will use only the last k tokens from the prompt
(i.e., left truncation). If set to `None`, truncation is disabled."""
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
skip_clone: bool = False
"""Internal flag indicating that this SamplingParams instance is safe to
reuse without cloning. When True, clone() will return self without
performing a deep copy. This should only be set when the params object
is guaranteed to be dedicated to a single request and won't be modified
in ways that would affect other uses."""
# The below fields are not supposed to be used as an input.
# They are set in post_init.
@ -270,6 +276,7 @@ class SamplingParams(
logit_bias: dict[int, float] | dict[str, float] | None = None,
allowed_token_ids: list[int] | None = None,
extra_args: dict[str, Any] | None = None,
skip_clone: bool = False,
) -> "SamplingParams":
if logit_bias is not None:
# Convert token_id to integer
@ -310,6 +317,7 @@ class SamplingParams(
logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids,
extra_args=extra_args,
skip_clone=skip_clone,
)
def __post_init__(self) -> None:
@ -540,8 +548,13 @@ class SamplingParams(
data that is expensive to copy. However, if not copied, the processor
needs to support parallel decoding for multiple sequences
See https://github.com/vllm-project/vllm/issues/3087
If skip_clone is True, uses shallow copy instead of deep copy.
"""
if self.skip_clone:
return copy.copy(self)
logit_processor_refs = (
None
if self.logits_processors is None