mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 20:05:01 +08:00
[Model] Merge SupportsMultiModalWithRawInput with SupportsMultiModal (#23749)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
4f35be10a9
commit
52883ed084
@ -1698,6 +1698,10 @@ class ModelConfig:
|
||||
def is_multimodal_model(self) -> bool:
|
||||
return self.multimodal_config is not None
|
||||
|
||||
@property
|
||||
def is_multimodal_raw_input_only_model(self) -> bool:
|
||||
return self._model_info.supports_multimodal_raw_input_only
|
||||
|
||||
@property
|
||||
def is_cross_encoder(self) -> bool:
|
||||
return (self._model_info.supports_cross_encoding
|
||||
@ -1707,10 +1711,6 @@ class ModelConfig:
|
||||
def is_pp_supported(self) -> bool:
|
||||
return self._model_info.supports_pp
|
||||
|
||||
@property
|
||||
def is_multimodal_raw_input_supported(self) -> bool:
|
||||
return self._model_info.supports_multimodal_raw_input
|
||||
|
||||
@property
|
||||
def is_attention_free(self) -> bool:
|
||||
return self._model_info.is_attention_free
|
||||
|
||||
@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol):
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
supports_multimodal_raw_input_only: ClassVar[bool] = False
|
||||
"""
|
||||
A flag that indicates this model supports multi-modal inputs and processes
|
||||
them in their raw form and not embeddings.
|
||||
"""
|
||||
|
||||
supports_encoder_tp_data: ClassVar[bool] = False
|
||||
"""
|
||||
A flag that indicates whether this model supports
|
||||
@ -143,45 +149,16 @@ def supports_multimodal(
|
||||
return getattr(model, "supports_multimodal", False)
|
||||
|
||||
|
||||
def supports_multimodal_raw_input_only(
|
||||
model: Union[type[object], object]) -> bool:
|
||||
return getattr(model, "supports_multimodal_raw_input_only", False)
|
||||
|
||||
|
||||
def supports_multimodal_encoder_tp_data(
|
||||
model: Union[type[object], object]) -> bool:
|
||||
return getattr(model, "supports_encoder_tp_data", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
|
||||
"""The interface required for all multi-modal models."""
|
||||
|
||||
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports multi-modal inputs and processes
|
||||
them in their raw form and not embeddings.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal_raw_input(
|
||||
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal_raw_input(
|
||||
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
|
||||
...
|
||||
|
||||
|
||||
def supports_multimodal_raw_input(
|
||||
model: Union[type[object], object]
|
||||
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
|
||||
TypeIs[SupportsMultiModalWithRawInput]]:
|
||||
return getattr(model, "supports_multimodal_raw_input", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsScoreTemplate(Protocol):
|
||||
"""The interface required for all models that support score template."""
|
||||
|
||||
@ -41,7 +41,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
|
||||
SupportsMultiModalWithRawInput)
|
||||
SupportsMultiModal)
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
||||
@ -174,10 +174,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
|
||||
)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
|
||||
SupportsMultiModalWithRawInput):
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
"""Prithvi Masked Autoencoder"""
|
||||
|
||||
supports_multimodal_raw_input_only = True
|
||||
is_pooling_model = True
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -29,7 +29,7 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free,
|
||||
is_hybrid, supports_cross_encoding,
|
||||
supports_multimodal,
|
||||
supports_multimodal_encoder_tp_data,
|
||||
supports_multimodal_raw_input, supports_pp,
|
||||
supports_multimodal_raw_input_only, supports_pp,
|
||||
supports_transcription, supports_v0_only)
|
||||
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
|
||||
is_text_generation_model)
|
||||
@ -326,7 +326,7 @@ class _ModelInfo:
|
||||
default_pooling_type: str
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_multimodal_raw_input: bool
|
||||
supports_multimodal_raw_input_only: bool
|
||||
supports_multimodal_encoder_tp_data: bool
|
||||
supports_pp: bool
|
||||
has_inner_state: bool
|
||||
@ -346,7 +346,8 @@ class _ModelInfo:
|
||||
default_pooling_type=get_default_pooling_type(model),
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
|
||||
supports_multimodal_raw_input_only=
|
||||
supports_multimodal_raw_input_only(model),
|
||||
supports_multimodal_encoder_tp_data=
|
||||
supports_multimodal_encoder_tp_data(model),
|
||||
supports_pp=supports_pp(model),
|
||||
@ -743,13 +744,13 @@ class _ModelRegistry:
|
||||
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
||||
return model_cls.supports_multimodal
|
||||
|
||||
def supports_multimodal_raw_input(
|
||||
def is_multimodal_raw_input_only_model(
|
||||
self,
|
||||
architectures: Union[str, list[str]],
|
||||
model_config: ModelConfig,
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures, model_config)
|
||||
return model_cls.supports_multimodal_raw_input
|
||||
return model_cls.supports_multimodal_raw_input_only
|
||||
|
||||
def is_pp_supported_model(
|
||||
self,
|
||||
|
||||
@ -139,8 +139,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cache_config.cache_dtype]
|
||||
|
||||
self.is_pooling_model = model_config.pooler_config is not None
|
||||
self.is_multimodal_raw_input_supported = (
|
||||
model_config.is_multimodal_raw_input_supported)
|
||||
self.is_multimodal_raw_input_only_model = (
|
||||
model_config.is_multimodal_raw_input_only_model)
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
@ -612,7 +613,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> BatchedTensorInputs:
|
||||
if not self.is_multimodal_raw_input_supported or not scheduler_output: # noqa: SIM102
|
||||
if not scheduler_output or not self.is_multimodal_raw_input_only_model:
|
||||
return {}
|
||||
|
||||
mm_kwargs = list[MultiModalKwargsItem]()
|
||||
@ -631,8 +632,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return mm_kwargs_combined
|
||||
|
||||
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
|
||||
if not self.is_multimodal_raw_input_supported:
|
||||
if not self.is_multimodal_raw_input_only_model:
|
||||
return {}
|
||||
|
||||
mm_budget = self.mm_budget
|
||||
assert mm_budget is not None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user