mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:44:54 +08:00
[Frontend] Expose custom args in OpenAI APIs (#16862)
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Signed-off-by: Andrew Feldman <afeldman@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
ed33349738
commit
dfada85eee
@ -4,12 +4,12 @@ import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size_triton,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||
|
||||
@ -326,8 +326,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
)
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
)
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
@ -414,6 +415,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.")
|
||||
|
||||
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
|
||||
default=None,
|
||||
description=("Additional request parameters with string or "
|
||||
"numeric values, used by custom extensions."),
|
||||
)
|
||||
|
||||
# --8<-- [end:chat-completion-extra-params]
|
||||
|
||||
# Default sampling parameters for chat completion requests
|
||||
@ -523,6 +530,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
structural_tag=self.structural_tag,
|
||||
)
|
||||
|
||||
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
|
||||
if self.kv_transfer_params:
|
||||
# Pass in kv_transfer_params via extra_args
|
||||
extra_args["kv_transfer_params"] = self.kv_transfer_params
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
@ -553,8 +564,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
logit_bias=self.logit_bias,
|
||||
bad_words= self.bad_words,
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=({"kv_transfer_params": self.kv_transfer_params}
|
||||
if self.kv_transfer_params else None))
|
||||
extra_args=extra_args or None,
|
||||
)
|
||||
|
||||
def _get_guided_json_from_tool(
|
||||
self) -> Optional[Union[str, dict, BaseModel]]:
|
||||
@ -871,6 +882,12 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.")
|
||||
|
||||
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
|
||||
default=None,
|
||||
description=("Additional request parameters with string or "
|
||||
"numeric values, used by custom extensions."),
|
||||
)
|
||||
|
||||
# --8<-- [end:completion-extra-params]
|
||||
|
||||
# Default sampling parameters for completion requests
|
||||
@ -968,6 +985,10 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
whitespace_pattern=self.guided_whitespace_pattern,
|
||||
)
|
||||
|
||||
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
|
||||
if self.kv_transfer_params:
|
||||
# Pass in kv_transfer_params via extra_args
|
||||
extra_args["kv_transfer_params"] = self.kv_transfer_params
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
@ -997,8 +1018,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=({"kv_transfer_params": self.kv_transfer_params}
|
||||
if self.kv_transfer_params else None))
|
||||
extra_args=extra_args or None,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@ -1117,8 +1138,9 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
)
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
)
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
@ -1623,8 +1645,9 @@ class TokenizeChatRequest(OpenAIBaseModel):
|
||||
)
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
)
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
@ -1736,6 +1759,12 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: Optional[bool] = False
|
||||
stream_continuous_usage_stats: Optional[bool] = False
|
||||
|
||||
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
|
||||
default=None,
|
||||
description=("Additional request parameters with string or "
|
||||
"numeric values, used by custom extensions."),
|
||||
)
|
||||
# --8<-- [end:transcription-extra-params]
|
||||
|
||||
# --8<-- [start:transcription-sampling-params]
|
||||
@ -1823,7 +1852,8 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
presence_penalty=self.presence_penalty,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY)
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
extra_args=self.vllm_xargs)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@ -198,8 +198,8 @@ class SamplingParams(
|
||||
processor which only retains scores for the given token ids.
|
||||
Defaults to None.
|
||||
extra_args: Arbitrary additional args, that can be used by custom
|
||||
sampling implementations. Not used by any in-tree sampling
|
||||
implementations.
|
||||
sampling implementations, plugins, etc. Not used by any in-tree
|
||||
sampling implementations.
|
||||
"""
|
||||
|
||||
n: int = 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user