mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:25:01 +08:00
[Frontend] Add logits_processors as an extra completion argument (#11150)
Signed-off-by: Brad Hilton <brad.hilton.nw@gmail.com>
This commit is contained in:
parent
3cb5769883
commit
9c3dadd1c9
@ -30,6 +30,7 @@ class MockModelConfig:
|
|||||||
tokenizer_revision = None
|
tokenizer_revision = None
|
||||||
multimodal_config = MultiModalConfig()
|
multimodal_config = MultiModalConfig()
|
||||||
hf_config = MockHFConfig()
|
hf_config = MockHFConfig()
|
||||||
|
logits_processor_pattern = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -156,41 +156,45 @@ class ModelConfig:
|
|||||||
can not be gathered from the vllm arguments.
|
can not be gathered from the vllm arguments.
|
||||||
override_pooler_config: Initialize non default pooling config or
|
override_pooler_config: Initialize non default pooling config or
|
||||||
override default pooling config for the pooling model.
|
override default pooling config for the pooling model.
|
||||||
|
logits_processor_pattern: Optional regex pattern specifying valid
|
||||||
|
logits processor qualified names that can be passed with the
|
||||||
|
`logits_processors` extra completion argument. Defaults to None,
|
||||||
|
which allows no processors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
model: str,
|
||||||
model: str,
|
task: Union[TaskOption, Literal["draft"]],
|
||||||
task: Union[TaskOption, Literal["draft"]],
|
tokenizer: str,
|
||||||
tokenizer: str,
|
tokenizer_mode: str,
|
||||||
tokenizer_mode: str,
|
trust_remote_code: bool,
|
||||||
trust_remote_code: bool,
|
dtype: Union[str, torch.dtype],
|
||||||
dtype: Union[str, torch.dtype],
|
seed: int,
|
||||||
seed: int,
|
allowed_local_media_path: str = "",
|
||||||
allowed_local_media_path: str = "",
|
revision: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
code_revision: Optional[str] = None,
|
||||||
code_revision: Optional[str] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_theta: Optional[float] = None,
|
||||||
rope_theta: Optional[float] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
tokenizer_revision: Optional[str] = None,
|
max_model_len: Optional[int] = None,
|
||||||
max_model_len: Optional[int] = None,
|
spec_target_max_model_len: Optional[int] = None,
|
||||||
spec_target_max_model_len: Optional[int] = None,
|
quantization: Optional[str] = None,
|
||||||
quantization: Optional[str] = None,
|
quantization_param_path: Optional[str] = None,
|
||||||
quantization_param_path: Optional[str] = None,
|
enforce_eager: Optional[bool] = None,
|
||||||
enforce_eager: Optional[bool] = None,
|
max_seq_len_to_capture: Optional[int] = None,
|
||||||
max_seq_len_to_capture: Optional[int] = None,
|
max_logprobs: int = 20,
|
||||||
max_logprobs: int = 20,
|
disable_sliding_window: bool = False,
|
||||||
disable_sliding_window: bool = False,
|
skip_tokenizer_init: bool = False,
|
||||||
skip_tokenizer_init: bool = False,
|
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
use_async_output_proc: bool = True,
|
||||||
use_async_output_proc: bool = True,
|
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
hf_overrides: Optional[HfOverrides] = None,
|
||||||
hf_overrides: Optional[HfOverrides] = None,
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
mm_cache_preprocessor: bool = False,
|
||||||
mm_cache_preprocessor: bool = False,
|
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||||
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
|
logits_processor_pattern: Optional[str] = None) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
@ -316,6 +320,7 @@ class ModelConfig:
|
|||||||
self.task: Final = task
|
self.task: Final = task
|
||||||
|
|
||||||
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
||||||
|
self.logits_processor_pattern = logits_processor_pattern
|
||||||
|
|
||||||
self._verify_quantization()
|
self._verify_quantization()
|
||||||
self._verify_cuda_graph()
|
self._verify_cuda_graph()
|
||||||
|
|||||||
@ -170,6 +170,7 @@ class EngineArgs:
|
|||||||
enable_chunked_prefill: Optional[bool] = None
|
enable_chunked_prefill: Optional[bool] = None
|
||||||
|
|
||||||
guided_decoding_backend: str = 'xgrammar'
|
guided_decoding_backend: str = 'xgrammar'
|
||||||
|
logits_processor_pattern: Optional[str] = None
|
||||||
# Speculative decoding configuration.
|
# Speculative decoding configuration.
|
||||||
speculative_model: Optional[str] = None
|
speculative_model: Optional[str] = None
|
||||||
speculative_model_quantization: Optional[str] = None
|
speculative_model_quantization: Optional[str] = None
|
||||||
@ -374,6 +375,14 @@ class EngineArgs:
|
|||||||
'https://github.com/noamgat/lm-format-enforcer.'
|
'https://github.com/noamgat/lm-format-enforcer.'
|
||||||
' Can be overridden per request via guided_decoding_backend'
|
' Can be overridden per request via guided_decoding_backend'
|
||||||
' parameter.')
|
' parameter.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--logits-processor-pattern',
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help='Optional regex pattern specifying valid logits processor '
|
||||||
|
'qualified names that can be passed with the `logits_processors` '
|
||||||
|
'extra completion argument. Defaults to None, which allows no '
|
||||||
|
'processors.')
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--distributed-executor-backend',
|
'--distributed-executor-backend',
|
||||||
@ -975,7 +984,7 @@ class EngineArgs:
|
|||||||
mm_cache_preprocessor=self.mm_cache_preprocessor,
|
mm_cache_preprocessor=self.mm_cache_preprocessor,
|
||||||
override_neuron_config=self.override_neuron_config,
|
override_neuron_config=self.override_neuron_config,
|
||||||
override_pooler_config=self.override_pooler_config,
|
override_pooler_config=self.override_pooler_config,
|
||||||
)
|
logits_processor_pattern=self.logits_processor_pattern)
|
||||||
|
|
||||||
def create_load_config(self) -> LoadConfig:
|
def create_load_config(self) -> LoadConfig:
|
||||||
return LoadConfig(
|
return LoadConfig(
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
@ -14,7 +15,7 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||||
RequestOutputKind, SamplingParams)
|
RequestOutputKind, SamplingParams)
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid, resolve_obj_by_qualname
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -148,6 +149,46 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
|||||||
type: Literal["function"] = "function"
|
type: Literal["function"] = "function"
|
||||||
|
|
||||||
|
|
||||||
|
class LogitsProcessorConstructor(BaseModel):
|
||||||
|
qualname: str
|
||||||
|
args: Optional[List[Any]] = None
|
||||||
|
kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
LogitsProcessors = List[Union[str, LogitsProcessorConstructor]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_processors(processors: Optional[LogitsProcessors],
|
||||||
|
pattern: Optional[str]) -> Optional[List[Any]]:
|
||||||
|
if processors and pattern:
|
||||||
|
logits_processors = []
|
||||||
|
for processor in processors:
|
||||||
|
qualname = processor if isinstance(processor,
|
||||||
|
str) else processor.qualname
|
||||||
|
if not re.match(pattern, qualname):
|
||||||
|
raise ValueError(
|
||||||
|
f"Logits processor '{qualname}' is not allowed by this "
|
||||||
|
"server. See --logits-processor-pattern engine argument "
|
||||||
|
"for more information.")
|
||||||
|
try:
|
||||||
|
logits_processor = resolve_obj_by_qualname(qualname)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Logits processor '{qualname}' could not be resolved: {e}"
|
||||||
|
) from e
|
||||||
|
if isinstance(processor, LogitsProcessorConstructor):
|
||||||
|
logits_processor = logits_processor(*processor.args or [],
|
||||||
|
**processor.kwargs or {})
|
||||||
|
logits_processors.append(logits_processor)
|
||||||
|
return logits_processors
|
||||||
|
elif processors:
|
||||||
|
raise ValueError(
|
||||||
|
"The `logits_processors` argument is not supported by this "
|
||||||
|
"server. See --logits-processor-pattern engine argugment "
|
||||||
|
"for more information.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(OpenAIBaseModel):
|
class ChatCompletionRequest(OpenAIBaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
# https://platform.openai.com/docs/api-reference/chat/create
|
# https://platform.openai.com/docs/api-reference/chat/create
|
||||||
@ -293,6 +334,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
"The request_id related to this request. If the caller does "
|
"The request_id related to this request. If the caller does "
|
||||||
"not set it, a random_uuid will be generated. This id is used "
|
"not set it, a random_uuid will be generated. This id is used "
|
||||||
"through out the inference process and return in response."))
|
"through out the inference process and return in response."))
|
||||||
|
logits_processors: Optional[LogitsProcessors] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"A list of either qualified names of logits processors, or "
|
||||||
|
"constructor objects, to apply when sampling. A constructor is "
|
||||||
|
"a JSON object with a required 'qualname' field specifying the "
|
||||||
|
"qualified name of the processor class/factory, and optional "
|
||||||
|
"'args' and 'kwargs' fields containing positional and keyword "
|
||||||
|
"arguments. For example: {'qualname': "
|
||||||
|
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||||
|
"{'param': 'value'}}."))
|
||||||
|
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
@ -314,7 +366,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
length_penalty=self.length_penalty,
|
length_penalty=self.length_penalty,
|
||||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||||
|
|
||||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
def to_sampling_params(
|
||||||
|
self, default_max_tokens: int,
|
||||||
|
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
||||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
@ -364,6 +418,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
min_tokens=self.min_tokens,
|
min_tokens=self.min_tokens,
|
||||||
skip_special_tokens=self.skip_special_tokens,
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||||
|
logits_processors=get_logits_processors(self.logits_processors,
|
||||||
|
logits_processor_pattern),
|
||||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||||
@ -599,6 +655,17 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
"The priority of the request (lower means earlier handling; "
|
"The priority of the request (lower means earlier handling; "
|
||||||
"default: 0). Any priority other than 0 will raise an error "
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
"if the served model does not use priority scheduling."))
|
"if the served model does not use priority scheduling."))
|
||||||
|
logits_processors: Optional[LogitsProcessors] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"A list of either qualified names of logits processors, or "
|
||||||
|
"constructor objects, to apply when sampling. A constructor is "
|
||||||
|
"a JSON object with a required 'qualname' field specifying the "
|
||||||
|
"qualified name of the processor class/factory, and optional "
|
||||||
|
"'args' and 'kwargs' fields containing positional and keyword "
|
||||||
|
"arguments. For example: {'qualname': "
|
||||||
|
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||||
|
"{'param': 'value'}}."))
|
||||||
|
|
||||||
# doc: end-completion-extra-params
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
@ -619,7 +686,9 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
length_penalty=self.length_penalty,
|
length_penalty=self.length_penalty,
|
||||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||||
|
|
||||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
def to_sampling_params(
|
||||||
|
self, default_max_tokens: int,
|
||||||
|
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = default_max_tokens
|
max_tokens = default_max_tokens
|
||||||
@ -665,6 +734,8 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
skip_special_tokens=self.skip_special_tokens,
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||||
|
logits_processors=get_logits_processors(self.logits_processors,
|
||||||
|
logits_processor_pattern),
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||||
else RequestOutputKind.FINAL_ONLY,
|
else RequestOutputKind.FINAL_ONLY,
|
||||||
|
|||||||
@ -197,7 +197,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
default_max_tokens)
|
default_max_tokens)
|
||||||
else:
|
else:
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens)
|
default_max_tokens,
|
||||||
|
self.model_config.logits_processor_pattern)
|
||||||
|
|
||||||
self._log_inputs(request_id,
|
self._log_inputs(request_id,
|
||||||
request_prompts[i],
|
request_prompts[i],
|
||||||
|
|||||||
@ -123,7 +123,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
default_max_tokens)
|
default_max_tokens)
|
||||||
else:
|
else:
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens)
|
default_max_tokens,
|
||||||
|
self.model_config.logits_processor_pattern)
|
||||||
|
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user