mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 18:54:55 +08:00
[Frontend] Add logits_processors as an extra completion argument (#11150)
Signed-off-by: Brad Hilton <brad.hilton.nw@gmail.com>
This commit is contained in:
parent
3cb5769883
commit
9c3dadd1c9
@ -30,6 +30,7 @@ class MockModelConfig:
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
logits_processor_pattern = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -156,41 +156,45 @@ class ModelConfig:
|
||||
can not be gathered from the vllm arguments.
|
||||
override_pooler_config: Initialize non default pooling config or
|
||||
override default pooling config for the pooling model.
|
||||
logits_processor_pattern: Optional regex pattern specifying valid
|
||||
logits processor qualified names that can be passed with the
|
||||
`logits_processors` extra completion argument. Defaults to None,
|
||||
which allows no processors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
task: Union[TaskOption, Literal["draft"]],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
allowed_local_media_path: str = "",
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
mm_cache_preprocessor: bool = False,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
|
||||
def __init__(self,
|
||||
model: str,
|
||||
task: Union[TaskOption, Literal["draft"]],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
allowed_local_media_path: str = "",
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
mm_cache_preprocessor: bool = False,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||
logits_processor_pattern: Optional[str] = None) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
@ -316,6 +320,7 @@ class ModelConfig:
|
||||
self.task: Final = task
|
||||
|
||||
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
||||
self.logits_processor_pattern = logits_processor_pattern
|
||||
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
|
||||
@ -170,6 +170,7 @@ class EngineArgs:
|
||||
enable_chunked_prefill: Optional[bool] = None
|
||||
|
||||
guided_decoding_backend: str = 'xgrammar'
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
# Speculative decoding configuration.
|
||||
speculative_model: Optional[str] = None
|
||||
speculative_model_quantization: Optional[str] = None
|
||||
@ -374,6 +375,14 @@ class EngineArgs:
|
||||
'https://github.com/noamgat/lm-format-enforcer.'
|
||||
' Can be overridden per request via guided_decoding_backend'
|
||||
' parameter.')
|
||||
parser.add_argument(
|
||||
'--logits-processor-pattern',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help='Optional regex pattern specifying valid logits processor '
|
||||
'qualified names that can be passed with the `logits_processors` '
|
||||
'extra completion argument. Defaults to None, which allows no '
|
||||
'processors.')
|
||||
# Parallel arguments
|
||||
parser.add_argument(
|
||||
'--distributed-executor-backend',
|
||||
@ -975,7 +984,7 @@ class EngineArgs:
|
||||
mm_cache_preprocessor=self.mm_cache_preprocessor,
|
||||
override_neuron_config=self.override_neuron_config,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
)
|
||||
logits_processor_pattern=self.logits_processor_pattern)
|
||||
|
||||
def create_load_config(self) -> LoadConfig:
|
||||
return LoadConfig(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import re
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
@ -14,7 +15,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils import random_uuid, resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -148,6 +149,46 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
class LogitsProcessorConstructor(BaseModel):
|
||||
qualname: str
|
||||
args: Optional[List[Any]] = None
|
||||
kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
LogitsProcessors = List[Union[str, LogitsProcessorConstructor]]
|
||||
|
||||
|
||||
def get_logits_processors(processors: Optional[LogitsProcessors],
|
||||
pattern: Optional[str]) -> Optional[List[Any]]:
|
||||
if processors and pattern:
|
||||
logits_processors = []
|
||||
for processor in processors:
|
||||
qualname = processor if isinstance(processor,
|
||||
str) else processor.qualname
|
||||
if not re.match(pattern, qualname):
|
||||
raise ValueError(
|
||||
f"Logits processor '{qualname}' is not allowed by this "
|
||||
"server. See --logits-processor-pattern engine argument "
|
||||
"for more information.")
|
||||
try:
|
||||
logits_processor = resolve_obj_by_qualname(qualname)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Logits processor '{qualname}' could not be resolved: {e}"
|
||||
) from e
|
||||
if isinstance(processor, LogitsProcessorConstructor):
|
||||
logits_processor = logits_processor(*processor.args or [],
|
||||
**processor.kwargs or {})
|
||||
logits_processors.append(logits_processor)
|
||||
return logits_processors
|
||||
elif processors:
|
||||
raise ValueError(
|
||||
"The `logits_processors` argument is not supported by this "
|
||||
"server. See --logits-processor-pattern engine argugment "
|
||||
"for more information.")
|
||||
return None
|
||||
|
||||
|
||||
class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
@ -293,6 +334,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."))
|
||||
logits_processors: Optional[LogitsProcessors] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A list of either qualified names of logits processors, or "
|
||||
"constructor objects, to apply when sampling. A constructor is "
|
||||
"a JSON object with a required 'qualname' field specifying the "
|
||||
"qualified name of the processor class/factory, and optional "
|
||||
"'args' and 'kwargs' fields containing positional and keyword "
|
||||
"arguments. For example: {'qualname': "
|
||||
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||
"{'param': 'value'}}."))
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
@ -314,7 +366,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
if max_tokens is None:
|
||||
@ -364,6 +418,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
min_tokens=self.min_tokens,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
logits_processors=get_logits_processors(self.logits_processors,
|
||||
logits_processor_pattern),
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
@ -599,6 +655,17 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
logits_processors: Optional[LogitsProcessors] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A list of either qualified names of logits processors, or "
|
||||
"constructor objects, to apply when sampling. A constructor is "
|
||||
"a JSON object with a required 'qualname' field specifying the "
|
||||
"qualified name of the processor class/factory, and optional "
|
||||
"'args' and 'kwargs' fields containing positional and keyword "
|
||||
"arguments. For example: {'qualname': "
|
||||
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||
"{'param': 'value'}}."))
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
@ -619,7 +686,9 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
@ -665,6 +734,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
logits_processors=get_logits_processors(self.logits_processors,
|
||||
logits_processor_pattern),
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
|
||||
@ -197,7 +197,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
default_max_tokens,
|
||||
self.model_config.logits_processor_pattern)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request_prompts[i],
|
||||
|
||||
@ -123,7 +123,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
default_max_tokens,
|
||||
self.model_config.logits_processor_pattern)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user