Implicit language-model-only mode via limit-mm-per-prompt (#22299)

Signed-off-by: Roger Wang <hey@rogerw.me>
Signed-off-by: Andy Xie <andy.xning@gmail.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: XIn Li <xinli@nvidia.com>
Signed-off-by: Junhao Li <junhao@ubicloud.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: zitian.zhao <zitian.zhao@tencentmusic.com>
Signed-off-by: zitian zhao <zitian.zhao@tencentmusic.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: iAmir97 <Amir.balwel@embeddedllm.com>
Signed-off-by: iAmir97 <71513472+iAmir97@users.noreply.github.com>
Signed-off-by: Linkun <github@lkchen.net>
Co-authored-by: Ning Xie <andy.xning@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
Co-authored-by: Andrew Sansom <andrew@protopia.ai>
Co-authored-by: Zhiyu <zhiyuc@nvidia.com>
Co-authored-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: XIn Li <xinli@nvidia.com>
Co-authored-by: Junhao Li <streaver91@gmail.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Co-authored-by: Yuxuan Zhang <2448370773@qq.com>
Co-authored-by: ZiTian Zhao <zitian.zhao@tencentmusic.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Po-Han Huang (NVIDIA) <53919306+nvpohanh@users.noreply.github.com>
Co-authored-by: iAmir97 <71513472+iAmir97@users.noreply.github.com>
Co-authored-by: iAmir97 <Amir.balwel@embeddedllm.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Hong Hanh <hanh.usth@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: lkchen <github@lkchen.net>
This commit is contained in:
Roger Wang 2025-08-08 22:21:40 -07:00 committed by GitHub
parent 429e4e2d42
commit 08b751ba74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 271 additions and 116 deletions

View File

@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for MultiModalRegistry.supports_multimodal_inputs and
Qwen2.5-VL visual component loading behavior.
"""
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ..models.utils import build_model_context
@pytest.mark.parametrize(
"model_id,limit_mm_per_prompt,expected",
[
("Qwen/Qwen2-0.5B-Instruct", {}, False),
("Qwen/Qwen2.5-VL-3B-Instruct", {}, True),
("Qwen/Qwen2.5-VL-3B-Instruct", {
"image": 0,
"video": 0
}, False),
("Qwen/Qwen2.5-VL-3B-Instruct", {
"image": 0
}, True),
],
)
@pytest.mark.core_model
def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected):
"""Test supports_multimodal_inputs returns correct boolean for various
configs."""
ctx = build_model_context(
model_id,
limit_mm_per_prompt=limit_mm_per_prompt,
)
assert MULTIMODAL_REGISTRY.supports_multimodal_inputs(
ctx.model_config) is expected

View File

@ -1695,15 +1695,6 @@ class ModelConfig:
return mm_config.mm_processor_cache_gb > 0
@property
def enable_mm_input_cache(self) -> bool:
"""Whether the multi-modal input cache should be enabled."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return mm_config.mm_processor_cache_gb > 0
def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config
if mm_config is None:

View File

@ -521,18 +521,22 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
if multimodal_config.get_limit_per_prompt("image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
else:
self.vision_tower = None
self.multi_modal_projector = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@ -756,7 +760,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -428,20 +428,24 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = Mistral3MultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
spatial_merge_size=config.spatial_merge_size,
patch_size=config.vision_config.patch_size,
multimodal_projector_bias=config.multimodal_projector_bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
if multimodal_config.get_limit_per_prompt("image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = Mistral3MultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
spatial_merge_size=config.spatial_merge_size,
patch_size=config.vision_config.patch_size,
multimodal_projector_bias=config.multimodal_projector_bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
else:
self.vision_tower = None
self.multi_modal_projector = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@ -611,7 +615,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.vision_tower is None and self.multi_modal_projector is None:
skip_prefixes = ["vision_tower.", "multi_modal_projector."]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:

View File

@ -737,16 +737,20 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vision_model = Llama4VisionModel(
config.vision_config,
None,
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
)
self.multi_modal_projector = Llama4MultiModalProjector(
self.config,
None,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
if multimodal_config.get_limit_per_prompt("image"):
self.vision_model = Llama4VisionModel(
config.vision_config,
None,
prefix=maybe_prefix(prefix, "vision_model"),
use_data_parallel=self.use_data_parallel,
)
self.multi_modal_projector = Llama4MultiModalProjector(
self.config,
None,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
else:
self.vision_model = None
self.multi_modal_projector = None
self.language_model = initialize_model(
vllm_config=vllm_config.with_hf_config(config.text_config,
["LlamaForCausalLM"]),
@ -783,6 +787,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_input(
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
assert self.vision_model and self.multi_modal_projector
flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist()
@ -1048,6 +1054,10 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
language_model_weights, other_weights = (
self._separate_and_rename_weights(weights))
# Skip loading vision model and projector if they're not initialized.
if self.vision_model is None and self.multi_modal_projector is None:
other_weights = []
# Handle expert scale parameters
regular_weights, expert_scale_weights, updated_params_from_experts = (
self._handle_expert_scale_broadcasting(language_model_weights,

View File

@ -722,13 +722,24 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
"exactly same result as the transformers implementation "
"in the audio tower part.")
self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
self.visual = Qwen2_5_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
if multimodal_config.get_limit_per_prompt("audio"):
self.audio_tower = Qwen2_5OmniAudioEncoder(
thinker_config.audio_config)
else:
self.audio_tower = None
if multimodal_config.get_limit_per_prompt(
"image") or multimodal_config.get_limit_per_prompt("video"):
self.visual = Qwen2_5_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps",
1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
self.quant_config = quant_config
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@ -886,9 +897,15 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = ["talker.", "token2wav."]
if self.audio_tower is None:
skip_prefixes.extend(["audio_tower."])
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(
self,
skip_prefixes=["talker.", "token2wav."],
skip_prefixes=skip_prefixes,
)
loaded_weights = loader.load_weights(weights,
mapper=self.hf_to_vllm_mapper)

View File

@ -843,12 +843,17 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen2_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(self.quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
if multimodal_config.get_limit_per_prompt("image") or \
multimodal_config.get_limit_per_prompt("video"):
self.visual = Qwen2_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(
self.quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@ -1152,7 +1157,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:

View File

@ -1049,12 +1049,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
if multimodal_config.get_limit_per_prompt("image") or \
multimodal_config.get_limit_per_prompt("video"):
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@ -1350,7 +1354,10 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
@ -1445,5 +1452,8 @@ class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@ -837,27 +837,35 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
self.vision_model = Step3VisionTransformer(config.vision_config,
None,
prefix=maybe_prefix(
prefix, "vision_model"))
self.vit_downsampler = nn.Conv2d(
config.vision_config.hidden_size,
config.vision_config.output_hidden_size,
kernel_size=2,
stride=config.understand_projector_stride)
self.vit_downsampler2 = nn.Conv2d(
config.vision_config.output_hidden_size,
config.vision_config.output_hidden_size * 2,
kernel_size=3,
stride=2,
padding=1,
)
self.vit_large_projector = nn.Linear(
config.vision_config.output_hidden_size * 2,
config.hidden_size,
bias=config.projector_bias,
)
if multimodal_config.get_limit_per_prompt("image"):
self.vision_model = Step3VisionTransformer(config.vision_config,
None,
prefix=maybe_prefix(
prefix,
"vision_model"))
self.vit_downsampler = nn.Conv2d(
config.vision_config.hidden_size,
config.vision_config.output_hidden_size,
kernel_size=2,
stride=config.understand_projector_stride)
self.vit_downsampler2 = nn.Conv2d(
config.vision_config.output_hidden_size,
config.vision_config.output_hidden_size * 2,
kernel_size=3,
stride=2,
padding=1,
)
self.vit_large_projector = nn.Linear(
config.vision_config.output_hidden_size * 2,
config.hidden_size,
bias=config.projector_bias,
)
else:
self.vision_model = None
self.vit_downsampler = None
self.vit_downsampler2 = None
self.vit_large_projector = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
@ -1046,7 +1054,15 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.vision_model is None and self.vit_large_projector is None:
skip_prefixes = [
"vision_model.", "vit_downsampler.", "vit_downsampler2.",
"vit_large_projector."
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loaded_weights = loader.load_weights(weights,
mapper=self.hf_to_vllm_mapper)
return loaded_weights

View File

@ -115,6 +115,45 @@ class MultiModalRegistry:
return True # Success
def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool:
"""Whether the multi-modal input cache should be enabled.
NOTE: This is put under MultiModalRegistry on purpose to respect
text-only mode for multimodal models.
"""
if not self.supports_multimodal_inputs(model_config):
return False
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
"""
Checks if the model supports multimodal inputs.
Returns True if the model is multimodal with any non-zero supported
modalities, otherwise returns False, effectively running in
text-only mode.
"""
if not model_config.is_multimodal_model:
return False
processor = self.create_processor(model_config, disable_cache=False)
supported_modalities = processor.info.get_supported_mm_limits()
mm_config = model_config.get_multimodal_config()
# Check if all supported modalities have limit == 0
if all(
mm_config.get_limit_per_prompt(modality) == 0
for modality in supported_modalities):
logger.info_once(
"All limits of multimodal modalities supported by the model "
"are set to 0, running in text-only mode.")
return False
return True
def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",

View File

@ -189,7 +189,7 @@ def compute_encoder_budget(
in the input sequence.
"""
if not model_config.is_multimodal_model:
if not mm_registry.supports_multimodal_inputs(model_config):
return 0, 0
# TODO: handle encoder-decoder models once we support them.

View File

@ -21,6 +21,7 @@ from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
@ -125,7 +126,7 @@ class EngineCore:
)
self.mm_input_cache_server = MultiModalInputCacheServer(
vllm_config.model_config)
vllm_config.model_config, MULTIMODAL_REGISTRY)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously

View File

@ -3,7 +3,7 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal import MultiModalKwargs, MultiModalRegistry
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.utils import is_list_of
@ -46,10 +46,11 @@ if TYPE_CHECKING:
class MultiModalInputCacheClient:
"""Used by P0 to check whether multi-modal kwargs are cached in P1."""
def __init__(self, model_config: "ModelConfig") -> None:
def __init__(self, model_config: "ModelConfig",
mm_registry: MultiModalRegistry) -> None:
super().__init__()
self.enabled = model_config.enable_mm_input_cache
self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalCacheItemMetadata,
@ -85,10 +86,11 @@ class MultiModalInputCacheClient:
class MultiModalInputCacheServer:
"""Used by P1 to avoid requiring past multi-modal kwargs from P0."""
def __init__(self, model_config: "ModelConfig") -> None:
def __init__(self, model_config: "ModelConfig",
mm_registry: MultiModalRegistry) -> None:
super().__init__()
self.enabled = model_config.enable_mm_input_cache
self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalKwargs,

View File

@ -51,7 +51,7 @@ class Processor:
mm_registry)
self.mm_input_cache_client = MultiModalInputCacheClient(
self.model_config)
self.model_config, mm_registry)
@property
def mm_registry(self):

View File

@ -129,7 +129,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.is_multimodal_model = model_config.is_multimodal_model
self.is_pooling_model = model_config.pooler_config is not None
self.is_encoder_only_model = False
self.is_multimodal_raw_input_supported = (
@ -149,6 +148,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config)
# Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
@ -330,7 +331,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.mm_registry,
max_model_len=self.max_model_len,
max_num_reqs=self.max_num_reqs,
) if self.is_multimodal_model else None)
) if self.supports_mm_inputs \
else None)
self.reorder_batch_threshold: Optional[int] = None
@ -1479,14 +1481,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
if self.is_multimodal_model:
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
else:
mm_embeds = []
if self.is_multimodal_model and get_pp_group().is_first_rank:
if self.supports_mm_inputs and get_pp_group().is_first_rank:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
@ -1817,7 +1819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
target_hidden_states = hidden_states[token_indices]
mm_embeds = None
if self.is_multimodal_model:
if self.supports_mm_inputs:
mm_embeds = self._gather_mm_embeddings(scheduler_output,
shift_computed_tokens=1)
@ -2209,7 +2211,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
if self.is_multimodal_model:
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
model_mm_kwargs = self._dummy_mm_kwargs(num_reqs)
@ -2417,7 +2419,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache.
if self.is_multimodal_model:
if self.supports_mm_inputs:
mm_budget = self.mm_budget
assert mm_budget is not None

View File

@ -157,7 +157,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cache_config.cache_dtype]
self._hidden_states_dtype = self.dtype
self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
@ -193,6 +192,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config)
# TODO: Support M-RoPE (e.g, Qwen2-VL)
assert not self.uses_mrope, "TPU does not support M-RoPE yet."
@ -293,7 +294,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.mm_registry,
max_model_len=self.max_model_len,
max_num_reqs=self.max_num_reqs,
) if self.is_multimodal_model else None)
) if self.supports_mm_inputs else None)
if not self.use_spmd:
self.sample_from_logits_func = torch.compile(
@ -947,7 +948,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_model_inputs(self, input_ids: torch.Tensor,
mm_embeds: list[torch.Tensor]):
if self.is_multimodal_model:
if self.supports_mm_inputs:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
@ -979,7 +980,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
if self.is_multimodal_model:
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
@ -1230,7 +1231,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
@torch.no_grad()
def _dummy_run(self, num_tokens: int, num_reqs: int,
num_blocks: int) -> None:
if self.is_multimodal_model:
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
dtype=self.dtype,
@ -1271,7 +1272,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
_num_slices_per_kv_cache_update_block,
)
if self.is_multimodal_model:
if self.supports_mm_inputs:
torch._dynamo.mark_dynamic(inputs_embeds, 0)
else:
torch._dynamo.mark_dynamic(input_ids, 0)
@ -1305,7 +1306,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
xm.mark_step() # Captures metadata updates
def _precompile_mm_encoder(self) -> None:
if not self.is_multimodal_model:
if not self.supports_mm_inputs:
return
# Pre-compile MM encoder for all supported data modalities.
@ -1527,7 +1528,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens: int,
) -> None:
# Profile with multimodal encoder & encoder cache.
if self.is_multimodal_model:
if self.supports_mm_inputs:
mm_budget = self.mm_budget
assert mm_budget is not None
@ -1684,7 +1685,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)
def reset_dynamo_cache(self):
if self.is_multimodal_model:
# NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs`
# since the compiled model object of the language backbone of a
# multimodal model needs to be extracted via `get_language_model`.
if self.model_config.is_multimodal_model:
compiled_model = self.model.get_language_model().model
else:
compiled_model = self.model.model