[Frontend] Add logits_processors as an extra completion argument (#11150)

Signed-off-by: Brad Hilton <brad.hilton.nw@gmail.com>
This commit is contained in:
Brad Hilton 2024-12-14 09:46:42 -07:00 committed by GitHub
parent 3cb5769883
commit 9c3dadd1c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 127 additions and 39 deletions

View File

@ -30,6 +30,7 @@ class MockModelConfig:
tokenizer_revision = None tokenizer_revision = None
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
hf_config = MockHFConfig() hf_config = MockHFConfig()
logits_processor_pattern = None
@dataclass @dataclass

View File

@ -156,41 +156,45 @@ class ModelConfig:
can not be gathered from the vllm arguments. can not be gathered from the vllm arguments.
override_pooler_config: Initialize non default pooling config or override_pooler_config: Initialize non default pooling config or
override default pooling config for the pooling model. override default pooling config for the pooling model.
logits_processor_pattern: Optional regex pattern specifying valid
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
""" """
def __init__( def __init__(self,
self, model: str,
model: str, task: Union[TaskOption, Literal["draft"]],
task: Union[TaskOption, Literal["draft"]], tokenizer: str,
tokenizer: str, tokenizer_mode: str,
tokenizer_mode: str, trust_remote_code: bool,
trust_remote_code: bool, dtype: Union[str, torch.dtype],
dtype: Union[str, torch.dtype], seed: int,
seed: int, allowed_local_media_path: str = "",
allowed_local_media_path: str = "", revision: Optional[str] = None,
revision: Optional[str] = None, code_revision: Optional[str] = None,
code_revision: Optional[str] = None, rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[Dict[str, Any]] = None, rope_theta: Optional[float] = None,
rope_theta: Optional[float] = None, tokenizer_revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None,
max_model_len: Optional[int] = None, spec_target_max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None, quantization: Optional[str] = None,
quantization: Optional[str] = None, quantization_param_path: Optional[str] = None,
quantization_param_path: Optional[str] = None, enforce_eager: Optional[bool] = None,
enforce_eager: Optional[bool] = None, max_seq_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20,
max_logprobs: int = 20, disable_sliding_window: bool = False,
disable_sliding_window: bool = False, skip_tokenizer_init: bool = False,
skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None,
served_model_name: Optional[Union[str, List[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True,
use_async_output_proc: bool = True, config_format: ConfigFormat = ConfigFormat.AUTO,
config_format: ConfigFormat = ConfigFormat.AUTO, hf_overrides: Optional[HfOverrides] = None,
hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_cache_preprocessor: bool = False,
mm_cache_preprocessor: bool = False, override_neuron_config: Optional[Dict[str, Any]] = None,
override_neuron_config: Optional[Dict[str, Any]] = None, override_pooler_config: Optional["PoolerConfig"] = None,
override_pooler_config: Optional["PoolerConfig"] = None) -> None: logits_processor_pattern: Optional[str] = None) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
@ -316,6 +320,7 @@ class ModelConfig:
self.task: Final = task self.task: Final = task
self.pooler_config = self._init_pooler_config(override_pooler_config) self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()

View File

@ -170,6 +170,7 @@ class EngineArgs:
enable_chunked_prefill: Optional[bool] = None enable_chunked_prefill: Optional[bool] = None
guided_decoding_backend: str = 'xgrammar' guided_decoding_backend: str = 'xgrammar'
logits_processor_pattern: Optional[str] = None
# Speculative decoding configuration. # Speculative decoding configuration.
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None speculative_model_quantization: Optional[str] = None
@ -374,6 +375,14 @@ class EngineArgs:
'https://github.com/noamgat/lm-format-enforcer.' 'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend' ' Can be overridden per request via guided_decoding_backend'
' parameter.') ' parameter.')
parser.add_argument(
'--logits-processor-pattern',
type=nullable_str,
default=None,
help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` '
'extra completion argument. Defaults to None, which allows no '
'processors.')
# Parallel arguments # Parallel arguments
parser.add_argument( parser.add_argument(
'--distributed-executor-backend', '--distributed-executor-backend',
@ -975,7 +984,7 @@ class EngineArgs:
mm_cache_preprocessor=self.mm_cache_preprocessor, mm_cache_preprocessor=self.mm_cache_preprocessor,
override_neuron_config=self.override_neuron_config, override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config, override_pooler_config=self.override_pooler_config,
) logits_processor_pattern=self.logits_processor_pattern)
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:
return LoadConfig( return LoadConfig(

View File

@ -1,5 +1,6 @@
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import re
import time import time
from argparse import Namespace from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
@ -14,7 +15,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams) RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.utils import random_uuid from vllm.utils import random_uuid, resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
@ -148,6 +149,46 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
type: Literal["function"] = "function" type: Literal["function"] = "function"
class LogitsProcessorConstructor(BaseModel):
qualname: str
args: Optional[List[Any]] = None
kwargs: Optional[Dict[str, Any]] = None
LogitsProcessors = List[Union[str, LogitsProcessorConstructor]]
def get_logits_processors(processors: Optional[LogitsProcessors],
pattern: Optional[str]) -> Optional[List[Any]]:
if processors and pattern:
logits_processors = []
for processor in processors:
qualname = processor if isinstance(processor,
str) else processor.qualname
if not re.match(pattern, qualname):
raise ValueError(
f"Logits processor '{qualname}' is not allowed by this "
"server. See --logits-processor-pattern engine argument "
"for more information.")
try:
logits_processor = resolve_obj_by_qualname(qualname)
except Exception as e:
raise ValueError(
f"Logits processor '{qualname}' could not be resolved: {e}"
) from e
if isinstance(processor, LogitsProcessorConstructor):
logits_processor = logits_processor(*processor.args or [],
**processor.kwargs or {})
logits_processors.append(logits_processor)
return logits_processors
elif processors:
raise ValueError(
"The `logits_processors` argument is not supported by this "
"server. See --logits-processor-pattern engine argugment "
"for more information.")
return None
class ChatCompletionRequest(OpenAIBaseModel): class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create # https://platform.openai.com/docs/api-reference/chat/create
@ -293,6 +334,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
"The request_id related to this request. If the caller does " "The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used " "not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response.")) "through out the inference process and return in response."))
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
@ -314,7 +366,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output) include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API # TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None: if max_tokens is None:
@ -364,6 +418,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
logits_processors=get_logits_processors(self.logits_processors,
logits_processor_pattern),
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
@ -599,6 +655,17 @@ class CompletionRequest(OpenAIBaseModel):
"The priority of the request (lower means earlier handling; " "The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling.")) "if the served model does not use priority scheduling."))
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))
# doc: end-completion-extra-params # doc: end-completion-extra-params
@ -619,7 +686,9 @@ class CompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output) include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
@ -665,6 +734,8 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
logits_processors=get_logits_processors(self.logits_processors,
logits_processor_pattern),
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,

View File

@ -197,7 +197,8 @@ class OpenAIServingChat(OpenAIServing):
default_max_tokens) default_max_tokens)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens) default_max_tokens,
self.model_config.logits_processor_pattern)
self._log_inputs(request_id, self._log_inputs(request_id,
request_prompts[i], request_prompts[i],

View File

@ -123,7 +123,8 @@ class OpenAIServingCompletion(OpenAIServing):
default_max_tokens) default_max_tokens)
else: else:
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens) default_max_tokens,
self.model_config.logits_processor_pattern)
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"