diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 93660e6118ca..5b40a04db15e 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -30,6 +30,7 @@ class MockModelConfig: tokenizer_revision = None multimodal_config = MultiModalConfig() hf_config = MockHFConfig() + logits_processor_pattern = None @dataclass diff --git a/vllm/config.py b/vllm/config.py index 12ed80c366e4..37d062f7eb07 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -156,41 +156,45 @@ class ModelConfig: can not be gathered from the vllm arguments. override_pooler_config: Initialize non default pooling config or override default pooling config for the pooling model. + logits_processor_pattern: Optional regex pattern specifying valid + logits processor qualified names that can be passed with the + `logits_processors` extra completion argument. Defaults to None, + which allows no processors. """ - def __init__( - self, - model: str, - task: Union[TaskOption, Literal["draft"]], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - mm_cache_preprocessor: bool = False, - override_neuron_config: Optional[Dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None) -> None: + def __init__(self, + model: str, + task: Union[TaskOption, Literal["draft"]], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + mm_cache_preprocessor: bool = False, + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + logits_processor_pattern: Optional[str] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -316,6 +320,7 @@ class ModelConfig: self.task: Final = task self.pooler_config = self._init_pooler_config(override_pooler_config) + self.logits_processor_pattern = logits_processor_pattern self._verify_quantization() self._verify_cuda_graph() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0098648b1cd6..5a73c6ee02e0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -170,6 +170,7 @@ class EngineArgs: enable_chunked_prefill: Optional[bool] = None guided_decoding_backend: str = 'xgrammar' + logits_processor_pattern: Optional[str] = None # Speculative decoding configuration. speculative_model: Optional[str] = None speculative_model_quantization: Optional[str] = None @@ -374,6 +375,14 @@ class EngineArgs: 'https://github.com/noamgat/lm-format-enforcer.' ' Can be overridden per request via guided_decoding_backend' ' parameter.') + parser.add_argument( + '--logits-processor-pattern', + type=nullable_str, + default=None, + help='Optional regex pattern specifying valid logits processor ' + 'qualified names that can be passed with the `logits_processors` ' + 'extra completion argument. Defaults to None, which allows no ' + 'processors.') # Parallel arguments parser.add_argument( '--distributed-executor-backend', @@ -975,7 +984,7 @@ class EngineArgs: mm_cache_preprocessor=self.mm_cache_preprocessor, override_neuron_config=self.override_neuron_config, override_pooler_config=self.override_pooler_config, - ) + logits_processor_pattern=self.logits_processor_pattern) def create_load_config(self) -> LoadConfig: return LoadConfig( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f4e7740ea0cf..dfb7c977dbd4 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,5 +1,6 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import re import time from argparse import Namespace from typing import Any, Dict, List, Literal, Optional, Union @@ -14,7 +15,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) from vllm.sequence import Logprob -from vllm.utils import random_uuid +from vllm.utils import random_uuid, resolve_obj_by_qualname logger = init_logger(__name__) @@ -148,6 +149,46 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): type: Literal["function"] = "function" +class LogitsProcessorConstructor(BaseModel): + qualname: str + args: Optional[List[Any]] = None + kwargs: Optional[Dict[str, Any]] = None + + +LogitsProcessors = List[Union[str, LogitsProcessorConstructor]] + + +def get_logits_processors(processors: Optional[LogitsProcessors], + pattern: Optional[str]) -> Optional[List[Any]]: + if processors and pattern: + logits_processors = [] + for processor in processors: + qualname = processor if isinstance(processor, + str) else processor.qualname + if not re.match(pattern, qualname): + raise ValueError( + f"Logits processor '{qualname}' is not allowed by this " + "server. See --logits-processor-pattern engine argument " + "for more information.") + try: + logits_processor = resolve_obj_by_qualname(qualname) + except Exception as e: + raise ValueError( + f"Logits processor '{qualname}' could not be resolved: {e}" + ) from e + if isinstance(processor, LogitsProcessorConstructor): + logits_processor = logits_processor(*processor.args or [], + **processor.kwargs or {}) + logits_processors.append(logits_processor) + return logits_processors + elif processors: + raise ValueError( + "The `logits_processors` argument is not supported by this " + "server. See --logits-processor-pattern engine argugment " + "for more information.") + return None + + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -293,6 +334,17 @@ class ChatCompletionRequest(OpenAIBaseModel): "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " "through out the inference process and return in response.")) + logits_processors: Optional[LogitsProcessors] = Field( + default=None, + description=( + "A list of either qualified names of logits processors, or " + "constructor objects, to apply when sampling. A constructor is " + "a JSON object with a required 'qualname' field specifying the " + "qualified name of the processor class/factory, and optional " + "'args' and 'kwargs' fields containing positional and keyword " + "arguments. For example: {'qualname': " + "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " + "{'param': 'value'}}.")) # doc: end-chat-completion-extra-params @@ -314,7 +366,9 @@ class ChatCompletionRequest(OpenAIBaseModel): length_penalty=self.length_penalty, include_stop_str_in_output=self.include_stop_str_in_output) - def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: + def to_sampling_params( + self, default_max_tokens: int, + logits_processor_pattern: Optional[str]) -> SamplingParams: # TODO(#9845): remove max_tokens when field is removed from OpenAI API max_tokens = self.max_completion_tokens or self.max_tokens if max_tokens is None: @@ -364,6 +418,8 @@ class ChatCompletionRequest(OpenAIBaseModel): min_tokens=self.min_tokens, skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, + logits_processors=get_logits_processors(self.logits_processors, + logits_processor_pattern), include_stop_str_in_output=self.include_stop_str_in_output, truncate_prompt_tokens=self.truncate_prompt_tokens, output_kind=RequestOutputKind.DELTA if self.stream \ @@ -599,6 +655,17 @@ class CompletionRequest(OpenAIBaseModel): "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling.")) + logits_processors: Optional[LogitsProcessors] = Field( + default=None, + description=( + "A list of either qualified names of logits processors, or " + "constructor objects, to apply when sampling. A constructor is " + "a JSON object with a required 'qualname' field specifying the " + "qualified name of the processor class/factory, and optional " + "'args' and 'kwargs' fields containing positional and keyword " + "arguments. For example: {'qualname': " + "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " + "{'param': 'value'}}.")) # doc: end-completion-extra-params @@ -619,7 +686,9 @@ class CompletionRequest(OpenAIBaseModel): length_penalty=self.length_penalty, include_stop_str_in_output=self.include_stop_str_in_output) - def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: + def to_sampling_params( + self, default_max_tokens: int, + logits_processor_pattern: Optional[str]) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens @@ -665,6 +734,8 @@ class CompletionRequest(OpenAIBaseModel): skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, + logits_processors=get_logits_processors(self.logits_processors, + logits_processor_pattern), truncate_prompt_tokens=self.truncate_prompt_tokens, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a5e7b4ac3bb3..527418c63509 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -197,7 +197,8 @@ class OpenAIServingChat(OpenAIServing): default_max_tokens) else: sampling_params = request.to_sampling_params( - default_max_tokens) + default_max_tokens, + self.model_config.logits_processor_pattern) self._log_inputs(request_id, request_prompts[i], diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index b3436773062f..bd39a4c42e93 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -123,7 +123,8 @@ class OpenAIServingCompletion(OpenAIServing): default_max_tokens) else: sampling_params = request.to_sampling_params( - default_max_tokens) + default_max_tokens, + self.model_config.logits_processor_pattern) request_id_item = f"{request_id}-{i}"