diff --git a/tests/multimodal/test_registry.py b/tests/multimodal/test_registry.py new file mode 100644 index 000000000000..d31e75bc279f --- /dev/null +++ b/tests/multimodal/test_registry.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for MultiModalRegistry.supports_multimodal_inputs and +Qwen2.5-VL visual component loading behavior. +""" + +import pytest + +from vllm.multimodal import MULTIMODAL_REGISTRY + +from ..models.utils import build_model_context + + +@pytest.mark.parametrize( + "model_id,limit_mm_per_prompt,expected", + [ + ("Qwen/Qwen2-0.5B-Instruct", {}, False), + ("Qwen/Qwen2.5-VL-3B-Instruct", {}, True), + ("Qwen/Qwen2.5-VL-3B-Instruct", { + "image": 0, + "video": 0 + }, False), + ("Qwen/Qwen2.5-VL-3B-Instruct", { + "image": 0 + }, True), + ], +) +@pytest.mark.core_model +def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected): + """Test supports_multimodal_inputs returns correct boolean for various + configs.""" + ctx = build_model_context( + model_id, + limit_mm_per_prompt=limit_mm_per_prompt, + ) + assert MULTIMODAL_REGISTRY.supports_multimodal_inputs( + ctx.model_config) is expected \ No newline at end of file diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index eaed6017cc58..69c05b75d3eb 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1695,15 +1695,6 @@ class ModelConfig: return mm_config.mm_processor_cache_gb > 0 - @property - def enable_mm_input_cache(self) -> bool: - """Whether the multi-modal input cache should be enabled.""" - mm_config = self.multimodal_config - if mm_config is None: - return False - - return mm_config.mm_processor_cache_gb > 0 - def get_mm_input_cache_gb(self) -> int: mm_config = self.multimodal_config if mm_config is None: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index c863ba406422..cfc6ffd99af6 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -521,18 +521,22 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): config.projector_hidden_act = "gelu" # TODO: Optionally initializes this for supporting embeddings. - self.vision_tower = init_vision_tower_for_llava( - config, - quant_config, - require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) - self.multi_modal_projector = LlavaMultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, - text_hidden_size=config.text_config.hidden_size, - projector_hidden_act=config.projector_hidden_act, - multimodal_projector_bias=config.multimodal_projector_bias, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + if multimodal_config.get_limit_per_prompt("image"): + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + multimodal_projector_bias=config.multimodal_projector_bias, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + else: + self.vision_tower = None + self.multi_modal_projector = None self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -756,7 +760,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.vision_tower is None and self.multi_modal_projector is None: + skip_prefixes.extend(["vision_tower.", "multi_modal_projector."]) + + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 88c3823eaa19..9e29a96c6e44 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -428,20 +428,24 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, config.projector_hidden_act = "gelu" # TODO: Optionally initializes this for supporting embeddings. - self.vision_tower = init_vision_tower_for_llava( - config, - quant_config, - require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) - self.multi_modal_projector = Mistral3MultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, - text_hidden_size=config.text_config.hidden_size, - projector_hidden_act=config.projector_hidden_act, - spatial_merge_size=config.spatial_merge_size, - patch_size=config.vision_config.patch_size, - multimodal_projector_bias=config.multimodal_projector_bias, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + if multimodal_config.get_limit_per_prompt("image"): + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + self.multi_modal_projector = Mistral3MultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + spatial_merge_size=config.spatial_merge_size, + patch_size=config.vision_config.patch_size, + multimodal_projector_bias=config.multimodal_projector_bias, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + else: + self.vision_tower = None + self.multi_modal_projector = None self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -611,7 +615,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.vision_tower is None and self.multi_modal_projector is None: + skip_prefixes = ["vision_tower.", "multi_modal_projector."] + + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index e73dc0c2be82..b405dfca6d39 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -737,16 +737,20 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config - self.vision_model = Llama4VisionModel( - config.vision_config, - None, - prefix=maybe_prefix(prefix, "vision_model"), - use_data_parallel=self.use_data_parallel, - ) - self.multi_modal_projector = Llama4MultiModalProjector( - self.config, - None, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + if multimodal_config.get_limit_per_prompt("image"): + self.vision_model = Llama4VisionModel( + config.vision_config, + None, + prefix=maybe_prefix(prefix, "vision_model"), + use_data_parallel=self.use_data_parallel, + ) + self.multi_modal_projector = Llama4MultiModalProjector( + self.config, + None, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + else: + self.vision_model = None + self.multi_modal_projector = None self.language_model = initialize_model( vllm_config=vllm_config.with_hf_config(config.text_config, ["LlamaForCausalLM"]), @@ -783,6 +787,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, def _process_image_input( self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: + + assert self.vision_model and self.multi_modal_projector flat_data = image_input["flat_data"] patches_per_image = image_input["patches_per_image"].tolist() @@ -1048,6 +1054,10 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, language_model_weights, other_weights = ( self._separate_and_rename_weights(weights)) + # Skip loading vision model and projector if they're not initialized. + if self.vision_model is None and self.multi_modal_projector is None: + other_weights = [] + # Handle expert scale parameters regular_weights, expert_scale_weights, updated_params_from_experts = ( self._handle_expert_scale_broadcasting(language_model_weights, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index a3af541d2067..e95295c31885 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -722,13 +722,24 @@ class Qwen2_5OmniThinkerForConditionalGeneration( "exactly same result as the transformers implementation " "in the audio tower part.") - self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) - self.visual = Qwen2_5_VisionTransformer( - vision_config=thinker_config.vision_config, - norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - ) + if multimodal_config.get_limit_per_prompt("audio"): + self.audio_tower = Qwen2_5OmniAudioEncoder( + thinker_config.audio_config) + else: + self.audio_tower = None + + if multimodal_config.get_limit_per_prompt( + "image") or multimodal_config.get_limit_per_prompt("video"): + self.visual = Qwen2_5_VisionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", + 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + else: + self.visual = None + self.quant_config = quant_config self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -886,9 +897,15 @@ class Qwen2_5OmniThinkerForConditionalGeneration( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = ["talker.", "token2wav."] + if self.audio_tower is None: + skip_prefixes.extend(["audio_tower."]) + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader( self, - skip_prefixes=["talker.", "token2wav."], + skip_prefixes=skip_prefixes, ) loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 79c5c77f6de6..6bea180ffec9 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -843,12 +843,17 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - self.visual = Qwen2_5_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(self.quant_config), - prefix=maybe_prefix(prefix, "visual"), - ) + if multimodal_config.get_limit_per_prompt("image") or \ + multimodal_config.get_limit_per_prompt("video"): + self.visual = Qwen2_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config( + self.quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + else: + self.visual = None self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -1152,7 +1157,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 633f8598e879..f2d438b3850b 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1049,12 +1049,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - self.visual = Qwen2VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=self._maybe_ignore_quant_config(quant_config), - prefix=maybe_prefix(prefix, "visual"), - ) + if multimodal_config.get_limit_per_prompt("image") or \ + multimodal_config.get_limit_per_prompt("video"): + self.visual = Qwen2VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + else: + self.visual = None self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -1350,7 +1354,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: @@ -1445,5 +1452,8 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + skip_prefixes = [] + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 363c12a4bf2b..41dba312cb42 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -837,27 +837,35 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.config = config self.multimodal_config = multimodal_config - self.vision_model = Step3VisionTransformer(config.vision_config, - None, - prefix=maybe_prefix( - prefix, "vision_model")) - self.vit_downsampler = nn.Conv2d( - config.vision_config.hidden_size, - config.vision_config.output_hidden_size, - kernel_size=2, - stride=config.understand_projector_stride) - self.vit_downsampler2 = nn.Conv2d( - config.vision_config.output_hidden_size, - config.vision_config.output_hidden_size * 2, - kernel_size=3, - stride=2, - padding=1, - ) - self.vit_large_projector = nn.Linear( - config.vision_config.output_hidden_size * 2, - config.hidden_size, - bias=config.projector_bias, - ) + if multimodal_config.get_limit_per_prompt("image"): + self.vision_model = Step3VisionTransformer(config.vision_config, + None, + prefix=maybe_prefix( + prefix, + "vision_model")) + self.vit_downsampler = nn.Conv2d( + config.vision_config.hidden_size, + config.vision_config.output_hidden_size, + kernel_size=2, + stride=config.understand_projector_stride) + self.vit_downsampler2 = nn.Conv2d( + config.vision_config.output_hidden_size, + config.vision_config.output_hidden_size * 2, + kernel_size=3, + stride=2, + padding=1, + ) + self.vit_large_projector = nn.Linear( + config.vision_config.output_hidden_size * 2, + config.hidden_size, + bias=config.projector_bias, + ) + else: + self.vision_model = None + self.vit_downsampler = None + self.vit_downsampler2 = None + self.vit_large_projector = None + self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, @@ -1046,7 +1054,15 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader(self) + + skip_prefixes = [] + if self.vision_model is None and self.vit_large_projector is None: + skip_prefixes = [ + "vision_model.", "vit_downsampler.", "vit_downsampler2.", + "vit_large_projector." + ] + + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 565d54e1a264..a101f2a55f5d 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -115,6 +115,45 @@ class MultiModalRegistry: return True # Success + def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool: + """Whether the multi-modal input cache should be enabled. + NOTE: This is put under MultiModalRegistry on purpose to respect + text-only mode for multimodal models. + """ + + if not self.supports_multimodal_inputs(model_config): + return False + + mm_config = model_config.get_multimodal_config() + + return mm_config.mm_processor_cache_gb > 0 + + def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: + """ + Checks if the model supports multimodal inputs. + Returns True if the model is multimodal with any non-zero supported + modalities, otherwise returns False, effectively running in + text-only mode. + """ + if not model_config.is_multimodal_model: + return False + + processor = self.create_processor(model_config, disable_cache=False) + supported_modalities = processor.info.get_supported_mm_limits() + + mm_config = model_config.get_multimodal_config() + + # Check if all supported modalities have limit == 0 + if all( + mm_config.get_limit_per_prompt(modality) == 0 + for modality in supported_modalities): + logger.info_once( + "All limits of multimodal modalities supported by the model " + "are set to 0, running in text-only mode.") + return False + + return True + def get_max_tokens_per_item_by_modality( self, model_config: "ModelConfig", diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 67ea3b007ece..faf5c132f864 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -189,7 +189,7 @@ def compute_encoder_budget( in the input sequence. """ - if not model_config.is_multimodal_model: + if not mm_registry.supports_multimodal_inputs(model_config): return 0, 0 # TODO: handle encoder-decoder models once we support them. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 78b8fe4ea676..f92a3e43da1f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -21,6 +21,7 @@ from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -125,7 +126,7 @@ class EngineCore: ) self.mm_input_cache_server = MultiModalInputCacheServer( - vllm_config.model_config) + vllm_config.model_config, MULTIMODAL_REGISTRY) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index 279c9f0007bc..0532cda03d9a 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Optional -from vllm.multimodal import MultiModalKwargs +from vllm.multimodal import MultiModalKwargs, MultiModalRegistry from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata from vllm.utils import is_list_of @@ -46,10 +46,11 @@ if TYPE_CHECKING: class MultiModalInputCacheClient: """Used by P0 to check whether multi-modal kwargs are cached in P1.""" - def __init__(self, model_config: "ModelConfig") -> None: + def __init__(self, model_config: "ModelConfig", + mm_registry: MultiModalRegistry) -> None: super().__init__() - self.enabled = model_config.enable_mm_input_cache + self.enabled = mm_registry.enable_mm_input_cache(model_config) self.mm_cache = MultiModalCache.get_lru_cache( model_config.get_mm_input_cache_gb(), MultiModalCacheItemMetadata, @@ -85,10 +86,11 @@ class MultiModalInputCacheClient: class MultiModalInputCacheServer: """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" - def __init__(self, model_config: "ModelConfig") -> None: + def __init__(self, model_config: "ModelConfig", + mm_registry: MultiModalRegistry) -> None: super().__init__() - self.enabled = model_config.enable_mm_input_cache + self.enabled = mm_registry.enable_mm_input_cache(model_config) self.mm_cache = MultiModalCache.get_lru_cache( model_config.get_mm_input_cache_gb(), MultiModalKwargs, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6e37ebeb8778..b9419142caf6 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -51,7 +51,7 @@ class Processor: mm_registry) self.mm_input_cache_client = MultiModalInputCacheClient( - self.model_config) + self.model_config, mm_registry) @property def mm_registry(self): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 08b253dcdb35..48ff50fd6bd8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -129,7 +129,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] - self.is_multimodal_model = model_config.is_multimodal_model self.is_pooling_model = model_config.pooler_config is not None self.is_encoder_only_model = False self.is_multimodal_raw_input_supported = ( @@ -149,6 +148,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + model_config) # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) @@ -330,7 +331,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mm_registry, max_model_len=self.max_model_len, max_num_reqs=self.max_num_reqs, - ) if self.is_multimodal_model else None) + ) if self.supports_mm_inputs \ + else None) self.reorder_batch_threshold: Optional[int] = None @@ -1479,14 +1481,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if self.is_multimodal_model: + if self.supports_mm_inputs: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - if self.is_multimodal_model and get_pp_group().is_first_rank: + if self.supports_mm_inputs and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -1817,7 +1819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: target_hidden_states = hidden_states[token_indices] mm_embeds = None - if self.is_multimodal_model: + if self.supports_mm_inputs: mm_embeds = self._gather_mm_embeddings(scheduler_output, shift_computed_tokens=1) @@ -2209,7 +2211,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - if self.is_multimodal_model: + if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) @@ -2417,7 +2419,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. - if self.is_multimodal_model: + if self.supports_mm_inputs: mm_budget = self.mm_budget assert mm_budget is not None diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 81252f9b606a..442c0ea068b9 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -157,7 +157,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cache_config.cache_dtype] self._hidden_states_dtype = self.dtype - self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len @@ -193,6 +192,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope + self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( + model_config) # TODO: Support M-RoPE (e.g, Qwen2-VL) assert not self.uses_mrope, "TPU does not support M-RoPE yet." @@ -293,7 +294,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.mm_registry, max_model_len=self.max_model_len, max_num_reqs=self.max_num_reqs, - ) if self.is_multimodal_model else None) + ) if self.supports_mm_inputs else None) if not self.use_spmd: self.sample_from_logits_func = torch.compile( @@ -947,7 +948,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _get_model_inputs(self, input_ids: torch.Tensor, mm_embeds: list[torch.Tensor]): - if self.is_multimodal_model: + if self.supports_mm_inputs: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -979,7 +980,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return self.kv_connector_no_forward(scheduler_output, self.vllm_config) - if self.is_multimodal_model: + if self.supports_mm_inputs: # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) @@ -1230,7 +1231,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @torch.no_grad() def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: - if self.is_multimodal_model: + if self.supports_mm_inputs: input_ids = None inputs_embeds = torch.zeros((num_tokens, self.hidden_size), dtype=self.dtype, @@ -1271,7 +1272,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): _num_slices_per_kv_cache_update_block, ) - if self.is_multimodal_model: + if self.supports_mm_inputs: torch._dynamo.mark_dynamic(inputs_embeds, 0) else: torch._dynamo.mark_dynamic(input_ids, 0) @@ -1305,7 +1306,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): xm.mark_step() # Captures metadata updates def _precompile_mm_encoder(self) -> None: - if not self.is_multimodal_model: + if not self.supports_mm_inputs: return # Pre-compile MM encoder for all supported data modalities. @@ -1527,7 +1528,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens: int, ) -> None: # Profile with multimodal encoder & encoder cache. - if self.is_multimodal_model: + if self.supports_mm_inputs: mm_budget = self.mm_budget assert mm_budget is not None @@ -1684,7 +1685,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) def reset_dynamo_cache(self): - if self.is_multimodal_model: + + # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs` + # since the compiled model object of the language backbone of a + # multimodal model needs to be extracted via `get_language_model`. + if self.model_config.is_multimodal_model: compiled_model = self.model.get_language_model().model else: compiled_model = self.model.model