[Frontend] Expose custom args in OpenAI APIs (#16862)

Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
afeldman-nm 2025-06-18 20:41:11 -04:00 committed by GitHub
parent ed33349738
commit dfada85eee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 14 deletions

View File

@ -4,12 +4,12 @@ import argparse
import itertools
import torch
import triton
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size_triton,
)
from vllm.triton_utils import triton
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:

View File

@ -326,8 +326,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
)
chat_template_kwargs: Optional[dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."),
)
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
default=None,
@ -414,6 +415,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description="KVTransfer parameters used for disaggregated serving.")
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
default=None,
description=("Additional request parameters with string or "
"numeric values, used by custom extensions."),
)
# --8<-- [end:chat-completion-extra-params]
# Default sampling parameters for chat completion requests
@ -523,6 +530,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
structural_tag=self.structural_tag,
)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
@ -553,8 +564,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias=self.logit_bias,
bad_words= self.bad_words,
allowed_token_ids=self.allowed_token_ids,
extra_args=({"kv_transfer_params": self.kv_transfer_params}
if self.kv_transfer_params else None))
extra_args=extra_args or None,
)
def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]:
@ -871,6 +882,12 @@ class CompletionRequest(OpenAIBaseModel):
default=None,
description="KVTransfer parameters used for disaggregated serving.")
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
default=None,
description=("Additional request parameters with string or "
"numeric values, used by custom extensions."),
)
# --8<-- [end:completion-extra-params]
# Default sampling parameters for completion requests
@ -968,6 +985,10 @@ class CompletionRequest(OpenAIBaseModel):
whitespace_pattern=self.guided_whitespace_pattern,
)
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
@ -997,8 +1018,8 @@ class CompletionRequest(OpenAIBaseModel):
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids,
extra_args=({"kv_transfer_params": self.kv_transfer_params}
if self.kv_transfer_params else None))
extra_args=extra_args or None,
)
@model_validator(mode="before")
@classmethod
@ -1117,8 +1138,9 @@ class EmbeddingChatRequest(OpenAIBaseModel):
)
chat_template_kwargs: Optional[dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."),
)
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
default=None,
@ -1623,8 +1645,9 @@ class TokenizeChatRequest(OpenAIBaseModel):
)
chat_template_kwargs: Optional[dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."),
)
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
default=None,
@ -1736,6 +1759,12 @@ class TranscriptionRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False
stream_continuous_usage_stats: Optional[bool] = False
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
default=None,
description=("Additional request parameters with string or "
"numeric values, used by custom extensions."),
)
# --8<-- [end:transcription-extra-params]
# --8<-- [start:transcription-sampling-params]
@ -1823,7 +1852,8 @@ class TranscriptionRequest(OpenAIBaseModel):
presence_penalty=self.presence_penalty,
output_kind=RequestOutputKind.DELTA
if self.stream \
else RequestOutputKind.FINAL_ONLY)
else RequestOutputKind.FINAL_ONLY,
extra_args=self.vllm_xargs)
@model_validator(mode="before")
@classmethod

View File

@ -198,8 +198,8 @@ class SamplingParams(
processor which only retains scores for the given token ids.
Defaults to None.
extra_args: Arbitrary additional args, that can be used by custom
sampling implementations. Not used by any in-tree sampling
implementations.
sampling implementations, plugins, etc. Not used by any in-tree
sampling implementations.
"""
n: int = 1