[Misc] Remove redundant ClassRegistry (#29681)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Cyrus Leung 2025-11-29 07:24:47 +08:00 committed by GitHub
parent 7c1ed45848
commit 7675ba30de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 25 additions and 49 deletions

View File

@ -233,7 +233,7 @@ def _test_processing_correctness(
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
factories = model_cls._processor_factory
ctx = InputProcessingContext(
model_config,
tokenizer=cached_tokenizer_from_config(model_config),

View File

@ -193,7 +193,7 @@ def test_model_tensor_schema(model_id: str):
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
assert supports_multimodal(model_cls)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
factories = model_cls._processor_factory
inputs_parse_methods = []
for attr_name in dir(model_cls):

View File

@ -32,11 +32,13 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.multimodal.registry import _ProcessorFactories
from vllm.sequence import IntermediateTensors
else:
VllmConfig = object
WeightsMapper = object
MultiModalFeatureSpec = object
_ProcessorFactories = object
IntermediateTensors = object
logger = init_logger(__name__)
@ -87,6 +89,11 @@ class SupportsMultiModal(Protocol):
A set indicating CPU-only multimodal fields.
"""
_processor_factory: ClassVar[_ProcessorFactories]
"""
Set internally by `MultiModalRegistry.register_processor`.
"""
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
"""

View File

@ -2,14 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
import torch.nn as nn
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
from vllm.utils.collection_utils import ClassRegistry
from .cache import BaseMultiModalProcessorCache
from .processing import (
@ -26,10 +23,11 @@ from .profiling import (
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.model_executor.models.interfaces import SupportsMultiModal
logger = init_logger(__name__)
N = TypeVar("N", bound=type[nn.Module])
N = TypeVar("N", bound=type["SupportsMultiModal"])
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
@ -95,9 +93,6 @@ class MultiModalRegistry:
A registry that dispatches data processing according to the model.
"""
def __init__(self) -> None:
self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]()
def _extract_mm_options(
self,
model_config: "ModelConfig",
@ -207,7 +202,7 @@ class MultiModalRegistry:
"""
def wrapper(model_cls: N) -> N:
if self._processor_factories.contains(model_cls, strict=True):
if "_processor_factory" in model_cls.__dict__:
logger.warning(
"Model class %s already has a multi-modal processor "
"registered to %s. It is overwritten by the new one.",
@ -215,7 +210,7 @@ class MultiModalRegistry:
self,
)
self._processor_factories[model_cls] = _ProcessorFactories(
model_cls._processor_factory = _ProcessorFactories(
info=info,
dummy_inputs=dummy_inputs,
processor=processor,
@ -225,12 +220,13 @@ class MultiModalRegistry:
return wrapper
def _get_model_cls(self, model_config: "ModelConfig"):
def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
return model_cls
assert hasattr(model_cls, "_processor_factory")
return cast("SupportsMultiModal", model_cls)
def _create_processing_ctx(
self,
@ -248,7 +244,7 @@ class MultiModalRegistry:
tokenizer: AnyTokenizer | None = None,
) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.info(ctx)
@ -266,7 +262,7 @@ class MultiModalRegistry:
raise ValueError(f"{model_config.model} is not a multimodal model")
model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)

View File

@ -6,64 +6,37 @@ Contains helpers that are applied to collections.
This is similar in concept to the `collections` module.
"""
from collections import UserDict, defaultdict
from collections import defaultdict
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
from typing import Generic, Literal, TypeVar
from typing_extensions import TypeIs, assert_never
T = TypeVar("T")
U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class ClassRegistry(UserDict[type[T], _V]):
"""
A registry that acts like a dictionary but searches for other classes
in the MRO if the original class is not found.
"""
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type):
return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro())
class LazyDict(Mapping[str, T], Generic[T]):
class LazyDict(Mapping[str, _V], Generic[_V]):
"""
Evaluates dictionary items only when they are accessed.
Adapted from: https://stackoverflow.com/a/47212782/5082708
"""
def __init__(self, factory: dict[str, Callable[[], T]]):
def __init__(self, factory: dict[str, Callable[[], _V]]):
self._factory = factory
self._dict: dict[str, T] = {}
self._dict: dict[str, _V] = {}
def __getitem__(self, key: str) -> T:
def __getitem__(self, key: str) -> _V:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
self._dict[key] = self._factory[key]()
return self._dict[key]
def __setitem__(self, key: str, value: Callable[[], T]):
def __setitem__(self, key: str, value: Callable[[], _V]):
self._factory[key] = value
def __iter__(self):