[Frontend] Add logits_processors as an extra completion argument (#11150)

Signed-off-by: Brad Hilton <brad.hilton.nw@gmail.com>
This commit is contained in:
Brad Hilton 2024-12-14 09:46:42 -07:00 committed by GitHub
parent 3cb5769883
commit 9c3dadd1c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 127 additions and 39 deletions

View File

@ -30,6 +30,7 @@ class MockModelConfig:
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
@dataclass

View File

@ -156,41 +156,45 @@ class ModelConfig:
can not be gathered from the vllm arguments.
override_pooler_config: Initialize non default pooling config or
override default pooling config for the pooling model.
logits_processor_pattern: Optional regex pattern specifying valid
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
"""
def __init__(
self,
model: str,
task: Union[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
def __init__(self,
model: str,
task: Union[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_cache_preprocessor: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
@ -316,6 +320,7 @@ class ModelConfig:
self.task: Final = task
self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern
self._verify_quantization()
self._verify_cuda_graph()

View File

@ -170,6 +170,7 @@ class EngineArgs:
enable_chunked_prefill: Optional[bool] = None
guided_decoding_backend: str = 'xgrammar'
logits_processor_pattern: Optional[str] = None
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
@ -374,6 +375,14 @@ class EngineArgs:
'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend'
' parameter.')
parser.add_argument(
'--logits-processor-pattern',
type=nullable_str,
default=None,
help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` '
'extra completion argument. Defaults to None, which allows no '
'processors.')
# Parallel arguments
parser.add_argument(
'--distributed-executor-backend',
@ -975,7 +984,7 @@ class EngineArgs:
mm_cache_preprocessor=self.mm_cache_preprocessor,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
)
logits_processor_pattern=self.logits_processor_pattern)
def create_load_config(self) -> LoadConfig:
return LoadConfig(

View File

@ -1,5 +1,6 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import re
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
@ -14,7 +15,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob
from vllm.utils import random_uuid
from vllm.utils import random_uuid, resolve_obj_by_qualname
logger = init_logger(__name__)
@ -148,6 +149,46 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
type: Literal["function"] = "function"
class LogitsProcessorConstructor(BaseModel):
qualname: str
args: Optional[List[Any]] = None
kwargs: Optional[Dict[str, Any]] = None
LogitsProcessors = List[Union[str, LogitsProcessorConstructor]]
def get_logits_processors(processors: Optional[LogitsProcessors],
pattern: Optional[str]) -> Optional[List[Any]]:
if processors and pattern:
logits_processors = []
for processor in processors:
qualname = processor if isinstance(processor,
str) else processor.qualname
if not re.match(pattern, qualname):
raise ValueError(
f"Logits processor '{qualname}' is not allowed by this "
"server. See --logits-processor-pattern engine argument "
"for more information.")
try:
logits_processor = resolve_obj_by_qualname(qualname)
except Exception as e:
raise ValueError(
f"Logits processor '{qualname}' could not be resolved: {e}"
) from e
if isinstance(processor, LogitsProcessorConstructor):
logits_processor = logits_processor(*processor.args or [],
**processor.kwargs or {})
logits_processors.append(logits_processor)
return logits_processors
elif processors:
raise ValueError(
"The `logits_processors` argument is not supported by this "
"server. See --logits-processor-pattern engine argugment "
"for more information.")
return None
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
@ -293,6 +334,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."))
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))
# doc: end-chat-completion-extra-params
@ -314,7 +366,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
@ -364,6 +418,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
logits_processors=get_logits_processors(self.logits_processors,
logits_processor_pattern),
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
@ -599,6 +655,17 @@ class CompletionRequest(OpenAIBaseModel):
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))
# doc: end-completion-extra-params
@ -619,7 +686,9 @@ class CompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
@ -665,6 +734,8 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
logits_processors=get_logits_processors(self.logits_processors,
logits_processor_pattern),
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,

View File

@ -197,7 +197,8 @@ class OpenAIServingChat(OpenAIServing):
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
default_max_tokens,
self.model_config.logits_processor_pattern)
self._log_inputs(request_id,
request_prompts[i],

View File

@ -123,7 +123,8 @@ class OpenAIServingCompletion(OpenAIServing):
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
default_max_tokens,
self.model_config.logits_processor_pattern)
request_id_item = f"{request_id}-{i}"