[Model] Merge SupportsMultiModalWithRawInput with SupportsMultiModal (#23749)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-28 01:01:50 +08:00 committed by GitHub
parent 4f35be10a9
commit 52883ed084
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 30 additions and 50 deletions

View File

@ -1698,6 +1698,10 @@ class ModelConfig:
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
@property
def is_multimodal_raw_input_only_model(self) -> bool:
return self._model_info.supports_multimodal_raw_input_only
@property
def is_cross_encoder(self) -> bool:
return (self._model_info.supports_cross_encoding
@ -1707,10 +1711,6 @@ class ModelConfig:
def is_pp_supported(self) -> bool:
return self._model_info.supports_pp
@property
def is_multimodal_raw_input_supported(self) -> bool:
return self._model_info.supports_multimodal_raw_input
@property
def is_attention_free(self) -> bool:
return self._model_info.is_attention_free

View File

@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""
supports_multimodal_raw_input_only: ClassVar[bool] = False
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.
"""
supports_encoder_tp_data: ClassVar[bool] = False
"""
A flag that indicates whether this model supports
@ -143,45 +149,16 @@ def supports_multimodal(
return getattr(model, "supports_multimodal", False)
def supports_multimodal_raw_input_only(
model: Union[type[object], object]) -> bool:
return getattr(model, "supports_multimodal_raw_input_only", False)
def supports_multimodal_encoder_tp_data(
model: Union[type[object], object]) -> bool:
return getattr(model, "supports_encoder_tp_data", False)
@runtime_checkable
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
"""The interface required for all multi-modal models."""
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
@overload
def supports_multimodal_raw_input(
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
...
@overload
def supports_multimodal_raw_input(
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
...
def supports_multimodal_raw_input(
model: Union[type[object], object]
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
TypeIs[SupportsMultiModalWithRawInput]]:
return getattr(model, "supports_multimodal_raw_input", False)
@runtime_checkable
class SupportsScoreTemplate(Protocol):
"""The interface required for all models that support score template."""

View File

@ -41,7 +41,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
SupportsMultiModalWithRawInput)
SupportsMultiModal)
from .interfaces_base import default_pooling_type
@ -174,10 +174,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
SupportsMultiModalWithRawInput):
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
"""Prithvi Masked Autoencoder"""
supports_multimodal_raw_input_only = True
is_pooling_model = True
@classmethod

View File

@ -29,7 +29,7 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding,
supports_multimodal,
supports_multimodal_encoder_tp_data,
supports_multimodal_raw_input, supports_pp,
supports_multimodal_raw_input_only, supports_pp,
supports_transcription, supports_v0_only)
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
is_text_generation_model)
@ -326,7 +326,7 @@ class _ModelInfo:
default_pooling_type: str
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input: bool
supports_multimodal_raw_input_only: bool
supports_multimodal_encoder_tp_data: bool
supports_pp: bool
has_inner_state: bool
@ -346,7 +346,8 @@ class _ModelInfo:
default_pooling_type=get_default_pooling_type(model),
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
supports_multimodal_raw_input_only=
supports_multimodal_raw_input_only(model),
supports_multimodal_encoder_tp_data=
supports_multimodal_encoder_tp_data(model),
supports_pp=supports_pp(model),
@ -743,13 +744,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures, model_config)
return model_cls.supports_multimodal
def supports_multimodal_raw_input(
def is_multimodal_raw_input_only_model(
self,
architectures: Union[str, list[str]],
model_config: ModelConfig,
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures, model_config)
return model_cls.supports_multimodal_raw_input
return model_cls.supports_multimodal_raw_input_only
def is_pp_supported_model(
self,

View File

@ -139,8 +139,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cache_config.cache_dtype]
self.is_pooling_model = model_config.pooler_config is not None
self.is_multimodal_raw_input_supported = (
model_config.is_multimodal_raw_input_supported)
self.is_multimodal_raw_input_only_model = (
model_config.is_multimodal_raw_input_only_model)
self.max_model_len = model_config.max_model_len
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
@ -612,7 +613,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
) -> BatchedTensorInputs:
if not self.is_multimodal_raw_input_supported or not scheduler_output: # noqa: SIM102
if not scheduler_output or not self.is_multimodal_raw_input_only_model:
return {}
mm_kwargs = list[MultiModalKwargsItem]()
@ -631,8 +632,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return mm_kwargs_combined
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
if not self.is_multimodal_raw_input_supported:
if not self.is_multimodal_raw_input_only_model:
return {}
mm_budget = self.mm_budget
assert mm_budget is not None