[Model] Merge SupportsMultiModalWithRawInput with SupportsMultiModal (#23749)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-28 01:01:50 +08:00 committed by GitHub
parent 4f35be10a9
commit 52883ed084
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 30 additions and 50 deletions

View File

@ -1698,6 +1698,10 @@ class ModelConfig:
def is_multimodal_model(self) -> bool: def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None return self.multimodal_config is not None
@property
def is_multimodal_raw_input_only_model(self) -> bool:
return self._model_info.supports_multimodal_raw_input_only
@property @property
def is_cross_encoder(self) -> bool: def is_cross_encoder(self) -> bool:
return (self._model_info.supports_cross_encoding return (self._model_info.supports_cross_encoding
@ -1707,10 +1711,6 @@ class ModelConfig:
def is_pp_supported(self) -> bool: def is_pp_supported(self) -> bool:
return self._model_info.supports_pp return self._model_info.supports_pp
@property
def is_multimodal_raw_input_supported(self) -> bool:
return self._model_info.supports_multimodal_raw_input
@property @property
def is_attention_free(self) -> bool: def is_attention_free(self) -> bool:
return self._model_info.is_attention_free return self._model_info.is_attention_free

View File

@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol):
MRO of your model class. MRO of your model class.
""" """
supports_multimodal_raw_input_only: ClassVar[bool] = False
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.
"""
supports_encoder_tp_data: ClassVar[bool] = False supports_encoder_tp_data: ClassVar[bool] = False
""" """
A flag that indicates whether this model supports A flag that indicates whether this model supports
@ -143,45 +149,16 @@ def supports_multimodal(
return getattr(model, "supports_multimodal", False) return getattr(model, "supports_multimodal", False)
def supports_multimodal_raw_input_only(
model: Union[type[object], object]) -> bool:
return getattr(model, "supports_multimodal_raw_input_only", False)
def supports_multimodal_encoder_tp_data( def supports_multimodal_encoder_tp_data(
model: Union[type[object], object]) -> bool: model: Union[type[object], object]) -> bool:
return getattr(model, "supports_encoder_tp_data", False) return getattr(model, "supports_encoder_tp_data", False)
@runtime_checkable
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
"""The interface required for all multi-modal models."""
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
@overload
def supports_multimodal_raw_input(
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
...
@overload
def supports_multimodal_raw_input(
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
...
def supports_multimodal_raw_input(
model: Union[type[object], object]
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
TypeIs[SupportsMultiModalWithRawInput]]:
return getattr(model, "supports_multimodal_raw_input", False)
@runtime_checkable @runtime_checkable
class SupportsScoreTemplate(Protocol): class SupportsScoreTemplate(Protocol):
"""The interface required for all models that support score template.""" """The interface required for all models that support score template."""

View File

@ -41,7 +41,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import (IsAttentionFree, MultiModalEmbeddings, from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
SupportsMultiModalWithRawInput) SupportsMultiModal)
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
@ -174,10 +174,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
info=PrithviGeoSpatialMAEProcessingInfo, info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder, dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
) )
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
SupportsMultiModalWithRawInput):
"""Prithvi Masked Autoencoder""" """Prithvi Masked Autoencoder"""
supports_multimodal_raw_input_only = True
is_pooling_model = True is_pooling_model = True
@classmethod @classmethod

View File

@ -29,7 +29,7 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding, is_hybrid, supports_cross_encoding,
supports_multimodal, supports_multimodal,
supports_multimodal_encoder_tp_data, supports_multimodal_encoder_tp_data,
supports_multimodal_raw_input, supports_pp, supports_multimodal_raw_input_only, supports_pp,
supports_transcription, supports_v0_only) supports_transcription, supports_v0_only)
from .interfaces_base import (get_default_pooling_type, is_pooling_model, from .interfaces_base import (get_default_pooling_type, is_pooling_model,
is_text_generation_model) is_text_generation_model)
@ -326,7 +326,7 @@ class _ModelInfo:
default_pooling_type: str default_pooling_type: str
supports_cross_encoding: bool supports_cross_encoding: bool
supports_multimodal: bool supports_multimodal: bool
supports_multimodal_raw_input: bool supports_multimodal_raw_input_only: bool
supports_multimodal_encoder_tp_data: bool supports_multimodal_encoder_tp_data: bool
supports_pp: bool supports_pp: bool
has_inner_state: bool has_inner_state: bool
@ -346,7 +346,8 @@ class _ModelInfo:
default_pooling_type=get_default_pooling_type(model), default_pooling_type=get_default_pooling_type(model),
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model), supports_multimodal_raw_input_only=
supports_multimodal_raw_input_only(model),
supports_multimodal_encoder_tp_data= supports_multimodal_encoder_tp_data=
supports_multimodal_encoder_tp_data(model), supports_multimodal_encoder_tp_data(model),
supports_pp=supports_pp(model), supports_pp=supports_pp(model),
@ -743,13 +744,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures, model_config) model_cls, _ = self.inspect_model_cls(architectures, model_config)
return model_cls.supports_multimodal return model_cls.supports_multimodal
def supports_multimodal_raw_input( def is_multimodal_raw_input_only_model(
self, self,
architectures: Union[str, list[str]], architectures: Union[str, list[str]],
model_config: ModelConfig, model_config: ModelConfig,
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures, model_config) model_cls, _ = self.inspect_model_cls(architectures, model_config)
return model_cls.supports_multimodal_raw_input return model_cls.supports_multimodal_raw_input_only
def is_pp_supported_model( def is_pp_supported_model(
self, self,

View File

@ -139,8 +139,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cache_config.cache_dtype] cache_config.cache_dtype]
self.is_pooling_model = model_config.pooler_config is not None self.is_pooling_model = model_config.pooler_config is not None
self.is_multimodal_raw_input_supported = ( self.is_multimodal_raw_input_only_model = (
model_config.is_multimodal_raw_input_supported) model_config.is_multimodal_raw_input_only_model)
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs self.max_num_reqs = scheduler_config.max_num_seqs
@ -612,7 +613,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> BatchedTensorInputs: ) -> BatchedTensorInputs:
if not self.is_multimodal_raw_input_supported or not scheduler_output: # noqa: SIM102 if not scheduler_output or not self.is_multimodal_raw_input_only_model:
return {} return {}
mm_kwargs = list[MultiModalKwargsItem]() mm_kwargs = list[MultiModalKwargsItem]()
@ -631,8 +632,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return mm_kwargs_combined return mm_kwargs_combined
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
if not self.is_multimodal_raw_input_supported: if not self.is_multimodal_raw_input_only_model:
return {} return {}
mm_budget = self.mm_budget mm_budget = self.mm_budget
assert mm_budget is not None assert mm_budget is not None