mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 08:31:18 +08:00
[Frontend] Expose custom args in OpenAI APIs (#16862)
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Signed-off-by: Andrew Feldman <afeldman@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
ed33349738
commit
dfada85eee
@ -4,12 +4,12 @@ import argparse
|
|||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||||
moe_align_block_size_triton,
|
moe_align_block_size_triton,
|
||||||
)
|
)
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
|
||||||
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||||
|
|||||||
@ -326,8 +326,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
)
|
)
|
||||||
chat_template_kwargs: Optional[dict[str, Any]] = Field(
|
chat_template_kwargs: Optional[dict[str, Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=("Additional kwargs to pass to the template renderer. "
|
description=(
|
||||||
"Will be accessible by the chat template."),
|
"Additional keyword args to pass to the template renderer. "
|
||||||
|
"Will be accessible by the chat template."),
|
||||||
)
|
)
|
||||||
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -414,6 +415,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="KVTransfer parameters used for disaggregated serving.")
|
description="KVTransfer parameters used for disaggregated serving.")
|
||||||
|
|
||||||
|
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional request parameters with string or "
|
||||||
|
"numeric values, used by custom extensions."),
|
||||||
|
)
|
||||||
|
|
||||||
# --8<-- [end:chat-completion-extra-params]
|
# --8<-- [end:chat-completion-extra-params]
|
||||||
|
|
||||||
# Default sampling parameters for chat completion requests
|
# Default sampling parameters for chat completion requests
|
||||||
@ -523,6 +530,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
structural_tag=self.structural_tag,
|
structural_tag=self.structural_tag,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
|
||||||
|
if self.kv_transfer_params:
|
||||||
|
# Pass in kv_transfer_params via extra_args
|
||||||
|
extra_args["kv_transfer_params"] = self.kv_transfer_params
|
||||||
return SamplingParams.from_optional(
|
return SamplingParams.from_optional(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
best_of=self.best_of,
|
best_of=self.best_of,
|
||||||
@ -553,8 +564,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
logit_bias=self.logit_bias,
|
logit_bias=self.logit_bias,
|
||||||
bad_words= self.bad_words,
|
bad_words= self.bad_words,
|
||||||
allowed_token_ids=self.allowed_token_ids,
|
allowed_token_ids=self.allowed_token_ids,
|
||||||
extra_args=({"kv_transfer_params": self.kv_transfer_params}
|
extra_args=extra_args or None,
|
||||||
if self.kv_transfer_params else None))
|
)
|
||||||
|
|
||||||
def _get_guided_json_from_tool(
|
def _get_guided_json_from_tool(
|
||||||
self) -> Optional[Union[str, dict, BaseModel]]:
|
self) -> Optional[Union[str, dict, BaseModel]]:
|
||||||
@ -871,6 +882,12 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="KVTransfer parameters used for disaggregated serving.")
|
description="KVTransfer parameters used for disaggregated serving.")
|
||||||
|
|
||||||
|
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional request parameters with string or "
|
||||||
|
"numeric values, used by custom extensions."),
|
||||||
|
)
|
||||||
|
|
||||||
# --8<-- [end:completion-extra-params]
|
# --8<-- [end:completion-extra-params]
|
||||||
|
|
||||||
# Default sampling parameters for completion requests
|
# Default sampling parameters for completion requests
|
||||||
@ -968,6 +985,10 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
whitespace_pattern=self.guided_whitespace_pattern,
|
whitespace_pattern=self.guided_whitespace_pattern,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
|
||||||
|
if self.kv_transfer_params:
|
||||||
|
# Pass in kv_transfer_params via extra_args
|
||||||
|
extra_args["kv_transfer_params"] = self.kv_transfer_params
|
||||||
return SamplingParams.from_optional(
|
return SamplingParams.from_optional(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
best_of=self.best_of,
|
best_of=self.best_of,
|
||||||
@ -997,8 +1018,8 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
guided_decoding=guided_decoding,
|
guided_decoding=guided_decoding,
|
||||||
logit_bias=self.logit_bias,
|
logit_bias=self.logit_bias,
|
||||||
allowed_token_ids=self.allowed_token_ids,
|
allowed_token_ids=self.allowed_token_ids,
|
||||||
extra_args=({"kv_transfer_params": self.kv_transfer_params}
|
extra_args=extra_args or None,
|
||||||
if self.kv_transfer_params else None))
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1117,8 +1138,9 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
|||||||
)
|
)
|
||||||
chat_template_kwargs: Optional[dict[str, Any]] = Field(
|
chat_template_kwargs: Optional[dict[str, Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=("Additional kwargs to pass to the template renderer. "
|
description=(
|
||||||
"Will be accessible by the chat template."),
|
"Additional keyword args to pass to the template renderer. "
|
||||||
|
"Will be accessible by the chat template."),
|
||||||
)
|
)
|
||||||
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -1623,8 +1645,9 @@ class TokenizeChatRequest(OpenAIBaseModel):
|
|||||||
)
|
)
|
||||||
chat_template_kwargs: Optional[dict[str, Any]] = Field(
|
chat_template_kwargs: Optional[dict[str, Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=("Additional kwargs to pass to the template renderer. "
|
description=(
|
||||||
"Will be accessible by the chat template."),
|
"Additional keyword args to pass to the template renderer. "
|
||||||
|
"Will be accessible by the chat template."),
|
||||||
)
|
)
|
||||||
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -1736,6 +1759,12 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
# Flattened stream option to simplify form data.
|
# Flattened stream option to simplify form data.
|
||||||
stream_include_usage: Optional[bool] = False
|
stream_include_usage: Optional[bool] = False
|
||||||
stream_continuous_usage_stats: Optional[bool] = False
|
stream_continuous_usage_stats: Optional[bool] = False
|
||||||
|
|
||||||
|
vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional request parameters with string or "
|
||||||
|
"numeric values, used by custom extensions."),
|
||||||
|
)
|
||||||
# --8<-- [end:transcription-extra-params]
|
# --8<-- [end:transcription-extra-params]
|
||||||
|
|
||||||
# --8<-- [start:transcription-sampling-params]
|
# --8<-- [start:transcription-sampling-params]
|
||||||
@ -1823,7 +1852,8 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
presence_penalty=self.presence_penalty,
|
presence_penalty=self.presence_penalty,
|
||||||
output_kind=RequestOutputKind.DELTA
|
output_kind=RequestOutputKind.DELTA
|
||||||
if self.stream \
|
if self.stream \
|
||||||
else RequestOutputKind.FINAL_ONLY)
|
else RequestOutputKind.FINAL_ONLY,
|
||||||
|
extra_args=self.vllm_xargs)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -198,8 +198,8 @@ class SamplingParams(
|
|||||||
processor which only retains scores for the given token ids.
|
processor which only retains scores for the given token ids.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
extra_args: Arbitrary additional args, that can be used by custom
|
extra_args: Arbitrary additional args, that can be used by custom
|
||||||
sampling implementations. Not used by any in-tree sampling
|
sampling implementations, plugins, etc. Not used by any in-tree
|
||||||
implementations.
|
sampling implementations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n: int = 1
|
n: int = 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user