mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 07:27:04 +08:00
[Misc] Remove redundant ClassRegistry (#29681)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
7c1ed45848
commit
7675ba30de
@ -233,7 +233,7 @@ def _test_processing_correctness(
|
||||
)
|
||||
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
factories = model_cls._processor_factory
|
||||
ctx = InputProcessingContext(
|
||||
model_config,
|
||||
tokenizer=cached_tokenizer_from_config(model_config),
|
||||
|
||||
@ -193,7 +193,7 @@ def test_model_tensor_schema(model_id: str):
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
assert supports_multimodal(model_cls)
|
||||
|
||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
factories = model_cls._processor_factory
|
||||
|
||||
inputs_parse_methods = []
|
||||
for attr_name in dir(model_cls):
|
||||
|
||||
@ -32,11 +32,13 @@ if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.multimodal.registry import _ProcessorFactories
|
||||
from vllm.sequence import IntermediateTensors
|
||||
else:
|
||||
VllmConfig = object
|
||||
WeightsMapper = object
|
||||
MultiModalFeatureSpec = object
|
||||
_ProcessorFactories = object
|
||||
IntermediateTensors = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -87,6 +89,11 @@ class SupportsMultiModal(Protocol):
|
||||
A set indicating CPU-only multimodal fields.
|
||||
"""
|
||||
|
||||
_processor_factory: ClassVar[_ProcessorFactories]
|
||||
"""
|
||||
Set internally by `MultiModalRegistry.register_processor`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
"""
|
||||
|
||||
@ -2,14 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
|
||||
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
|
||||
from vllm.utils.collection_utils import ClassRegistry
|
||||
|
||||
from .cache import BaseMultiModalProcessorCache
|
||||
from .processing import (
|
||||
@ -26,10 +23,11 @@ from .profiling import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
N = TypeVar("N", bound=type[nn.Module])
|
||||
N = TypeVar("N", bound=type["SupportsMultiModal"])
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
|
||||
|
||||
@ -95,9 +93,6 @@ class MultiModalRegistry:
|
||||
A registry that dispatches data processing according to the model.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]()
|
||||
|
||||
def _extract_mm_options(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
@ -207,7 +202,7 @@ class MultiModalRegistry:
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if self._processor_factories.contains(model_cls, strict=True):
|
||||
if "_processor_factory" in model_cls.__dict__:
|
||||
logger.warning(
|
||||
"Model class %s already has a multi-modal processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
@ -215,7 +210,7 @@ class MultiModalRegistry:
|
||||
self,
|
||||
)
|
||||
|
||||
self._processor_factories[model_cls] = _ProcessorFactories(
|
||||
model_cls._processor_factory = _ProcessorFactories(
|
||||
info=info,
|
||||
dummy_inputs=dummy_inputs,
|
||||
processor=processor,
|
||||
@ -225,12 +220,13 @@ class MultiModalRegistry:
|
||||
|
||||
return wrapper
|
||||
|
||||
def _get_model_cls(self, model_config: "ModelConfig"):
|
||||
def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
return model_cls
|
||||
assert hasattr(model_cls, "_processor_factory")
|
||||
return cast("SupportsMultiModal", model_cls)
|
||||
|
||||
def _create_processing_ctx(
|
||||
self,
|
||||
@ -248,7 +244,7 @@ class MultiModalRegistry:
|
||||
tokenizer: AnyTokenizer | None = None,
|
||||
) -> BaseProcessingInfo:
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = self._processor_factories[model_cls]
|
||||
factories = model_cls._processor_factory
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
return factories.info(ctx)
|
||||
|
||||
@ -266,7 +262,7 @@ class MultiModalRegistry:
|
||||
raise ValueError(f"{model_config.model} is not a multimodal model")
|
||||
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = self._processor_factories[model_cls]
|
||||
factories = model_cls._processor_factory
|
||||
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
|
||||
|
||||
@ -6,64 +6,37 @@ Contains helpers that are applied to collections.
|
||||
This is similar in concept to the `collections` module.
|
||||
"""
|
||||
|
||||
from collections import UserDict, defaultdict
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from typing_extensions import TypeIs, assert_never
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
_K = TypeVar("_K", bound=Hashable)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
class ClassRegistry(UserDict[type[T], _V]):
|
||||
"""
|
||||
A registry that acts like a dictionary but searches for other classes
|
||||
in the MRO if the original class is not found.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: type[T]) -> _V:
|
||||
for cls in key.mro():
|
||||
if cls in self.data:
|
||||
return self.data[cls]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return self.contains(key)
|
||||
|
||||
def contains(self, key: object, *, strict: bool = False) -> bool:
|
||||
if not isinstance(key, type):
|
||||
return False
|
||||
|
||||
if strict:
|
||||
return key in self.data
|
||||
|
||||
return any(cls in self.data for cls in key.mro())
|
||||
|
||||
|
||||
class LazyDict(Mapping[str, T], Generic[T]):
|
||||
class LazyDict(Mapping[str, _V], Generic[_V]):
|
||||
"""
|
||||
Evaluates dictionary items only when they are accessed.
|
||||
|
||||
Adapted from: https://stackoverflow.com/a/47212782/5082708
|
||||
"""
|
||||
|
||||
def __init__(self, factory: dict[str, Callable[[], T]]):
|
||||
def __init__(self, factory: dict[str, Callable[[], _V]]):
|
||||
self._factory = factory
|
||||
self._dict: dict[str, T] = {}
|
||||
self._dict: dict[str, _V] = {}
|
||||
|
||||
def __getitem__(self, key: str) -> T:
|
||||
def __getitem__(self, key: str) -> _V:
|
||||
if key not in self._dict:
|
||||
if key not in self._factory:
|
||||
raise KeyError(key)
|
||||
self._dict[key] = self._factory[key]()
|
||||
return self._dict[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Callable[[], T]):
|
||||
def __setitem__(self, key: str, value: Callable[[], _V]):
|
||||
self._factory[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user