[Bugfix] Validate custom logits processor xargs for online serving (#27560)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-11-06 00:53:33 +08:00 committed by GitHub
parent 6cae1e5332
commit 3f5a4b6473
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 239 additions and 56 deletions

View File

@ -254,7 +254,15 @@ The previous sections alluded to the interfaces which vLLM logits processors mus
changes to the batch makeup. changes to the batch makeup.
""" """
raise NotImplementedError raise NotImplementedError
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.
Raise ValueError for invalid ones.
"""
return None
``` ```
A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) the following methods: A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) the following methods:
@ -279,6 +287,10 @@ A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum)
* Use the `BatchUpdate` members to update logits processor internal state * Use the `BatchUpdate` members to update logits processor internal state
* **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added. * **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added.
* `validate_params(cls, sampling_params: SamplingParams)`:
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
### `BatchUpdate` data structure ### `BatchUpdate` data structure
The `BatchUpdate` abstraction models the persistent batch as a list of requests, supporting the following operations to change batch state (note that the order in which the operations are mentioned below reflects the order in which they should be processed in `update_state()`): The `BatchUpdate` abstraction models the persistent batch as a list of requests, supporting the following operations to change batch state (note that the order in which the operations are mentioned below reflects the order in which they should be processed in `update_state()`):

View File

@ -4,6 +4,9 @@ You can use vLLM *custom arguments* to pass in arguments which are not part of t
Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code. Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code.
!!! note
Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise invalid custom arguments can cause unexpected behaviour.
## Offline Custom Arguments ## Offline Custom Arguments
Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`: Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`:

View File

@ -18,6 +18,11 @@ In vLLM, logits processors operate at batch granularity. During a given engine s
Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods: Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods:
* `validate_params(cls, sampling_params: SamplingParams)`:
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
* **Note:** it's important to implement `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor.
* `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)` * `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)`
* `vllm_config`: engine configuration data structure * `vllm_config`: engine configuration data structure
* `device`: hardware accelerator device info * `device`: hardware accelerator device info
@ -103,6 +108,14 @@ The contrived example below implements a custom logits processor which consumes
class DummyLogitsProcessor(LogitsProcessor): class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples""" """Fake logit processor to support unit testing and examples"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")
def __init__(self, vllm_config: "VllmConfig", device: torch.device, def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool): is_pin_memory: bool):
self.req_info: dict[int, int] = {} self.req_info: dict[int, int] = {}
@ -118,6 +131,7 @@ The contrived example below implements a custom logits processor which consumes
# Process added requests. # Process added requests.
for index, params, _, _ in batch_update.added: for index, params, _, _ in batch_update.added:
assert params is not None assert params is not None
self.validate_params(params)
if params.extra_args and (target_token := if params.extra_args and (target_token :=
params.extra_args.get("target_token")): params.extra_args.get("target_token")):
self.req_info[index] = target_token self.req_info[index] = target_token
@ -157,6 +171,7 @@ The contrived example below implements a custom logits processor which consumes
logits[rows, cols] = values_to_keep logits[rows, cols] = values_to_keep
return logits return logits
``` ```
In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor. In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor.
@ -180,7 +195,13 @@ RequestLogitsProcessor = Union[
While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above. While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above.
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance: You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.):
* Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters.
* Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit.
* Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
??? code "Example of Wrapping a Request-Level Logits Processor" ??? code "Example of Wrapping a Request-Level Logits Processor"
@ -220,6 +241,16 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
"""Example of wrapping a fake request-level logit processor to create a """Example of wrapping a fake request-level logit processor to create a
batch-level logits processor""" batch-level logits processor"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} is not int"
)
def is_argmax_invariant(self) -> bool: def is_argmax_invariant(self) -> bool:
return False return False
@ -240,18 +271,11 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
Returns: Returns:
`Callable` request logits processor, or None `Callable` request logits processor, or None
""" """
target_token: Optional[Any] = params.extra_args and params.extra_args.get( target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token" "target_token"
) )
if target_token is None: if target_token is None:
return None return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token) return DummyPerReqLogitsProcessor(target_token)
``` ```

View File

@ -33,6 +33,8 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------ ------------------------------------------------------------
""" """
from typing import Any
import torch import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
@ -48,6 +50,16 @@ from vllm.v1.sample.logits_processor.builtin import process_dict_updates
class DummyLogitsProcessor(LogitsProcessor): class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples""" """Fake logit processor to support unit testing and examples"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)
def __init__( def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
): ):
@ -57,14 +69,17 @@ class DummyLogitsProcessor(LogitsProcessor):
return False return False
def update_state(self, batch_update: BatchUpdate | None): def update_state(self, batch_update: BatchUpdate | None):
def extract_extra_arg(params: SamplingParams) -> int | None:
self.validate_params(params)
return params.extra_args and params.extra_args.get("target_token")
process_dict_updates( process_dict_updates(
self.req_info, self.req_info,
batch_update, batch_update,
# This function returns the LP's per-request state based on the # This function returns the LP's per-request state based on the
# request details, or None if this LP does not apply to the # request details, or None if this LP does not apply to the
# request. # request.
lambda params, _, __: params.extra_args lambda params, _, __: extract_extra_arg(params),
and (params.extra_args.get("target_token")),
) )
def apply(self, logits: torch.Tensor) -> torch.Tensor: def apply(self, logits: torch.Tensor) -> torch.Tensor:

View File

@ -76,6 +76,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of wrapping a fake request-level logit processor to create a """Example of wrapping a fake request-level logit processor to create a
batch-level logits processor""" batch-level logits processor"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: Any | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(f"target_token value {target_token} is not int")
def is_argmax_invariant(self) -> bool: def is_argmax_invariant(self) -> bool:
return False return False
@ -101,13 +109,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
) )
if target_token is None: if target_token is None:
return None return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token) return DummyPerReqLogitsProcessor(target_token)

View File

@ -77,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize """Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type""" info about the device type"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token = params.extra_args and params.extra_args.get("target_token")
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"`target_token` has to be an integer, got {target_token}."
)
def __init__( def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
): ):
@ -113,13 +121,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
is None is None
): ):
return None return None
if not isinstance(target_token, int):
logger.warning(
"target_token value %s is not int; not applying logits"
" processor to request.",
target_token,
)
return None
return DummyPerReqLogitsProcessor(target_token) return DummyPerReqLogitsProcessor(target_token)

View File

@ -40,6 +40,7 @@ class MockModelConfig:
tokenizer_revision: str | None = None tokenizer_revision: str | None = None
multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig) multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig)
hf_config: MockHFConfig = field(default_factory=MockHFConfig) hf_config: MockHFConfig = field(default_factory=MockHFConfig)
logits_processors: list[str] | None = None
logits_processor_pattern: str | None = None logits_processor_pattern: str | None = None
diff_sampling_param: dict | None = None diff_sampling_param: dict | None = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""

View File

@ -353,6 +353,7 @@ class MockModelConfig:
tokenizer_revision = None tokenizer_revision = None
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
hf_config = MockHFConfig() hf_config = MockHFConfig()
logits_processors: list[str] | None = None
logits_processor_pattern = None logits_processor_pattern = None
diff_sampling_param: dict | None = None diff_sampling_param: dict | None = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""

View File

@ -177,3 +177,32 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
# Alternate whether to activate dummy logitproc for each request # Alternate whether to activate dummy logitproc for each request
use_dummy_logitproc = not use_dummy_logitproc use_dummy_logitproc = not use_dummy_logitproc
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_invalid_custom_logitsproc_arg(
client: openai.AsyncOpenAI, model_name: str
):
"""Test that request with invalid custom logitsproc is rejected"""
prompt = "Hello, my name is"
# Pass invalid (non-int) target_token value to dummy logits processor
request_keyword_args: dict[str, Any] = {
**api_keyword_args,
"extra_body": {
"vllm_xargs": {DUMMY_LOGITPROC_ARG: "invalid_target_token_value"}
},
}
with pytest.raises(openai.OpenAIError) as exc_info:
await client.completions.create(
model=model_name,
prompt=prompt,
**request_keyword_args,
)
assert "is not int" in str(exc_info.value)

View File

@ -52,6 +52,16 @@ prompts = [
class DummyLogitsProcessor(LogitsProcessor): class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples""" """Fake logit processor to support unit testing and examples"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)
def __init__( def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
): ):
@ -62,11 +72,14 @@ class DummyLogitsProcessor(LogitsProcessor):
return False return False
def update_state(self, batch_update: BatchUpdate | None): def update_state(self, batch_update: BatchUpdate | None):
def extract_extra_arg(params: SamplingParams) -> int | None:
self.validate_params(params)
return params.extra_args and params.extra_args.get("target_token")
process_dict_updates( process_dict_updates(
self.req_info, self.req_info,
batch_update, batch_update,
lambda params, _, __: params.extra_args lambda params, _, __: extract_extra_arg(params),
and (params.extra_args.get("target_token")),
) )
def apply(self, logits: torch.Tensor) -> torch.Tensor: def apply(self, logits: torch.Tensor) -> torch.Tensor:

View File

@ -772,10 +772,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
description="KVTransfer parameters used for disaggregated serving.", description="KVTransfer parameters used for disaggregated serving.",
) )
vllm_xargs: dict[str, str | int | float] | None = Field( vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
default=None, default=None,
description=( description=(
"Additional request parameters with string or " "Additional request parameters with (list of) string or "
"numeric values, used by custom extensions." "numeric values, used by custom extensions."
), ),
) )

View File

@ -68,6 +68,7 @@ from vllm.transformers_utils.tokenizers import (
validate_request_params, validate_request_params,
) )
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
logger = init_logger(__name__) logger = init_logger(__name__)
@ -107,6 +108,9 @@ class OpenAIServingChat(OpenAIServing):
self.trust_request_chat_template = trust_request_chat_template self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs self.enable_log_outputs = enable_log_outputs
# set up logits processors
self.logits_processors = self.model_config.logits_processors
# set up reasoning parser # set up reasoning parser
self.reasoning_parser = self._get_reasoning_parser( self.reasoning_parser = self._get_reasoning_parser(
reasoning_parser_name=reasoning_parser reasoning_parser_name=reasoning_parser
@ -291,6 +295,10 @@ class OpenAIServingChat(OpenAIServing):
self.model_config.logits_processor_pattern, self.model_config.logits_processor_pattern,
self.default_sampling_params, self.default_sampling_params,
) )
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)
self._log_inputs( self._log_inputs(
request_id, request_id,

View File

@ -36,6 +36,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
logger = init_logger(__name__) logger = init_logger(__name__)
@ -59,6 +60,10 @@ class OpenAIServingCompletion(OpenAIServing):
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack, log_error_stack=log_error_stack,
) )
# set up logits processors
self.logits_processors = self.model_config.logits_processors
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.enable_force_include_usage = enable_force_include_usage self.enable_force_include_usage = enable_force_include_usage
@ -181,6 +186,10 @@ class OpenAIServingCompletion(OpenAIServing):
self.model_config.logits_processor_pattern, self.model_config.logits_processor_pattern,
self.default_sampling_params, self.default_sampling_params,
) )
validate_logits_processors_parameters(
self.logits_processors,
sampling_params,
)
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"

View File

@ -131,10 +131,34 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
"""Example of overriding the wrapper class `__init__()` in order to utilize """Example of overriding the wrapper class `__init__()` in order to utilize
info about the device type""" info about the device type"""
def __init__( @classmethod
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool def validate_params(cls, params: SamplingParams):
): ngram_size = params.extra_args and params.extra_args.get("ngram_size")
super().__init__(vllm_config, device, is_pin_memory) window_size = params.extra_args and params.extra_args.get("window_size", 100)
whitelist_token_ids = params.extra_args and params.extra_args.get(
"whitelist_token_ids", None
)
# if ngram_size is not provided, skip validation because the processor
# will not be used.
if ngram_size is None:
return None
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
)
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(
"`window_size` has to be a strictly positive integer, "
f"got {window_size}."
)
if whitelist_token_ids is not None and not isinstance(
whitelist_token_ids, Iterable
):
raise ValueError(
"`whitelist_token_ids` has to be a sequence of integers, "
f"got {whitelist_token_ids}."
)
def is_argmax_invariant(self) -> bool: def is_argmax_invariant(self) -> bool:
return True return True
@ -150,26 +174,8 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
) )
if ngram_size is None: if ngram_size is None:
return None return None
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError( whitelist_token_ids = set(whitelist_token_ids) if whitelist_token_ids else None
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
)
if not isinstance(window_size, int) or window_size <= 0:
raise ValueError(
"`window_size` has to be a strictly positive integer, "
f"got {window_size}."
)
if whitelist_token_ids is not None and not isinstance(
whitelist_token_ids, Iterable
):
raise ValueError(
"`whitelist_token_ids` has to be a set of integers, "
f"got {whitelist_token_ids}."
)
else:
whitelist_token_ids = (
set(whitelist_token_ids) if whitelist_token_ids else None
)
return NoRepeatNGramLogitsProcessor( return NoRepeatNGramLogitsProcessor(
ngram_size=ngram_size, ngram_size=ngram_size,
window_size=window_size, window_size=window_size,

View File

@ -218,3 +218,9 @@ class DeepseekVLV2Config(PretrainedConfig):
self.global_view_pos = global_view_pos self.global_view_pos = global_view_pos
self.candidate_resolutions = candidate_resolutions self.candidate_resolutions = candidate_resolutions
self.vocab_size = self.text_config.vocab_size self.vocab_size = self.text_config.vocab_size
# update model_type for OCR model
if "DeepseekOCRForCausalLM" in (
self.architectures or kwargs.get("architectures", [])
):
self.model_type = "deepseek_ocr"

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
import importlib.metadata import importlib.metadata
import os
import threading import threading
from collections.abc import Callable, Collection from collections.abc import Callable, Collection
from functools import lru_cache from functools import lru_cache
@ -68,6 +69,33 @@ def set_default_torch_num_threads(num_threads: int):
torch.set_num_threads(old_num_threads) torch.set_num_threads(old_num_threads)
@contextlib.contextmanager
def guard_cuda_initialization():
"""Avoid unexpected CUDA initialization."""
from vllm.platforms import current_platform
if not current_platform.is_cuda():
yield
return
had_key = "CUDA_VISIBLE_DEVICES" in os.environ
old_value = os.environ.get("CUDA_VISIBLE_DEVICES")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
try:
yield
except Exception as e:
if "No CUDA GPUs are available" in str(e):
err_msg = "CUDA initialization is blocked."
else:
err_msg = str(e)
raise RuntimeError(err_msg) from e
finally:
if had_key:
os.environ["CUDA_VISIBLE_DEVICES"] = old_value
else:
os.environ.pop("CUDA_VISIBLE_DEVICES")
def get_dtype_size(dtype: torch.dtype) -> int: def get_dtype_size(dtype: torch.dtype) -> int:
"""Get the size of the data type in bytes.""" """Get the size of the data type in bytes."""
return torch.tensor([], dtype=dtype).element_size() return torch.tensor([], dtype=dtype).element_size()

View File

@ -13,6 +13,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils.torch_utils import guard_cuda_initialization
from vllm.v1.sample.logits_processor.builtin import ( from vllm.v1.sample.logits_processor.builtin import (
LogitBiasLogitsProcessor, LogitBiasLogitsProcessor,
MinPLogitsProcessor, MinPLogitsProcessor,
@ -72,8 +73,10 @@ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
entrypoint.name, entrypoint.name,
entrypoint.value, entrypoint.value,
) )
classes.append(entrypoint.load()) with guard_cuda_initialization():
classes.append(entrypoint.load())
except Exception as e: except Exception as e:
logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e)
raise RuntimeError( raise RuntimeError(
f"Failed to load LogitsProcessor plugin {entrypoint}" f"Failed to load LogitsProcessor plugin {entrypoint}"
) from e ) from e
@ -126,8 +129,15 @@ def _load_logitsprocs_by_fqcns(
try: try:
# Load module # Load module
module = importlib.import_module(module_path) with guard_cuda_initialization():
module = importlib.import_module(module_path)
except Exception as e: except Exception as e:
logger.error(
"Failed to load %sth LogitsProcessor plugin %s: %s",
ldx,
logitproc,
e,
)
raise RuntimeError( raise RuntimeError(
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}" f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
) from e ) from e
@ -206,6 +216,14 @@ def build_logitsprocs(
) )
def validate_logits_processors_parameters(
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
sampling_params: SamplingParams,
):
for logits_procs in _load_custom_logitsprocs(logits_processors):
logits_procs.validate_params(sampling_params)
class AdapterLogitsProcessor(LogitsProcessor): class AdapterLogitsProcessor(LogitsProcessor):
"""Wrapper for per-request logits processors """Wrapper for per-request logits processors

View File

@ -58,6 +58,14 @@ class BatchUpdate:
class LogitsProcessor(ABC): class LogitsProcessor(ABC):
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
"""Validate sampling params for this logits processor.
Raise ValueError for invalid ones.
"""
return None
@abstractmethod @abstractmethod
def __init__( def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool