diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 313ab2fa8038b..9638791ab5caa 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -233,7 +233,7 @@ def _test_processing_correctness( ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) - factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] + factories = model_cls._processor_factory ctx = InputProcessingContext( model_config, tokenizer=cached_tokenizer_from_config(model_config), diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 687d1ef349f84..a287d5b87d1b7 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -193,7 +193,7 @@ def test_model_tensor_schema(model_id: str): model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) assert supports_multimodal(model_cls) - factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] + factories = model_cls._processor_factory inputs_parse_methods = [] for attr_name in dir(model_cls): diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index cee0b79e5e5ac..2218d688e59f6 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -32,11 +32,13 @@ if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.models.utils import WeightsMapper from vllm.multimodal.inputs import MultiModalFeatureSpec + from vllm.multimodal.registry import _ProcessorFactories from vllm.sequence import IntermediateTensors else: VllmConfig = object WeightsMapper = object MultiModalFeatureSpec = object + _ProcessorFactories = object IntermediateTensors = object logger = init_logger(__name__) @@ -87,6 +89,11 @@ class SupportsMultiModal(Protocol): A set indicating CPU-only multimodal fields. """ + _processor_factory: ClassVar[_ProcessorFactories] + """ + Set internally by `MultiModalRegistry.register_processor`. + """ + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: """ diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 8f9276e846407..a7eafa76ad17e 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,14 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Protocol, TypeVar - -import torch.nn as nn +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config -from vllm.utils.collection_utils import ClassRegistry from .cache import BaseMultiModalProcessorCache from .processing import ( @@ -26,10 +23,11 @@ from .profiling import ( if TYPE_CHECKING: from vllm.config import ModelConfig + from vllm.model_executor.models.interfaces import SupportsMultiModal logger = init_logger(__name__) -N = TypeVar("N", bound=type[nn.Module]) +N = TypeVar("N", bound=type["SupportsMultiModal"]) _I = TypeVar("_I", bound=BaseProcessingInfo) _I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True) @@ -95,9 +93,6 @@ class MultiModalRegistry: A registry that dispatches data processing according to the model. """ - def __init__(self) -> None: - self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - def _extract_mm_options( self, model_config: "ModelConfig", @@ -207,7 +202,7 @@ class MultiModalRegistry: """ def wrapper(model_cls: N) -> N: - if self._processor_factories.contains(model_cls, strict=True): + if "_processor_factory" in model_cls.__dict__: logger.warning( "Model class %s already has a multi-modal processor " "registered to %s. It is overwritten by the new one.", @@ -215,7 +210,7 @@ class MultiModalRegistry: self, ) - self._processor_factories[model_cls] = _ProcessorFactories( + model_cls._processor_factory = _ProcessorFactories( info=info, dummy_inputs=dummy_inputs, processor=processor, @@ -225,12 +220,13 @@ class MultiModalRegistry: return wrapper - def _get_model_cls(self, model_config: "ModelConfig"): + def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal": # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) - return model_cls + assert hasattr(model_cls, "_processor_factory") + return cast("SupportsMultiModal", model_cls) def _create_processing_ctx( self, @@ -248,7 +244,7 @@ class MultiModalRegistry: tokenizer: AnyTokenizer | None = None, ) -> BaseProcessingInfo: model_cls = self._get_model_cls(model_config) - factories = self._processor_factories[model_cls] + factories = model_cls._processor_factory ctx = self._create_processing_ctx(model_config, tokenizer) return factories.info(ctx) @@ -266,7 +262,7 @@ class MultiModalRegistry: raise ValueError(f"{model_config.model} is not a multimodal model") model_cls = self._get_model_cls(model_config) - factories = self._processor_factories[model_cls] + factories = model_cls._processor_factory ctx = self._create_processing_ctx(model_config, tokenizer) diff --git a/vllm/utils/collection_utils.py b/vllm/utils/collection_utils.py index 57271311828cd..3b19e1bd78197 100644 --- a/vllm/utils/collection_utils.py +++ b/vllm/utils/collection_utils.py @@ -6,64 +6,37 @@ Contains helpers that are applied to collections. This is similar in concept to the `collections` module. """ -from collections import UserDict, defaultdict +from collections import defaultdict from collections.abc import Callable, Generator, Hashable, Iterable, Mapping from typing import Generic, Literal, TypeVar from typing_extensions import TypeIs, assert_never T = TypeVar("T") -U = TypeVar("U") _K = TypeVar("_K", bound=Hashable) _V = TypeVar("_V") -class ClassRegistry(UserDict[type[T], _V]): - """ - A registry that acts like a dictionary but searches for other classes - in the MRO if the original class is not found. - """ - - def __getitem__(self, key: type[T]) -> _V: - for cls in key.mro(): - if cls in self.data: - return self.data[cls] - - raise KeyError(key) - - def __contains__(self, key: object) -> bool: - return self.contains(key) - - def contains(self, key: object, *, strict: bool = False) -> bool: - if not isinstance(key, type): - return False - - if strict: - return key in self.data - - return any(cls in self.data for cls in key.mro()) - - -class LazyDict(Mapping[str, T], Generic[T]): +class LazyDict(Mapping[str, _V], Generic[_V]): """ Evaluates dictionary items only when they are accessed. Adapted from: https://stackoverflow.com/a/47212782/5082708 """ - def __init__(self, factory: dict[str, Callable[[], T]]): + def __init__(self, factory: dict[str, Callable[[], _V]]): self._factory = factory - self._dict: dict[str, T] = {} + self._dict: dict[str, _V] = {} - def __getitem__(self, key: str) -> T: + def __getitem__(self, key: str) -> _V: if key not in self._dict: if key not in self._factory: raise KeyError(key) self._dict[key] = self._factory[key]() return self._dict[key] - def __setitem__(self, key: str, value: Callable[[], T]): + def __setitem__(self, key: str, value: Callable[[], _V]): self._factory[key] = value def __iter__(self):