mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
[Bugfix] Validate custom logits processor xargs for online serving (#27560)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
6cae1e5332
commit
3f5a4b6473
@ -254,7 +254,15 @@ The previous sections alluded to the interfaces which vLLM logits processors mus
|
||||
changes to the batch makeup.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, sampling_params: SamplingParams):
|
||||
"""Validate sampling params for this logits processor.
|
||||
|
||||
Raise ValueError for invalid ones.
|
||||
"""
|
||||
return None
|
||||
|
||||
```
|
||||
|
||||
A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) the following methods:
|
||||
@ -279,6 +287,10 @@ A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum)
|
||||
* Use the `BatchUpdate` members to update logits processor internal state
|
||||
* **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added.
|
||||
|
||||
* `validate_params(cls, sampling_params: SamplingParams)`:
|
||||
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
|
||||
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
|
||||
|
||||
### `BatchUpdate` data structure
|
||||
|
||||
The `BatchUpdate` abstraction models the persistent batch as a list of requests, supporting the following operations to change batch state (note that the order in which the operations are mentioned below reflects the order in which they should be processed in `update_state()`):
|
||||
|
||||
@ -4,6 +4,9 @@ You can use vLLM *custom arguments* to pass in arguments which are not part of t
|
||||
|
||||
Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code.
|
||||
|
||||
!!! note
|
||||
Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise invalid custom arguments can cause unexpected behaviour.
|
||||
|
||||
## Offline Custom Arguments
|
||||
|
||||
Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`:
|
||||
|
||||
@ -18,6 +18,11 @@ In vLLM, logits processors operate at batch granularity. During a given engine s
|
||||
|
||||
Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods:
|
||||
|
||||
* `validate_params(cls, sampling_params: SamplingParams)`:
|
||||
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
|
||||
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
|
||||
* **Note:** it's important to implement `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor.
|
||||
|
||||
* `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)`
|
||||
* `vllm_config`: engine configuration data structure
|
||||
* `device`: hardware accelerator device info
|
||||
@ -103,6 +108,14 @@ The contrived example below implements a custom logits processor which consumes
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: int | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(f"target_token value {target_token} is not int")
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
self.req_info: dict[int, int] = {}
|
||||
@ -118,6 +131,7 @@ The contrived example below implements a custom logits processor which consumes
|
||||
# Process added requests.
|
||||
for index, params, _, _ in batch_update.added:
|
||||
assert params is not None
|
||||
self.validate_params(params)
|
||||
if params.extra_args and (target_token :=
|
||||
params.extra_args.get("target_token")):
|
||||
self.req_info[index] = target_token
|
||||
@ -157,6 +171,7 @@ The contrived example below implements a custom logits processor which consumes
|
||||
logits[rows, cols] = values_to_keep
|
||||
|
||||
return logits
|
||||
|
||||
```
|
||||
|
||||
In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor.
|
||||
@ -180,7 +195,13 @@ RequestLogitsProcessor = Union[
|
||||
|
||||
While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above.
|
||||
|
||||
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
|
||||
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.):
|
||||
|
||||
* Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters.
|
||||
|
||||
* Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit.
|
||||
|
||||
* Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
|
||||
|
||||
??? code "Example of Wrapping a Request-Level Logits Processor"
|
||||
|
||||
@ -220,6 +241,16 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
|
||||
"""Example of wrapping a fake request-level logit processor to create a
|
||||
batch-level logits processor"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(
|
||||
f"target_token value {target_token} is not int"
|
||||
)
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return False
|
||||
|
||||
@ -240,18 +271,11 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
|
||||
Returns:
|
||||
`Callable` request logits processor, or None
|
||||
"""
|
||||
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is None:
|
||||
return None
|
||||
if not isinstance(target_token, int):
|
||||
logger.warning(
|
||||
"target_token value %s is not int; not applying logits"
|
||||
" processor to request.",
|
||||
target_token,
|
||||
)
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
```
|
||||
|
||||
|
||||
@ -33,6 +33,8 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the'
|
||||
------------------------------------------------------------
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
@ -48,6 +50,16 @@ from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(
|
||||
f"target_token value {target_token} {type(target_token)} is not int"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
@ -57,14 +69,17 @@ class DummyLogitsProcessor(LogitsProcessor):
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
def extract_extra_arg(params: SamplingParams) -> int | None:
|
||||
self.validate_params(params)
|
||||
return params.extra_args and params.extra_args.get("target_token")
|
||||
|
||||
process_dict_updates(
|
||||
self.req_info,
|
||||
batch_update,
|
||||
# This function returns the LP's per-request state based on the
|
||||
# request details, or None if this LP does not apply to the
|
||||
# request.
|
||||
lambda params, _, __: params.extra_args
|
||||
and (params.extra_args.get("target_token")),
|
||||
lambda params, _, __: extract_extra_arg(params),
|
||||
)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -76,6 +76,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
"""Example of wrapping a fake request-level logit processor to create a
|
||||
batch-level logits processor"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(f"target_token value {target_token} is not int")
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return False
|
||||
|
||||
@ -101,13 +109,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
)
|
||||
if target_token is None:
|
||||
return None
|
||||
if not isinstance(target_token, int):
|
||||
logger.warning(
|
||||
"target_token value %s is not int; not applying logits"
|
||||
" processor to request.",
|
||||
target_token,
|
||||
)
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
|
||||
|
||||
|
||||
@ -77,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
"""Example of overriding the wrapper class `__init__()` in order to utilize
|
||||
info about the device type"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token = params.extra_args and params.extra_args.get("target_token")
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(
|
||||
f"`target_token` has to be an integer, got {target_token}."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
@ -113,13 +121,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
is None
|
||||
):
|
||||
return None
|
||||
if not isinstance(target_token, int):
|
||||
logger.warning(
|
||||
"target_token value %s is not int; not applying logits"
|
||||
" processor to request.",
|
||||
target_token,
|
||||
)
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ class MockModelConfig:
|
||||
tokenizer_revision: str | None = None
|
||||
multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig)
|
||||
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
|
||||
logits_processors: list[str] | None = None
|
||||
logits_processor_pattern: str | None = None
|
||||
diff_sampling_param: dict | None = None
|
||||
allowed_local_media_path: str = ""
|
||||
|
||||
@ -353,6 +353,7 @@ class MockModelConfig:
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
logits_processors: list[str] | None = None
|
||||
logits_processor_pattern = None
|
||||
diff_sampling_param: dict | None = None
|
||||
allowed_local_media_path: str = ""
|
||||
|
||||
@ -177,3 +177,32 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
|
||||
|
||||
# Alternate whether to activate dummy logitproc for each request
|
||||
use_dummy_logitproc = not use_dummy_logitproc
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_invalid_custom_logitsproc_arg(
|
||||
client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
"""Test that request with invalid custom logitsproc is rejected"""
|
||||
|
||||
prompt = "Hello, my name is"
|
||||
# Pass invalid (non-int) target_token value to dummy logits processor
|
||||
request_keyword_args: dict[str, Any] = {
|
||||
**api_keyword_args,
|
||||
"extra_body": {
|
||||
"vllm_xargs": {DUMMY_LOGITPROC_ARG: "invalid_target_token_value"}
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(openai.OpenAIError) as exc_info:
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
**request_keyword_args,
|
||||
)
|
||||
|
||||
assert "is not int" in str(exc_info.value)
|
||||
|
||||
@ -52,6 +52,16 @@ prompts = [
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: int | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(
|
||||
f"target_token value {target_token} {type(target_token)} is not int"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
@ -62,11 +72,14 @@ class DummyLogitsProcessor(LogitsProcessor):
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
def extract_extra_arg(params: SamplingParams) -> int | None:
|
||||
self.validate_params(params)
|
||||
return params.extra_args and params.extra_args.get("target_token")
|
||||
|
||||
process_dict_updates(
|
||||
self.req_info,
|
||||
batch_update,
|
||||
lambda params, _, __: params.extra_args
|
||||
and (params.extra_args.get("target_token")),
|
||||
lambda params, _, __: extract_extra_arg(params),
|
||||
)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -772,10 +772,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
|
||||
vllm_xargs: dict[str, str | int | float] | None = Field(
|
||||
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional request parameters with string or "
|
||||
"Additional request parameters with (list of) string or "
|
||||
"numeric values, used by custom extensions."
|
||||
),
|
||||
)
|
||||
|
||||
@ -68,6 +68,7 @@ from vllm.transformers_utils.tokenizers import (
|
||||
validate_request_params,
|
||||
)
|
||||
from vllm.utils.collection_utils import as_list
|
||||
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -107,6 +108,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
self.enable_log_outputs = enable_log_outputs
|
||||
|
||||
# set up logits processors
|
||||
self.logits_processors = self.model_config.logits_processors
|
||||
|
||||
# set up reasoning parser
|
||||
self.reasoning_parser = self._get_reasoning_parser(
|
||||
reasoning_parser_name=reasoning_parser
|
||||
@ -291,6 +295,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.model_config.logits_processor_pattern,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
validate_logits_processors_parameters(
|
||||
self.logits_processors,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
|
||||
@ -36,6 +36,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.collection_utils import as_list
|
||||
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -59,6 +60,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
# set up logits processors
|
||||
self.logits_processors = self.model_config.logits_processors
|
||||
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
@ -181,6 +186,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
self.model_config.logits_processor_pattern,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
validate_logits_processors_parameters(
|
||||
self.logits_processors,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
|
||||
@ -131,10 +131,34 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
"""Example of overriding the wrapper class `__init__()` in order to utilize
|
||||
info about the device type"""
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
super().__init__(vllm_config, device, is_pin_memory)
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
ngram_size = params.extra_args and params.extra_args.get("ngram_size")
|
||||
window_size = params.extra_args and params.extra_args.get("window_size", 100)
|
||||
whitelist_token_ids = params.extra_args and params.extra_args.get(
|
||||
"whitelist_token_ids", None
|
||||
)
|
||||
# if ngram_size is not provided, skip validation because the processor
|
||||
# will not be used.
|
||||
if ngram_size is None:
|
||||
return None
|
||||
|
||||
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
||||
raise ValueError(
|
||||
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
|
||||
)
|
||||
if not isinstance(window_size, int) or window_size <= 0:
|
||||
raise ValueError(
|
||||
"`window_size` has to be a strictly positive integer, "
|
||||
f"got {window_size}."
|
||||
)
|
||||
if whitelist_token_ids is not None and not isinstance(
|
||||
whitelist_token_ids, Iterable
|
||||
):
|
||||
raise ValueError(
|
||||
"`whitelist_token_ids` has to be a sequence of integers, "
|
||||
f"got {whitelist_token_ids}."
|
||||
)
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return True
|
||||
@ -150,26 +174,8 @@ class NGramPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
)
|
||||
if ngram_size is None:
|
||||
return None
|
||||
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
||||
raise ValueError(
|
||||
f"`ngram_size` has to be a strictly positive integer, got {ngram_size}."
|
||||
)
|
||||
if not isinstance(window_size, int) or window_size <= 0:
|
||||
raise ValueError(
|
||||
"`window_size` has to be a strictly positive integer, "
|
||||
f"got {window_size}."
|
||||
)
|
||||
if whitelist_token_ids is not None and not isinstance(
|
||||
whitelist_token_ids, Iterable
|
||||
):
|
||||
raise ValueError(
|
||||
"`whitelist_token_ids` has to be a set of integers, "
|
||||
f"got {whitelist_token_ids}."
|
||||
)
|
||||
else:
|
||||
whitelist_token_ids = (
|
||||
set(whitelist_token_ids) if whitelist_token_ids else None
|
||||
)
|
||||
|
||||
whitelist_token_ids = set(whitelist_token_ids) if whitelist_token_ids else None
|
||||
return NoRepeatNGramLogitsProcessor(
|
||||
ngram_size=ngram_size,
|
||||
window_size=window_size,
|
||||
|
||||
@ -218,3 +218,9 @@ class DeepseekVLV2Config(PretrainedConfig):
|
||||
self.global_view_pos = global_view_pos
|
||||
self.candidate_resolutions = candidate_resolutions
|
||||
self.vocab_size = self.text_config.vocab_size
|
||||
|
||||
# update model_type for OCR model
|
||||
if "DeepseekOCRForCausalLM" in (
|
||||
self.architectures or kwargs.get("architectures", [])
|
||||
):
|
||||
self.model_type = "deepseek_ocr"
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import importlib.metadata
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Callable, Collection
|
||||
from functools import lru_cache
|
||||
@ -68,6 +69,33 @@ def set_default_torch_num_threads(num_threads: int):
|
||||
torch.set_num_threads(old_num_threads)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def guard_cuda_initialization():
|
||||
"""Avoid unexpected CUDA initialization."""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
yield
|
||||
return
|
||||
|
||||
had_key = "CUDA_VISIBLE_DEVICES" in os.environ
|
||||
old_value = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
if "No CUDA GPUs are available" in str(e):
|
||||
err_msg = "CUDA initialization is blocked."
|
||||
else:
|
||||
err_msg = str(e)
|
||||
raise RuntimeError(err_msg) from e
|
||||
finally:
|
||||
if had_key:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = old_value
|
||||
else:
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES")
|
||||
|
||||
|
||||
def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
"""Get the size of the data type in bytes."""
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
@ -13,6 +13,7 @@ import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.torch_utils import guard_cuda_initialization
|
||||
from vllm.v1.sample.logits_processor.builtin import (
|
||||
LogitBiasLogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
@ -72,8 +73,10 @@ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
|
||||
entrypoint.name,
|
||||
entrypoint.value,
|
||||
)
|
||||
classes.append(entrypoint.load())
|
||||
with guard_cuda_initialization():
|
||||
classes.append(entrypoint.load())
|
||||
except Exception as e:
|
||||
logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e)
|
||||
raise RuntimeError(
|
||||
f"Failed to load LogitsProcessor plugin {entrypoint}"
|
||||
) from e
|
||||
@ -126,8 +129,15 @@ def _load_logitsprocs_by_fqcns(
|
||||
|
||||
try:
|
||||
# Load module
|
||||
module = importlib.import_module(module_path)
|
||||
with guard_cuda_initialization():
|
||||
module = importlib.import_module(module_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load %sth LogitsProcessor plugin %s: %s",
|
||||
ldx,
|
||||
logitproc,
|
||||
e,
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
|
||||
) from e
|
||||
@ -206,6 +216,14 @@ def build_logitsprocs(
|
||||
)
|
||||
|
||||
|
||||
def validate_logits_processors_parameters(
|
||||
logits_processors: Sequence[str | type[LogitsProcessor]] | None,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
for logits_procs in _load_custom_logitsprocs(logits_processors):
|
||||
logits_procs.validate_params(sampling_params)
|
||||
|
||||
|
||||
class AdapterLogitsProcessor(LogitsProcessor):
|
||||
"""Wrapper for per-request logits processors
|
||||
|
||||
|
||||
@ -58,6 +58,14 @@ class BatchUpdate:
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
@classmethod
|
||||
def validate_params(cls, sampling_params: SamplingParams):
|
||||
"""Validate sampling params for this logits processor.
|
||||
|
||||
Raise ValueError for invalid ones.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user