From 52883ed08461943ff55d5dd3cf12a28c00902fa7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 28 Aug 2025 01:01:50 +0800 Subject: [PATCH] [Model] Merge `SupportsMultiModalWithRawInput` with `SupportsMultiModal` (#23749) Signed-off-by: DarkLight1337 --- vllm/config/__init__.py | 8 ++-- vllm/model_executor/models/interfaces.py | 45 +++++-------------- .../models/prithvi_geospatial_mae.py | 6 +-- vllm/model_executor/models/registry.py | 11 ++--- vllm/v1/worker/gpu_model_runner.py | 10 +++-- 5 files changed, 30 insertions(+), 50 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index e3fb6d796def5..351833d3f02d0 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1698,6 +1698,10 @@ class ModelConfig: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + @property + def is_multimodal_raw_input_only_model(self) -> bool: + return self._model_info.supports_multimodal_raw_input_only + @property def is_cross_encoder(self) -> bool: return (self._model_info.supports_cross_encoding @@ -1707,10 +1711,6 @@ class ModelConfig: def is_pp_supported(self) -> bool: return self._model_info.supports_pp - @property - def is_multimodal_raw_input_supported(self) -> bool: - return self._model_info.supports_multimodal_raw_input - @property def is_attention_free(self) -> bool: return self._model_info.is_attention_free diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 506732fed3614..2ee966fb5c0c8 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ + supports_multimodal_raw_input_only: ClassVar[bool] = False + """ + A flag that indicates this model supports multi-modal inputs and processes + them in their raw form and not embeddings. + """ + supports_encoder_tp_data: ClassVar[bool] = False """ A flag that indicates whether this model supports @@ -143,45 +149,16 @@ def supports_multimodal( return getattr(model, "supports_multimodal", False) +def supports_multimodal_raw_input_only( + model: Union[type[object], object]) -> bool: + return getattr(model, "supports_multimodal_raw_input_only", False) + + def supports_multimodal_encoder_tp_data( model: Union[type[object], object]) -> bool: return getattr(model, "supports_encoder_tp_data", False) -@runtime_checkable -class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): - """The interface required for all multi-modal models.""" - - supports_multimodal_raw_input: ClassVar[Literal[True]] = True - """ - A flag that indicates this model supports multi-modal inputs and processes - them in their raw form and not embeddings. - - Note: - There is no need to redefine this flag if this class is in the - MRO of your model class. - """ - - -@overload -def supports_multimodal_raw_input( - model: object) -> TypeIs[SupportsMultiModalWithRawInput]: - ... - - -@overload -def supports_multimodal_raw_input( - model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: - ... - - -def supports_multimodal_raw_input( - model: Union[type[object], object] -) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], - TypeIs[SupportsMultiModalWithRawInput]]: - return getattr(model, "supports_multimodal_raw_input", False) - - @runtime_checkable class SupportsScoreTemplate(Protocol): """The interface required for all models that support score template.""" diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index f46d6375e1f61..2d14fe6d5892f 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -41,7 +41,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import (IsAttentionFree, MultiModalEmbeddings, - SupportsMultiModalWithRawInput) + SupportsMultiModal) from .interfaces_base import default_pooling_type @@ -174,10 +174,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): info=PrithviGeoSpatialMAEProcessingInfo, dummy_inputs=PrithviGeoSpatialMAEInputBuilder, ) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, - SupportsMultiModalWithRawInput): +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): """Prithvi Masked Autoencoder""" + supports_multimodal_raw_input_only = True is_pooling_model = True @classmethod diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 02ef301a52a43..12c0c77784db8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -29,7 +29,7 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, supports_multimodal_encoder_tp_data, - supports_multimodal_raw_input, supports_pp, + supports_multimodal_raw_input_only, supports_pp, supports_transcription, supports_v0_only) from .interfaces_base import (get_default_pooling_type, is_pooling_model, is_text_generation_model) @@ -326,7 +326,7 @@ class _ModelInfo: default_pooling_type: str supports_cross_encoding: bool supports_multimodal: bool - supports_multimodal_raw_input: bool + supports_multimodal_raw_input_only: bool supports_multimodal_encoder_tp_data: bool supports_pp: bool has_inner_state: bool @@ -346,7 +346,8 @@ class _ModelInfo: default_pooling_type=get_default_pooling_type(model), supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), - supports_multimodal_raw_input=supports_multimodal_raw_input(model), + supports_multimodal_raw_input_only= + supports_multimodal_raw_input_only(model), supports_multimodal_encoder_tp_data= supports_multimodal_encoder_tp_data(model), supports_pp=supports_pp(model), @@ -743,13 +744,13 @@ class _ModelRegistry: model_cls, _ = self.inspect_model_cls(architectures, model_config) return model_cls.supports_multimodal - def supports_multimodal_raw_input( + def is_multimodal_raw_input_only_model( self, architectures: Union[str, list[str]], model_config: ModelConfig, ) -> bool: model_cls, _ = self.inspect_model_cls(architectures, model_config) - return model_cls.supports_multimodal_raw_input + return model_cls.supports_multimodal_raw_input_only def is_pp_supported_model( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d93460d618e7c..20d2d20ba0967 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -139,8 +139,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cache_config.cache_dtype] self.is_pooling_model = model_config.pooler_config is not None - self.is_multimodal_raw_input_supported = ( - model_config.is_multimodal_raw_input_supported) + self.is_multimodal_raw_input_only_model = ( + model_config.is_multimodal_raw_input_only_model) + self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -612,7 +613,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self, scheduler_output: "SchedulerOutput", ) -> BatchedTensorInputs: - if not self.is_multimodal_raw_input_supported or not scheduler_output: # noqa: SIM102 + if not scheduler_output or not self.is_multimodal_raw_input_only_model: return {} mm_kwargs = list[MultiModalKwargsItem]() @@ -631,8 +632,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return mm_kwargs_combined def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: - if not self.is_multimodal_raw_input_supported: + if not self.is_multimodal_raw_input_only_model: return {} + mm_budget = self.mm_budget assert mm_budget is not None