mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 06:08:44 +08:00
[Chore][1/2] Drop v0.14 deprecations (#31285)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
506eb0f454
commit
09dc7c690c
@ -72,7 +72,6 @@ Internal data structures.
|
||||
- [vllm.multimodal.inputs.MultiModalFieldConfig][]
|
||||
- [vllm.multimodal.inputs.MultiModalKwargsItem][]
|
||||
- [vllm.multimodal.inputs.MultiModalKwargsItems][]
|
||||
- [vllm.multimodal.inputs.MultiModalKwargs][]
|
||||
- [vllm.multimodal.inputs.MultiModalInputs][]
|
||||
|
||||
### Data Parsing
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorRespons
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
|
||||
@ -61,13 +61,13 @@ class MockLoRAResolver(LoRAResolver):
|
||||
return LoRARequest(
|
||||
lora_name="test-lora",
|
||||
lora_int_id=1,
|
||||
lora_local_path="/fake/path/test-lora",
|
||||
lora_path="/fake/path/test-lora",
|
||||
)
|
||||
elif lora_name == "invalid-lora":
|
||||
return LoRARequest(
|
||||
lora_name="invalid-lora",
|
||||
lora_int_id=2,
|
||||
lora_local_path="/fake/path/invalid-lora",
|
||||
lora_path="/fake/path/invalid-lora",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@ -41,9 +41,8 @@ from vllm.entrypoints.tool import Tool
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.tokenizers.protocol import TokenizerLike
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -259,8 +258,8 @@ class ParsableContext(ConversationContext):
|
||||
self,
|
||||
*,
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
tokenizer: AnyTokenizer,
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None,
|
||||
tokenizer: TokenizerLike,
|
||||
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
|
||||
request: ResponsesRequest,
|
||||
available_tools: list[str] | None,
|
||||
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||
|
||||
@ -19,9 +19,8 @@ from vllm.entrypoints.constants import MCP_PREFIX
|
||||
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.tokenizers.protocol import TokenizerLike
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -33,8 +32,8 @@ class ResponsesParser:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tokenizer: AnyTokenizer,
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
||||
tokenizer: TokenizerLike,
|
||||
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser],
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
request: ResponsesRequest,
|
||||
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||
@ -150,8 +149,8 @@ class ResponsesParser:
|
||||
|
||||
def get_responses_parser_for_simple_context(
|
||||
*,
|
||||
tokenizer: AnyTokenizer,
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
||||
tokenizer: TokenizerLike,
|
||||
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser],
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
request: ResponsesRequest,
|
||||
tool_parser_cls,
|
||||
|
||||
@ -119,7 +119,7 @@ class OpenAIServingModels:
|
||||
lora_cards = [
|
||||
ModelCard(
|
||||
id=lora.lora_name,
|
||||
root=lora.local_path,
|
||||
root=lora.path,
|
||||
parent=lora.base_model_name
|
||||
if lora.base_model_name
|
||||
else self.base_model_paths[0].name,
|
||||
|
||||
@ -1,33 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "ToolParser":
|
||||
from vllm.tool_parsers import ToolParser
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.entrypoints.openai.tool_parsers.ToolParser` has been moved to "
|
||||
"`vllm.tool_parsers.ToolParser`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return ToolParser
|
||||
if name == "ToolParserManager":
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.entrypoints.openai.tool_parsers.ToolParserManager` "
|
||||
"has been moved to `vllm.tool_parsers.ToolParserManager`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return ToolParserManager
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
import msgspec
|
||||
|
||||
@ -21,7 +20,6 @@ class LoRARequest(
|
||||
lora_name: str
|
||||
lora_int_id: int
|
||||
lora_path: str = ""
|
||||
lora_local_path: str | None = msgspec.field(default=None)
|
||||
long_lora_max_len: int | None = None
|
||||
base_model_name: str | None = msgspec.field(default=None)
|
||||
tensorizer_config_dict: dict | None = None
|
||||
@ -29,16 +27,6 @@ class LoRARequest(
|
||||
def __post_init__(self):
|
||||
if self.lora_int_id < 1:
|
||||
raise ValueError(f"id must be > 0, got {self.lora_int_id}")
|
||||
if self.lora_local_path:
|
||||
warnings.warn(
|
||||
"The 'lora_local_path' attribute is deprecated "
|
||||
"and will be removed in a future version. "
|
||||
"Please use 'lora_path' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not self.lora_path:
|
||||
self.lora_path = self.lora_local_path or ""
|
||||
|
||||
# Ensure lora_path is not empty
|
||||
assert self.lora_path, "lora_path cannot be empty"
|
||||
@ -55,28 +43,6 @@ class LoRARequest(
|
||||
def path(self):
|
||||
return self.lora_path
|
||||
|
||||
@property
|
||||
def local_path(self):
|
||||
warnings.warn(
|
||||
"The 'local_path' attribute is deprecated "
|
||||
"and will be removed in a future version. "
|
||||
"Please use 'path' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.lora_path
|
||||
|
||||
@local_path.setter
|
||||
def local_path(self, value):
|
||||
warnings.warn(
|
||||
"The 'local_path' attribute is deprecated "
|
||||
"and will be removed in a future version. "
|
||||
"Please use 'path' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.lora_path = value
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
"""
|
||||
Overrides the equality method to compare LoRARequest
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Iterable, Mapping, MutableSequence, Set
|
||||
from collections.abc import Callable, Iterable, Mapping, MutableSequence
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
ClassVar,
|
||||
@ -100,17 +100,6 @@ class SupportsMultiModal(Protocol):
|
||||
in their raw form and not input embeddings.
|
||||
"""
|
||||
|
||||
merge_by_field_config: ClassVar[bool | None] = None
|
||||
"""
|
||||
[DEPRECATED] A flag that indicates which implementation of
|
||||
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
|
||||
"""
|
||||
|
||||
multimodal_cpu_fields: ClassVar[Set[str] | None] = None
|
||||
"""
|
||||
[DEPRECATED] A set indicating CPU-only multimodal fields.
|
||||
"""
|
||||
|
||||
_processor_factory: ClassVar[_ProcessorFactories]
|
||||
"""
|
||||
Set internally by `MultiModalRegistry.register_processor`.
|
||||
@ -277,35 +266,7 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: ...
|
||||
def supports_multimodal(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsMultiModal]] | TypeIs[SupportsMultiModal]:
|
||||
res = getattr(model, "supports_multimodal", False)
|
||||
|
||||
if res:
|
||||
# We can remove this starting from v0.14
|
||||
merge_by_field_config = getattr(model, "merge_by_field_config", None)
|
||||
if merge_by_field_config is False:
|
||||
raise ValueError(
|
||||
"`merge_by_field_config=False` is no longer effective, "
|
||||
"please update your model to consider the new batching logic "
|
||||
"in `group_mm_kwargs_by_modality` (refer to "
|
||||
"https://github.com/vllm-project/vllm/issues/26149), "
|
||||
"and then remove the override from your model."
|
||||
)
|
||||
if merge_by_field_config is True:
|
||||
logger.warning_once(
|
||||
"`merge_by_field_config=True` is redundant, "
|
||||
"please remove the override from your model."
|
||||
)
|
||||
|
||||
multimodal_cpu_fields = getattr(model, "multimodal_cpu_fields", None)
|
||||
if multimodal_cpu_fields is not None:
|
||||
raise ValueError(
|
||||
"`multimodal_cpu_fields` is no longer effective, "
|
||||
"please set `keep_on_cpu=True` in `MultiModalFieldConfig` "
|
||||
"(refer to https://github.com/vllm-project/vllm/pull/30181), "
|
||||
"and then remove the override from your model."
|
||||
)
|
||||
|
||||
return res
|
||||
return getattr(model, "supports_multimodal", False)
|
||||
|
||||
|
||||
def supports_multimodal_raw_input_only(model: type[object] | object) -> bool:
|
||||
|
||||
@ -6,7 +6,6 @@ from .inputs import (
|
||||
ModalityData,
|
||||
MultiModalDataBuiltins,
|
||||
MultiModalDataDict,
|
||||
MultiModalKwargs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalPlaceholderDict,
|
||||
MultiModalUUIDDict,
|
||||
@ -30,7 +29,6 @@ __all__ = [
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalHasher",
|
||||
"MultiModalKwargs",
|
||||
"MultiModalKwargsItems",
|
||||
"MultiModalPlaceholderDict",
|
||||
"MultiModalUUIDDict",
|
||||
|
||||
@ -20,7 +20,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import NotRequired, TypeVar, deprecated
|
||||
from typing_extensions import NotRequired, TypeVar
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby, is_list_of
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
@ -356,8 +356,8 @@ class MultiModalFeatureSpec:
|
||||
@dataclass
|
||||
class MultiModalFieldElem:
|
||||
"""
|
||||
Represents a keyword argument corresponding to a multi-modal item
|
||||
in [`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs].
|
||||
Represents a keyword argument inside a
|
||||
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
|
||||
"""
|
||||
|
||||
modality: str
|
||||
@ -369,14 +369,14 @@ class MultiModalFieldElem:
|
||||
key: str
|
||||
"""
|
||||
The key of this field in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
|
||||
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
|
||||
i.e. the name of the keyword argument to be passed to the model.
|
||||
"""
|
||||
|
||||
data: NestedTensors
|
||||
"""
|
||||
The tensor data of this field in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
|
||||
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
|
||||
i.e. the value of the keyword argument to be passed to the model.
|
||||
|
||||
It may be set to `None` if it is determined that the item is cached
|
||||
@ -410,9 +410,9 @@ class MultiModalFieldElem:
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class BaseMultiModalField(ABC):
|
||||
"""
|
||||
Defines how to interpret tensor data belonging to a keyword argument in
|
||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs] for multiple
|
||||
multi-modal items, and vice versa.
|
||||
Defines how to interpret tensor data belonging to a keyword argument for
|
||||
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems],
|
||||
and vice versa.
|
||||
"""
|
||||
|
||||
keep_on_cpu: bool = False
|
||||
@ -985,62 +985,6 @@ MultiModalKwargsOptionalItems: TypeAlias = (
|
||||
)
|
||||
|
||||
|
||||
@deprecated("`MultiModalKwargs` is deprecated and will be removed in v0.14.")
|
||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
[`torch.nn.Module.forward`][].
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@deprecated(
|
||||
"`MultiModalKwargs.from_hf_inputs` is deprecated and "
|
||||
"will be removed in v0.14. "
|
||||
"Please use `MultiModalKwargsItems.from_hf_inputs` and "
|
||||
"access the tensor data using `.get_data()`."
|
||||
)
|
||||
def from_hf_inputs(
|
||||
hf_inputs: "BatchFeature",
|
||||
config_by_key: Mapping[str, MultiModalFieldConfig],
|
||||
):
|
||||
return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data()
|
||||
|
||||
@staticmethod
|
||||
@deprecated(
|
||||
"`MultiModalKwargs.from_items` is deprecated and "
|
||||
"will be removed in v0.14. "
|
||||
"Please use `MultiModalKwargsItems.from_seq` and "
|
||||
"access the tensor data using `.get_data()`."
|
||||
)
|
||||
def from_items(
|
||||
items: Sequence[MultiModalKwargsItem],
|
||||
*,
|
||||
pin_memory: bool = False,
|
||||
):
|
||||
return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory)
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
if key not in self:
|
||||
raise KeyError(
|
||||
f"Keyword argument {key!r} not found. "
|
||||
f"Available keys: {set(self.keys())}"
|
||||
)
|
||||
|
||||
return super().__getitem__(key)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
for k in self:
|
||||
if k not in other:
|
||||
return False
|
||||
if not nested_tensors_equal(self[k], other[k]):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
|
||||
"""
|
||||
A dictionary containing placeholder ranges for each modality.
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import mimetypes
|
||||
from collections.abc import Generator, Set
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
@ -462,8 +462,6 @@ def group_mm_kwargs_by_modality(
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
merge_by_field_config: bool | None = None,
|
||||
multimodal_cpu_fields: Set[str] | None = None,
|
||||
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
|
||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||
modality together into the same `MultiModalKwargs` instance.
|
||||
@ -476,17 +474,6 @@ def group_mm_kwargs_by_modality(
|
||||
Yields:
|
||||
A tuple `(modality, num_items, grouped_kwargs)`.
|
||||
"""
|
||||
if merge_by_field_config is not None:
|
||||
logger.warning_once(
|
||||
"The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
|
||||
"is deprecated and will be removed in v0.14."
|
||||
)
|
||||
if multimodal_cpu_fields is not None:
|
||||
logger.warning_once(
|
||||
"The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` "
|
||||
"is deprecated and will be removed in v0.14."
|
||||
)
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItems
|
||||
|
||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||
|
||||
@ -7,7 +7,6 @@ from .registry import (
|
||||
cached_get_tokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
get_tokenizer,
|
||||
init_tokenizer_from_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -16,5 +15,4 @@ __all__ = [
|
||||
"cached_get_tokenizer",
|
||||
"get_tokenizer",
|
||||
"cached_tokenizer_from_config",
|
||||
"init_tokenizer_from_config",
|
||||
]
|
||||
|
||||
@ -7,7 +7,7 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import huggingface_hub
|
||||
from typing_extensions import TypeVar, assert_never, deprecated
|
||||
from typing_extensions import TypeVar, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@ -224,10 +224,3 @@ def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Renamed to `cached_tokenizer_from_config`. The old name will be removed in v0.14."
|
||||
)
|
||||
def init_tokenizer_from_config(model_config: "ModelConfig"):
|
||||
return cached_tokenizer_from_config(model_config)
|
||||
|
||||
@ -1,127 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "AnyTokenizer":
|
||||
warnings.warn(
|
||||
"`vllm.transformers_utils.tokenizer.AnyTokenizer` has been moved to "
|
||||
"`vllm.tokenizers.TokenizerLike`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return TokenizerLike
|
||||
# Keep until lm-eval is updated
|
||||
if name == "get_tokenizer":
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.transformers_utils.tokenizer.get_tokenizer` "
|
||||
"has been moved to `vllm.tokenizers.get_tokenizer`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
"The old name will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return get_tokenizer
|
||||
if name == "cached_get_tokenizer":
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.transformers_utils.tokenizer.cached_get_tokenizer` "
|
||||
"has been moved to `vllm.tokenizers.cached_get_tokenizer`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return cached_get_tokenizer
|
||||
if name == "cached_tokenizer_from_config":
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.transformers_utils.tokenizer.cached_tokenizer_from_config` "
|
||||
"has been moved to `vllm.tokenizers.cached_tokenizer_from_config`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return cached_tokenizer_from_config
|
||||
if name == "init_tokenizer_from_configs":
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.transformers_utils.tokenizer.init_tokenizer_from_configs` "
|
||||
"has been moved to `vllm.tokenizers.cached_tokenizer_from_config`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return cached_tokenizer_from_config
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
@deprecated("Will be removed in v0.14. Please use `tokenizer.decode()` instead.")
|
||||
def decode_tokens(
|
||||
tokenizer: TokenizerLike,
|
||||
token_ids: list[int],
|
||||
*,
|
||||
skip_special_tokens: bool | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Backend-agnostic equivalent of HF's
|
||||
`tokenizer.decode(token_ids, ...)`.
|
||||
|
||||
`skip_special_tokens=None` means to use the backend's default
|
||||
settings.
|
||||
"""
|
||||
kw_args: dict[str, Any] = {}
|
||||
|
||||
if skip_special_tokens is not None:
|
||||
kw_args["skip_special_tokens"] = skip_special_tokens
|
||||
|
||||
return tokenizer.decode(token_ids, **kw_args)
|
||||
|
||||
|
||||
@deprecated("Will be removed in v0.14. Please use `tokenizer.encode()` instead.")
|
||||
def encode_tokens(
|
||||
tokenizer: TokenizerLike,
|
||||
text: str,
|
||||
*,
|
||||
truncation: bool | None = None,
|
||||
max_length: int | None = None,
|
||||
add_special_tokens: bool | None = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Backend-agnostic equivalent of HF's
|
||||
`tokenizer.encode(text, ...)`.
|
||||
|
||||
`add_special_tokens=None` means to use the backend's default
|
||||
settings.
|
||||
"""
|
||||
|
||||
kw_args: dict[str, Any] = {}
|
||||
if max_length is not None:
|
||||
kw_args["max_length"] = max_length
|
||||
|
||||
if truncation is not None:
|
||||
kw_args["truncation"] = truncation
|
||||
|
||||
if add_special_tokens is not None:
|
||||
kw_args["add_special_tokens"] = add_special_tokens
|
||||
|
||||
return tokenizer.encode(text, **kw_args)
|
||||
|
||||
@ -1,33 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "TokenizerBase":
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.transformers_utils.tokenizer_base.TokenizerBase` has been "
|
||||
"moved to `vllm.tokenizers.TokenizerLike`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return TokenizerLike
|
||||
if name == "TokenizerRegistry":
|
||||
from vllm.tokenizers import TokenizerRegistry
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.transformers_utils.tokenizer_base.TokenizerRegistry` has been "
|
||||
"moved to `vllm.tokenizers.TokenizerRegistry`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return TokenizerRegistry
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@ -2,39 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
_DEPRECATED_MAPPINGS = {
|
||||
"cprofile": "profiling",
|
||||
"cprofile_context": "profiling",
|
||||
# Used by lm-eval
|
||||
"get_open_port": "network_utils",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring
|
||||
"""Module-level getattr to handle deprecated utilities."""
|
||||
if name in _DEPRECATED_MAPPINGS:
|
||||
submodule_name = _DEPRECATED_MAPPINGS[name]
|
||||
warnings.warn(
|
||||
f"vllm.utils.{name} is deprecated and will be removed in a future version. "
|
||||
f"Use vllm.utils.{submodule_name}.{name} instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
# expose deprecated names in dir() for better UX/tab-completion
|
||||
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
|
||||
|
||||
|
||||
MASK_64_BITS = (1 << 64) - 1
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
@ -190,14 +189,6 @@ class AsyncLLM(EngineClient):
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"`AsyncLLM.processor` has been renamed to `AsyncLLM.input_processor`. "
|
||||
"The old name will be removed in v0.14."
|
||||
)
|
||||
def processor(self):
|
||||
return self.input_processor
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
|
||||
@ -7,7 +7,7 @@ from copy import copy
|
||||
from typing import Any, cast
|
||||
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
@ -136,14 +136,6 @@ class LLMEngine:
|
||||
# Don't keep the dummy data in memory
|
||||
self.reset_mm_cache()
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
"`LLMEngine.processor` has been renamed to `LLMEngine.input_processor`. "
|
||||
"The old name will be removed in v0.14."
|
||||
)
|
||||
def processor(self):
|
||||
return self.input_processor
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
|
||||
@ -1,20 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import warnings
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "Processor":
|
||||
from .input_processor import InputProcessor
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.v1.engine.processor.Processor` has been moved to "
|
||||
"`vllm.v1.engine.input_processor.InputProcessor`. "
|
||||
"The old name will be removed in v0.14.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return InputProcessor
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
@ -1090,13 +1090,11 @@ class GPUModelRunner(
|
||||
mm_kwargs.append(feature.data)
|
||||
|
||||
# Input all modalities at once
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
mm_kwargs_combined: BatchedTensorInputs = {}
|
||||
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
):
|
||||
mm_kwargs_combined.update(mm_kwargs_group)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user