[Bugfix] Validate custom logits processor xargs for online serving (#27560)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-11-06 00:53:33 +08:00 committed by GitHub
parent 6cae1e5332
commit 3f5a4b6473
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 239 additions and 56 deletions

View File

@ -254,7 +254,15 @@ The previous sections alluded to the interfaces which vLLM logits processors mus
changes to the batch makeup.
"""
raise NotImplementedError
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.
Raise ValueError for invalid ones.
"""
return None
```
A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) the following methods:
@ -279,6 +287,10 @@ A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum)
* Use the `BatchUpdate` members to update logits processor internal state
* **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added.
* `validate_params(cls, sampling_params: SamplingParams)`:
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
### `BatchUpdate` data structure
The `BatchUpdate` abstraction models the persistent batch as a list of requests, supporting the following operations to change batch state (note that the order in which the operations are mentioned below reflects the order in which they should be processed in `update_state()`):

View File

@ -4,6 +4,9 @@ You can use vLLM *custom arguments* to pass in arguments which are not part of t
Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code.
!!! note
Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise invalid custom arguments can cause unexpected behaviour.
## Offline Custom Arguments
Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`:

View File

@ -18,6 +18,11 @@ In vLLM, logits processors operate at batch granularity. During a given engine s
Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods:
* `validate_params(cls, sampling_params: SamplingParams)`:
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
* **Note:** it's important to implement `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor.
* `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)`
* `vllm_config`: engine configuration data structure
* `device`: hardware accelerator device info
@ -103,6 +108,14 @@ The contrived example below implements a custom logits processor which consumes
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
self.req_info: dict[int, int] = {}
@ -118,6 +131,7 @@ The contrived example below implements a custom logits processor which consumes
# Process added requests.
for index, params, _, _ in batch_update.added:
assert params is not None
self.validate_params(params)
if params.extra_args and (target_token :=
params.extra_args.get("target_token")):
self.req_info[index] = target_token
@ -157,6 +171,7 @@ The contrived example below implements a custom logits processor which consumes
logits[rows, cols] = values_to_keep
return logits
```
In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor.
@ -180,7 +195,13 @@ RequestLogitsProcessor = Union[
While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above.
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.):
* Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters.
* Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit.
* Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
??? code "Example of Wrapping a Request-Level Logits Processor"
@ -220,6 +241,16 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
"""Example of wrapping a fake request-level logit processor to create a
batch-level logits processor"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} is not int"
)
def is_argmax_invariant(self) -> bool:
return False
@ -240,18 +271,11 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
Returns:
`Callable` request logits processor, or None
"""
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is None:
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)
```

View File

@ -33,6 +33,8 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------
"""
from typing import Any
import torch
from vllm import LLM, SamplingParams
@ -48,6 +50,16 @@ from vllm.v1.sample.logits_processor.builtin import process_dict_updates
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)
def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
@ -57,14 +69,17 @@ class DummyLogitsProcessor(LogitsProcessor):
return False
def update_state(self, batch_update: BatchUpdate | None):
def extract_extra_arg(params: SamplingParams) -> int | None:
self.validate_params(params)
return params.extra_args and params.extra_args.get("target_token")
process_dict_updates(
self.req_info,
batch_update,
# This function returns the LP's per-request state based on the
# request details, or None if this LP does not apply to the
# request.
lambda params, _, __: params.extra_args
and (params.extra_args.get("target_token")),
lambda params, _, __: extract_extra_arg(params),
)
def apply(self, logits: torch.Tensor) -> torch.Tensor:

View File

@ -76,6 +76,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of wrapping a fake request-level logit processor to create a
batch-level logits processor"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")
def is_argmax_invariant(self) -> bool:
return False
@ -101,13 +109,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
)
if target_token is None:
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)

View File

@ -77,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token = params.extra_args and params.extra_args.get("target_token")
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"`target_token` has to be an integer, got {target_token}."
)
def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
@ -113,13 +121,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
is None
):
return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token)

View File

@ -40,6 +40,7 @@ class MockModelConfig:
tokenizer_revision: str | None = None
multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig)
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
logits_processors: list[str] | None = None
logits_processor_pattern: str | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""

View File

@ -353,6 +353,7 @@ class MockModelConfig:
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processors: list[str] | None = None
logits_processor_pattern = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""

View File

@ -177,3 +177,32 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
# Alternate whether to activate dummy logitproc for each request
use_dummy_logitproc = not use_dummy_logitproc
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_invalid_custom_logitsproc_arg(
client: openai.AsyncOpenAI, model_name: str
):
"""Test that request with invalid custom logitsproc is rejected"""
prompt = "Hello, my name is"
# Pass invalid (non-int) target_token value to dummy logits processor
request_keyword_args: dict[str, Any] = {
**api_keyword_args,
"extra_body": {
"vllm_xargs": {DUMMY_LOGITPROC_ARG: "invalid_target_token_value"}
},
}
with pytest.raises(openai.OpenAIError) as exc_info:
await client.completions.create(
model=model_name,
prompt=prompt,
**request_keyword_args,
)
assert "is not int" in str(exc_info.value)

View File

@ -52,6 +52,16 @@ prompts = [
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
@ -62,11 +72,14 @@ class DummyLogitsProcessor(LogitsProcessor):
return False
def update_state(self, batch_update: BatchUpdate | None):
def extract_extra_arg(params: SamplingParams) -> int | None:
self.validate_params(params)
return params.extra_args and params.extra_args.get("target_token")
process_dict_updates(
self.req_info,
batch_update,
lambda params, _, __: params.extra_args
and (params.extra_args.get("target_token")),
lambda params, _, __: extract_extra_arg(params),
)
def apply(self, logits: torch.Tensor) -> torch.Tensor:

View File

@ -772,10 +772,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
description="KVTransfer parameters used for disaggregated serving.",
)
vllm_xargs: dict[str, str | int | float] | None = Field(
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
default=None,
description=(
"Additional request parameters with string or "
"Additional request parameters with (list of) string or "
"numeric values, used by custom extensions."
),
)

View File

@ -68,6 +68,7 @@ from vllm.transformers_utils.tokenizers import (
validate_request_params,
)
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
logger = init_logger(__name__)
@ -107,6 +108,9 @@ class OpenAIServingChat(OpenAIServing):
self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs
# set up logits processors
self.logits_processors = self.model_config.logits_processors
# set up reasoning parser
self.reasoning_parser = self._get_reasoning_parser(
reasoning_parser_name=reasoning_parser
@ -291,6 +295,10 @@ class OpenAIServingChat(OpenAIServing):
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)
self._log_inputs(
request_id,

View File

@ -36,6 +36,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
logger = init_logger(__name__)
@ -59,6 +60,10 @@ class OpenAIServingCompletion(OpenAIServing):
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
# set up logits processors
self.logits_processors = self.model_config.logits_processors
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.enable_force_include_usage = enable_force_include_usage
@ -181,6 +186,10 @@ class OpenAIServingCompletion(OpenAIServing):
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)
request_id_item = f"{request_id}-{i}"

View File

@ -131,10 +131,34 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type"""
def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
super().__init__(vllm_config, device, is_pin_memory)
@classmethod
def validate_params(cls, params: SamplingParams):
ngram_size = params.extra_args and params.extra_args.get("ngram_size")
window_size = params.extra_args and params.extra_args.get("window_size", 100)
whitelist_token_ids = params.extra_args and params.extra_args.get(
"whitelist_token_ids", None
)
# if ngram_size is not provided, skip validation because the processor
# will not be used.
if ngram_size is None:
return None
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
)
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(
"`window_size` has to be a strictly positive integer, "
f"got {window_size}."
)
if whitelist_token_ids is not None and not isinstance(
whitelist_token_ids, Iterable
):
raise ValueError(
"`whitelist_token_ids` has to be a sequence of integers, "
f"got {whitelist_token_ids}."
)
def is_argmax_invariant(self) -> bool:
return True
@ -150,26 +174,8 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
)
if ngram_size is None:
return None
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
)
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(
"`window_size` has to be a strictly positive integer, "
f"got {window_size}."
)
if whitelist_token_ids is not None and not isinstance(
whitelist_token_ids, Iterable
):
raise ValueError(
"`whitelist_token_ids` has to be a set of integers, "
f"got {whitelist_token_ids}."
)
else:
whitelist_token_ids = (
set(whitelist_token_ids) if whitelist_token_ids else None
)
whitelist_token_ids = set(whitelist_token_ids) if whitelist_token_ids else None
return NoRepeatNGramLogitsProcessor(
ngram_size=ngram_size,
window_size=window_size,

View File

@ -218,3 +218,9 @@ class DeepseekVLV2Config(PretrainedConfig):
self.global_view_pos = global_view_pos
self.candidate_resolutions = candidate_resolutions
self.vocab_size = self.text_config.vocab_size
# update model_type for OCR model
if "DeepseekOCRForCausalLM" in (
self.architectures or kwargs.get("architectures", [])
):
self.model_type = "deepseek_ocr"

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import importlib.metadata
import os
import threading
from collections.abc import Callable, Collection
from functools import lru_cache
@ -68,6 +69,33 @@ def set_default_torch_num_threads(num_threads: int):
torch.set_num_threads(old_num_threads)
@contextlib.contextmanager
def guard_cuda_initialization():
"""Avoid unexpected CUDA initialization."""
from vllm.platforms import current_platform
if not current_platform.is_cuda():
yield
return
had_key = "CUDA_VISIBLE_DEVICES" in os.environ
old_value = os.environ.get("CUDA_VISIBLE_DEVICES")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
try:
yield
except Exception as e:
if "No CUDA GPUs are available" in str(e):
err_msg = "CUDA initialization is blocked."
else:
err_msg = str(e)
raise RuntimeError(err_msg) from e
finally:
if had_key:
os.environ["CUDA_VISIBLE_DEVICES"] = old_value
else:
os.environ.pop("CUDA_VISIBLE_DEVICES")
def get_dtype_size(dtype: torch.dtype) -> int:
"""Get the size of the data type in bytes."""
return torch.tensor([], dtype=dtype).element_size()

View File

@ -13,6 +13,7 @@ import torch
from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor
from vllm.sampling_params import SamplingParams
from vllm.utils.torch_utils import guard_cuda_initialization
from vllm.v1.sample.logits_processor.builtin import (
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
@ -72,8 +73,10 @@ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
entrypoint.name,
entrypoint.value,
)
classes.append(entrypoint.load())
with guard_cuda_initialization():
classes.append(entrypoint.load())
except Exception as e:
logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e)
raise RuntimeError(
f"Failed to load LogitsProcessor plugin {entrypoint}"
) from e
@ -126,8 +129,15 @@ def _load_logitsprocs_by_fqcns(
try:
# Load module
module = importlib.import_module(module_path)
with guard_cuda_initialization():
module = importlib.import_module(module_path)
except Exception as e:
logger.error(
"Failed to load %sth LogitsProcessor plugin %s: %s",
ldx,
logitproc,
e,
)
raise RuntimeError(
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
) from e
@ -206,6 +216,14 @@ def build_logitsprocs(
)
def validate_logits_processors_parameters(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
sampling_params: SamplingParams,
):
for logits_procs in _load_custom_logitsprocs(logits_processors):
logits_procs.validate_params(sampling_params)
class AdapterLogitsProcessor(LogitsProcessor):
"""Wrapper for per-request logits processors

View File

@ -58,6 +58,14 @@ class BatchUpdate:
class LogitsProcessor(ABC):
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.
Raise ValueError for invalid ones.
"""
return None
@abstractmethod
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool