diff --git a/tests/multimodal/test_registry.py b/tests/multimodal/test_registry.py deleted file mode 100644 index d31e75bc279f6..0000000000000 --- a/tests/multimodal/test_registry.py +++ /dev/null @@ -1,38 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Unit tests for MultiModalRegistry.supports_multimodal_inputs and -Qwen2.5-VL visual component loading behavior. -""" - -import pytest - -from vllm.multimodal import MULTIMODAL_REGISTRY - -from ..models.utils import build_model_context - - -@pytest.mark.parametrize( - "model_id,limit_mm_per_prompt,expected", - [ - ("Qwen/Qwen2-0.5B-Instruct", {}, False), - ("Qwen/Qwen2.5-VL-3B-Instruct", {}, True), - ("Qwen/Qwen2.5-VL-3B-Instruct", { - "image": 0, - "video": 0 - }, False), - ("Qwen/Qwen2.5-VL-3B-Instruct", { - "image": 0 - }, True), - ], -) -@pytest.mark.core_model -def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected): - """Test supports_multimodal_inputs returns correct boolean for various - configs.""" - ctx = build_model_context( - model_id, - limit_mm_per_prompt=limit_mm_per_prompt, - ) - assert MULTIMODAL_REGISTRY.supports_multimodal_inputs( - ctx.model_config) is expected \ No newline at end of file diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 69c05b75d3eb8..eaed6017cc58d 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1695,6 +1695,15 @@ class ModelConfig: return mm_config.mm_processor_cache_gb > 0 + @property + def enable_mm_input_cache(self) -> bool: + """Whether the multi-modal input cache should be enabled.""" + mm_config = self.multimodal_config + if mm_config is None: + return False + + return mm_config.mm_processor_cache_gb > 0 + def get_mm_input_cache_gb(self) -> int: mm_config = self.multimodal_config if mm_config is None: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index cfc6ffd99af62..c863ba406422d 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -521,22 +521,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): config.projector_hidden_act = "gelu" # TODO: Optionally initializes this for supporting embeddings. - if multimodal_config.get_limit_per_prompt("image"): - self.vision_tower = init_vision_tower_for_llava( - config, - quant_config, - require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) - self.multi_modal_projector = LlavaMultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, - text_hidden_size=config.text_config.hidden_size, - projector_hidden_act=config.projector_hidden_act, - multimodal_projector_bias=config.multimodal_projector_bias, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) - else: - self.vision_tower = None - self.multi_modal_projector = None + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + multimodal_projector_bias=config.multimodal_projector_bias, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -760,11 +756,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = [] - if self.vision_tower is None and self.multi_modal_projector is None: - skip_prefixes.extend(["vision_tower.", "multi_modal_projector."]) - - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 9e29a96c6e44a..88c3823eaa193 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -428,24 +428,20 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, config.projector_hidden_act = "gelu" # TODO: Optionally initializes this for supporting embeddings. - if multimodal_config.get_limit_per_prompt("image"): - self.vision_tower = init_vision_tower_for_llava( - config, - quant_config, - require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) - self.multi_modal_projector = Mistral3MultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, - text_hidden_size=config.text_config.hidden_size, - projector_hidden_act=config.projector_hidden_act, - spatial_merge_size=config.spatial_merge_size, - patch_size=config.vision_config.patch_size, - multimodal_projector_bias=config.multimodal_projector_bias, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) - else: - self.vision_tower = None - self.multi_modal_projector = None + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.multi_modal_projector = Mistral3MultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + spatial_merge_size=config.spatial_merge_size, + patch_size=config.vision_config.patch_size, + multimodal_projector_bias=config.multimodal_projector_bias, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -615,11 +611,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = [] - if self.vision_tower is None and self.multi_modal_projector is None: - skip_prefixes = ["vision_tower.", "multi_modal_projector."] - - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index b405dfca6d39b..e73dc0c2be82e 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -737,20 +737,16 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config - if multimodal_config.get_limit_per_prompt("image"): - self.vision_model = Llama4VisionModel( - config.vision_config, - None, - prefix=maybe_prefix(prefix, "vision_model"), - use_data_parallel=self.use_data_parallel, - ) - self.multi_modal_projector = Llama4MultiModalProjector( - self.config, - None, - prefix=maybe_prefix(prefix, "multi_modal_projector")) - else: - self.vision_model = None - self.multi_modal_projector = None + self.vision_model = Llama4VisionModel( + config.vision_config, + None, + prefix=maybe_prefix(prefix, "vision_model"), + use_data_parallel=self.use_data_parallel, + ) + self.multi_modal_projector = Llama4MultiModalProjector( + self.config, + None, + prefix=maybe_prefix(prefix, "multi_modal_projector")) self.language_model = initialize_model( vllm_config=vllm_config.with_hf_config(config.text_config, ["LlamaForCausalLM"]), @@ -787,8 +783,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, def _process_image_input( self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: - - assert self.vision_model and self.multi_modal_projector flat_data = image_input["flat_data"] patches_per_image = image_input["patches_per_image"].tolist() @@ -1054,10 +1048,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, language_model_weights, other_weights = ( self._separate_and_rename_weights(weights)) - # Skip loading vision model and projector if they're not initialized. - if self.vision_model is None and self.multi_modal_projector is None: - other_weights = [] - # Handle expert scale parameters regular_weights, expert_scale_weights, updated_params_from_experts = ( self._handle_expert_scale_broadcasting(language_model_weights, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index e95295c31885a..a3af541d20676 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -722,24 +722,13 @@ class Qwen2_5OmniThinkerForConditionalGeneration( "exactly same result as the transformers implementation " "in the audio tower part.") - if multimodal_config.get_limit_per_prompt("audio"): - self.audio_tower = Qwen2_5OmniAudioEncoder( - thinker_config.audio_config) - else: - self.audio_tower = None - - if multimodal_config.get_limit_per_prompt( - "image") or multimodal_config.get_limit_per_prompt("video"): - self.visual = Qwen2_5_VisionTransformer( - vision_config=thinker_config.vision_config, - norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", - 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - ) - else: - self.visual = None - + self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) + self.visual = Qwen2_5_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) self.quant_config = quant_config self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -897,15 +886,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = ["talker.", "token2wav."] - if self.audio_tower is None: - skip_prefixes.extend(["audio_tower."]) - if self.visual is None: - skip_prefixes.extend(["visual."]) - loader = AutoWeightsLoader( self, - skip_prefixes=skip_prefixes, + skip_prefixes=["talker.", "token2wav."], ) loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 6bea180ffec90..79c5c77f6de69 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -843,17 +843,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - if multimodal_config.get_limit_per_prompt("image") or \ - multimodal_config.get_limit_per_prompt("video"): - self.visual = Qwen2_5_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config( - self.quant_config), - prefix=maybe_prefix(prefix, "visual"), - ) - else: - self.visual = None + self.visual = Qwen2_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(self.quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -1157,10 +1152,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = [] - if self.visual is None: - skip_prefixes.extend(["visual."]) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f2d438b3850b8..633f8598e879d 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1049,16 +1049,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - if multimodal_config.get_limit_per_prompt("image") or \ - multimodal_config.get_limit_per_prompt("video"): - self.visual = Qwen2VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(quant_config), - prefix=maybe_prefix(prefix, "visual"), - ) - else: - self.visual = None + self.visual = Qwen2VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -1354,10 +1350,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = [] - if self.visual is None: - skip_prefixes.extend(["visual."]) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: @@ -1452,8 +1445,5 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = [] - if self.visual is None: - skip_prefixes.extend(["visual."]) - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 41dba312cb422..363c12a4bf2b8 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -837,35 +837,27 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - if multimodal_config.get_limit_per_prompt("image"): - self.vision_model = Step3VisionTransformer(config.vision_config, - None, - prefix=maybe_prefix( - prefix, - "vision_model")) - self.vit_downsampler = nn.Conv2d( - config.vision_config.hidden_size, - config.vision_config.output_hidden_size, - kernel_size=2, - stride=config.understand_projector_stride) - self.vit_downsampler2 = nn.Conv2d( - config.vision_config.output_hidden_size, - config.vision_config.output_hidden_size * 2, - kernel_size=3, - stride=2, - padding=1, - ) - self.vit_large_projector = nn.Linear( - config.vision_config.output_hidden_size * 2, - config.hidden_size, - bias=config.projector_bias, - ) - else: - self.vision_model = None - self.vit_downsampler = None - self.vit_downsampler2 = None - self.vit_large_projector = None - + self.vision_model = Step3VisionTransformer(config.vision_config, + None, + prefix=maybe_prefix( + prefix, "vision_model")) + self.vit_downsampler = nn.Conv2d( + config.vision_config.hidden_size, + config.vision_config.output_hidden_size, + kernel_size=2, + stride=config.understand_projector_stride) + self.vit_downsampler2 = nn.Conv2d( + config.vision_config.output_hidden_size, + config.vision_config.output_hidden_size * 2, + kernel_size=3, + stride=2, + padding=1, + ) + self.vit_large_projector = nn.Linear( + config.vision_config.output_hidden_size * 2, + config.hidden_size, + bias=config.projector_bias, + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, @@ -1054,15 +1046,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - - skip_prefixes = [] - if self.vision_model is None and self.vit_large_projector is None: - skip_prefixes = [ - "vision_model.", "vit_downsampler.", "vit_downsampler2.", - "vit_large_projector." - ] - - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader(self) loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index a101f2a55f5d1..565d54e1a264b 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -115,45 +115,6 @@ class MultiModalRegistry: return True # Success - def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool: - """Whether the multi-modal input cache should be enabled. - NOTE: This is put under MultiModalRegistry on purpose to respect - text-only mode for multimodal models. - """ - - if not self.supports_multimodal_inputs(model_config): - return False - - mm_config = model_config.get_multimodal_config() - - return mm_config.mm_processor_cache_gb > 0 - - def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: - """ - Checks if the model supports multimodal inputs. - Returns True if the model is multimodal with any non-zero supported - modalities, otherwise returns False, effectively running in - text-only mode. - """ - if not model_config.is_multimodal_model: - return False - - processor = self.create_processor(model_config, disable_cache=False) - supported_modalities = processor.info.get_supported_mm_limits() - - mm_config = model_config.get_multimodal_config() - - # Check if all supported modalities have limit == 0 - if all( - mm_config.get_limit_per_prompt(modality) == 0 - for modality in supported_modalities): - logger.info_once( - "All limits of multimodal modalities supported by the model " - "are set to 0, running in text-only mode.") - return False - - return True - def get_max_tokens_per_item_by_modality( self, model_config: "ModelConfig", diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index faf5c132f8640..67ea3b007ecee 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -189,7 +189,7 @@ def compute_encoder_budget( in the input sequence. """ - if not mm_registry.supports_multimodal_inputs(model_config): + if not model_config.is_multimodal_model: return 0, 0 # TODO: handle encoder-decoder models once we support them. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f92a3e43da1f2..78b8fe4ea676f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -21,7 +21,6 @@ from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest -from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -126,7 +125,7 @@ class EngineCore: ) self.mm_input_cache_server = MultiModalInputCacheServer( - vllm_config.model_config, MULTIMODAL_REGISTRY) + vllm_config.model_config) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index 0532cda03d9a7..279c9f0007bce 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Optional -from vllm.multimodal import MultiModalKwargs, MultiModalRegistry +from vllm.multimodal import MultiModalKwargs from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata from vllm.utils import is_list_of @@ -46,11 +46,10 @@ if TYPE_CHECKING: class MultiModalInputCacheClient: """Used by P0 to check whether multi-modal kwargs are cached in P1.""" - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: + def __init__(self, model_config: "ModelConfig") -> None: super().__init__() - self.enabled = mm_registry.enable_mm_input_cache(model_config) + self.enabled = model_config.enable_mm_input_cache self.mm_cache = MultiModalCache.get_lru_cache( model_config.get_mm_input_cache_gb(), MultiModalCacheItemMetadata, @@ -86,11 +85,10 @@ class MultiModalInputCacheClient: class MultiModalInputCacheServer: """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: + def __init__(self, model_config: "ModelConfig") -> None: super().__init__() - self.enabled = mm_registry.enable_mm_input_cache(model_config) + self.enabled = model_config.enable_mm_input_cache self.mm_cache = MultiModalCache.get_lru_cache( model_config.get_mm_input_cache_gb(), MultiModalKwargs, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index b9419142caf6c..6e37ebeb87781 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -51,7 +51,7 @@ class Processor: mm_registry) self.mm_input_cache_client = MultiModalInputCacheClient( - self.model_config, mm_registry) + self.model_config) @property def mm_registry(self): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 48ff50fd6bd8c..08b253dcdb35c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -129,6 +129,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] + self.is_multimodal_model = model_config.is_multimodal_model self.is_pooling_model = model_config.pooler_config is not None self.is_encoder_only_model = False self.is_multimodal_raw_input_supported = ( @@ -148,8 +149,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope - self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) @@ -331,8 +330,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mm_registry, max_model_len=self.max_model_len, max_num_reqs=self.max_num_reqs, - ) if self.supports_mm_inputs \ - else None) + ) if self.is_multimodal_model else None) self.reorder_batch_threshold: Optional[int] = None @@ -1481,14 +1479,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if self.supports_mm_inputs: + if self.is_multimodal_model: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - if self.supports_mm_inputs and get_pp_group().is_first_rank: + if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -1819,7 +1817,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: target_hidden_states = hidden_states[token_indices] mm_embeds = None - if self.supports_mm_inputs: + if self.is_multimodal_model: mm_embeds = self._gather_mm_embeddings(scheduler_output, shift_computed_tokens=1) @@ -2211,7 +2209,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - if self.supports_mm_inputs: + if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) @@ -2419,7 +2417,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. - if self.supports_mm_inputs: + if self.is_multimodal_model: mm_budget = self.mm_budget assert mm_budget is not None diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 442c0ea068b92..81252f9b606ae 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -157,6 +157,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cache_config.cache_dtype] self._hidden_states_dtype = self.dtype + self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len @@ -192,8 +193,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope - self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) # TODO: Support M-RoPE (e.g, Qwen2-VL) assert not self.uses_mrope, "TPU does not support M-RoPE yet." @@ -294,7 +293,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mm_registry, max_model_len=self.max_model_len, max_num_reqs=self.max_num_reqs, - ) if self.supports_mm_inputs else None) + ) if self.is_multimodal_model else None) if not self.use_spmd: self.sample_from_logits_func = torch.compile( @@ -948,7 +947,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _get_model_inputs(self, input_ids: torch.Tensor, mm_embeds: list[torch.Tensor]): - if self.supports_mm_inputs: + if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -980,7 +979,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return self.kv_connector_no_forward(scheduler_output, self.vllm_config) - if self.supports_mm_inputs: + if self.is_multimodal_model: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) @@ -1231,7 +1230,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @torch.no_grad() def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: - if self.supports_mm_inputs: + if self.is_multimodal_model: input_ids = None inputs_embeds = torch.zeros((num_tokens, self.hidden_size), dtype=self.dtype, @@ -1272,7 +1271,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): _num_slices_per_kv_cache_update_block, ) - if self.supports_mm_inputs: + if self.is_multimodal_model: torch._dynamo.mark_dynamic(inputs_embeds, 0) else: torch._dynamo.mark_dynamic(input_ids, 0) @@ -1306,7 +1305,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): xm.mark_step() # Captures metadata updates def _precompile_mm_encoder(self) -> None: - if not self.supports_mm_inputs: + if not self.is_multimodal_model: return # Pre-compile MM encoder for all supported data modalities. @@ -1528,7 +1527,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens: int, ) -> None: # Profile with multimodal encoder & encoder cache. - if self.supports_mm_inputs: + if self.is_multimodal_model: mm_budget = self.mm_budget assert mm_budget is not None @@ -1685,11 +1684,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) def reset_dynamo_cache(self): - - # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs` - # since the compiled model object of the language backbone of a - # multimodal model needs to be extracted via `get_language_model`. - if self.model_config.is_multimodal_model: + if self.is_multimodal_model: compiled_model = self.model.get_language_model().model else: compiled_model = self.model.model