diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 1089de87b994e..ec3ba4474c192 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -740,23 +740,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor
E Pre-computed embeddings can be inputted for this modality.
+ Multiple items can be inputted per text prompt for this modality.
-!!! warning
- Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
- However, there are differences in how they handle text + image inputs:
-
- V0 correctly implements the model's attention pattern:
- - Uses bidirectional attention between the image tokens corresponding to the same image
- - Uses causal attention for other tokens
- - Implemented via (naive) PyTorch SDPA with masking tensors
- - Note: May use significant memory for long prompts with image
-
- V1 currently uses a simplified attention pattern:
- - Uses causal attention for all tokens, including image tokens
- - Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}`
- - Will be updated in the future to support the correct behavior
-
- This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
-
!!! note
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
MobileNet-v5 vision backbone.
@@ -776,9 +759,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
For more details, please see:
-!!! warning
- Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
-
!!! note
For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported.
diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py
index fd26b838ae209..c5a0b6748f797 100644
--- a/tests/models/multimodal/generation/test_common.py
+++ b/tests/models/multimodal/generation/test_common.py
@@ -382,7 +382,6 @@ VLM_TEST_SETTINGS = {
auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
- num_logprobs=10,
),
"glm4v": VLMTestInfo(
models=["zai-org/glm-4v-9b"],
diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
index a80617a366cab..8448003e70531 100644
--- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
+++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
@@ -30,5 +30,6 @@ class DummyPlatform(Platform):
use_mla,
has_sink,
use_sparse,
+ use_mm_prefix,
):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py
index 84cca8e686075..03f4c40302eb8 100644
--- a/vllm/attention/backends/abstract.py
+++ b/vllm/attention/backends/abstract.py
@@ -166,6 +166,10 @@ class AttentionBackend(ABC):
def supports_sink(cls) -> bool:
return False
+ @classmethod
+ def supports_mm_prefix(cls) -> bool:
+ return False
+
@classmethod
def is_sparse(cls) -> bool:
return False
@@ -207,6 +211,7 @@ class AttentionBackend(ABC):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
+ use_mm_prefix: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
@@ -219,6 +224,10 @@ class AttentionBackend(ABC):
invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported")
+ if use_mm_prefix and not cls.supports_mm_prefix():
+ invalid_reasons.append(
+ "partial multimodal token full attention not supported"
+ )
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index 8a522deedf3ce..340b161ea1e15 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -230,6 +230,10 @@ class Attention(nn.Module, AttentionLayerBase):
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
+ # NOTE: model_config may be None during certain tests
+ model_config = vllm_config.model_config
+ self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
+
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
@@ -241,6 +245,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size,
use_mla=False,
has_sink=self.has_sink,
+ use_mm_prefix=self.use_mm_prefix,
attn_type=attn_type,
)
else:
diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py
index aeb130dfe8726..f6aba271d2e96 100644
--- a/vllm/attention/selector.py
+++ b/vllm/attention/selector.py
@@ -27,6 +27,7 @@ def get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
+ use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
@@ -52,6 +53,7 @@ def get_attn_backend(
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
+ use_mm_prefix=use_mm_prefix,
attn_type=attn_type,
)
@@ -66,6 +68,7 @@ def _cached_get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
+ use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
from vllm.platforms import current_platform
@@ -87,6 +90,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
+ use_mm_prefix,
attn_type,
)
else:
@@ -99,6 +103,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
+ use_mm_prefix,
attn_type,
)
if not attention_cls:
diff --git a/vllm/config/model.py b/vllm/config/model.py
index 509a9c5e162f7..583904a949ea1 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -4,6 +4,7 @@
import warnings
from collections.abc import Callable
from dataclasses import InitVar, field
+from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch
@@ -1217,6 +1218,19 @@ class ModelConfig:
)
return False
+ @cached_property
+ def is_mm_prefix_lm(self) -> bool:
+ """Whether to use bidirectional attention for mm positions."""
+ MM_PREFIX_LM_MODELS = (
+ "gemma3",
+ # TODO(Isotr0py): Disable paligemma for now before
+ # we supports soft cap attention for FlexAttention
+ # "paligemma",
+ )
+ if not hasattr(self.hf_config, "model_type"):
+ return False
+ return self.hf_config.model_type in MM_PREFIX_LM_MODELS
+
def get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:
diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py
index d9118f5b9e9a5..2ed66554e358e 100644
--- a/vllm/multimodal/inputs.py
+++ b/vllm/multimodal/inputs.py
@@ -175,6 +175,31 @@ class PlaceholderRange:
return int(self.is_embed.sum().item())
+ def extract_embeds_range(self) -> list[tuple[int, int]]:
+ """Extract the start and end indices of the embedded region in prompt.
+
+ For example, given `PlaceholderRange(offset=2, length=5)` and
+ `is_embed = [False, True, False, True, True]`, the output is
+ `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.
+
+ Returns:
+ A tuple `(start, end)` representing the start and end
+ indices (inclusive) of the embedded region.
+ Returns full placeholder range if `is_embed` is `None`.
+ """
+ if self.is_embed is None:
+ return [(self.offset, self.offset + self.length)]
+
+ mask_i = self.is_embed.int()
+ starts = torch.nonzero(
+ torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
+ ).flatten()
+ ends = torch.nonzero(
+ torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
+ ).flatten()
+ ranges = torch.stack((starts, ends), dim=1) + self.offset
+ return [tuple(x) for x in ranges.tolist()]
+
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py
index a2518d5fd3dc4..a49b6e92df00d 100644
--- a/vllm/platforms/cpu.py
+++ b/vllm/platforms/cpu.py
@@ -133,6 +133,7 @@ class CpuPlatform(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
+ use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py
index 37c95f48669f8..39101c43142f7 100644
--- a/vllm/platforms/cuda.py
+++ b/vllm/platforms/cuda.py
@@ -233,6 +233,20 @@ class CudaPlatformBase(Platform):
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
+ scheduler_config = vllm_config.scheduler_config
+ # Note: model_config may be None during testing
+ if (
+ model_config is not None
+ and model_config.is_mm_prefix_lm
+ and scheduler_config.is_multimodal_model
+ and not scheduler_config.disable_chunked_mm_input
+ ):
+ logger.warning(
+ "Forcing --disable_chunked_mm_input for models "
+ "with multimodal-bidirectional attention."
+ )
+ scheduler_config.disable_chunked_mm_input = True
+
@classmethod
def get_current_memory_usage(
cls, device: torch.types.Device | None = None
@@ -268,6 +282,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
+ use_mm_prefix,
device_capability,
attn_type,
) -> tuple[
@@ -289,6 +304,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
+ use_mm_prefix,
device_capability,
attn_type,
)
@@ -312,6 +328,7 @@ class CudaPlatformBase(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
+ use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if attn_type is None:
@@ -332,6 +349,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
+ use_mm_prefix,
device_capability,
attn_type,
)
@@ -356,6 +374,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
+ use_mm_prefix,
device_capability,
attn_type,
)
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index 27c6fac09f498..f04e94e425257 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -239,6 +239,7 @@ class Platform:
use_mla: bool,
has_sink: bool,
use_sparse: bool,
+ use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
"""Get the attention backend class of a device."""
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index 32c7f8e536639..ff0fc78517876 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -216,6 +216,7 @@ class RocmPlatform(Platform):
use_mla,
has_sink,
use_sparse,
+ use_mm_prefix,
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py
index cbc0a996f3661..d6998e7a308af 100644
--- a/vllm/platforms/tpu.py
+++ b/vllm/platforms/tpu.py
@@ -62,8 +62,9 @@ class TpuPlatform(Platform):
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
- has_sink,
- use_sparse,
+ has_sink: bool,
+ use_sparse: bool,
+ use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if use_sparse:
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
index 768714fb16726..0a05750764d8d 100644
--- a/vllm/platforms/xpu.py
+++ b/vllm/platforms/xpu.py
@@ -48,7 +48,8 @@ class XPUPlatform(Platform):
block_size: int,
use_mla: bool,
has_sink: bool,
- use_sparse,
+ use_sparse: bool,
+ use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py
index a2a6eeeb16b24..d8dbe4cbae013 100644
--- a/vllm/v1/attention/backends/flex_attention.py
+++ b/vllm/v1/attention/backends/flex_attention.py
@@ -17,6 +17,7 @@ from torch.nn.attention.flex_attention import (
and_masks,
create_block_mask,
flex_attention,
+ or_masks,
)
from vllm.attention.backends.abstract import (
@@ -42,6 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
+torch._dynamo.config.recompile_limit = 16
create_block_mask_compiled = torch.compile(
create_block_mask, fullgraph=True, mode="reduce-overhead"
)
@@ -91,6 +93,11 @@ class FlexAttentionBackend(AttentionBackend):
"""FlexAttention supports both decoder and encoder-only attention."""
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
+ @classmethod
+ def supports_mm_prefix(cls) -> bool:
+ """FlexAttention supports full attention for image tokens."""
+ return True
+
@staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl
@@ -316,6 +323,7 @@ class FlexAttentionMetadata:
kv_block_size: int = 16
transformed_score_mod: _score_mod_signature | None = None
sliding_window: int | None = None
+ mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
@cached_property
def logical_block_ids(self):
@@ -443,6 +451,45 @@ class FlexAttentionMetadata:
return final_mask_mod if self.causal else sliding_window_mask_mod
+ def get_prefix_lm_mask_mod(self) -> _mask_mod_signature:
+ """Creates the prefix LM mask_mod function for FlexAttention."""
+
+ assert self.doc_ids is not None
+ request_lookup = self.doc_ids
+
+ def prefix_lm_mask_mod(
+ b: torch.Tensor,
+ h: torch.Tensor,
+ cu_q_idx: torch.Tensor,
+ q_idx: torch.Tensor,
+ kv_idx: torch.Tensor,
+ ):
+ mask = torch.zeros_like(q_idx, dtype=torch.bool)
+ for req, doc_range_lst in (self.mm_prefix_range or {}).items():
+ req_mask = request_lookup[cu_q_idx] == req
+ for start, end in doc_range_lst:
+ doc_mask_q = (q_idx >= start) & (q_idx <= end)
+ doc_mask_kv = (kv_idx >= start) & (kv_idx <= end)
+ mask = mask | (req_mask & doc_mask_q & doc_mask_kv)
+ return mask
+
+ def final_mask_mod(
+ b: torch.Tensor,
+ h: torch.Tensor,
+ q_idx: torch.Tensor,
+ physical_kv_idx: torch.Tensor,
+ ) -> torch.Tensor:
+ (is_valid, logical_q_idx, logical_kv_idx) = (
+ self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
+ )
+ return torch.where(
+ is_valid,
+ prefix_lm_mask_mod(b, h, q_idx, logical_q_idx, logical_kv_idx),
+ False,
+ )
+
+ return final_mask_mod
+
def get_mask_mod(self):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
@@ -456,6 +503,10 @@ class FlexAttentionMetadata:
# Add sliding window mask for sliding window attention
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
+ if self.mm_prefix_range:
+ # Add prefix LM mask for vision-language prefix LM attention
+ prefix_lm_mask_mod = self.get_prefix_lm_mask_mod()
+ mask_mod = or_masks(mask_mod, prefix_lm_mask_mod)
return mask_mod
def get_transformed_score_mod(self) -> _score_mod_signature | None:
@@ -709,6 +760,7 @@ class FlexAttentionImpl(AttentionImpl):
sliding_window: int | None
alibi_slopes: torch.Tensor | None
logits_soft_cap: float | None
+ mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
def __init__(
self,
@@ -810,11 +862,21 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
+ needs_rebuild_block_mask = False
if attn_metadata.sliding_window != self.sliding_window:
attn_metadata.sliding_window = self.sliding_window
if attn_metadata.direct_build:
# update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
+ needs_rebuild_block_mask = True
+
+ if self.mm_prefix_range != getattr(attn_metadata, "mm_prefix_range", None):
+ self.mm_prefix_range = attn_metadata.mm_prefix_range
+ attn_metadata.mask_mod = attn_metadata.get_mask_mod()
+ needs_rebuild_block_mask = True
+
+ if needs_rebuild_block_mask:
+ if attn_metadata.direct_build and attn_metadata.causal:
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
else:
attn_metadata.block_mask = attn_metadata.build_block_mask()
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index a50360ab08694..22a3f9d8d2dda 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -48,7 +48,10 @@ from vllm.distributed.parallel_state import (
is_global_first_rank,
prepare_communication_buffer_for_model,
)
-from vllm.forward_context import BatchDescriptor, set_forward_context
+from vllm.forward_context import (
+ BatchDescriptor,
+ set_forward_context,
+)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.rotary_embedding import (
@@ -329,6 +332,7 @@ class GPUModelRunner(
self.use_alibi = model_config.uses_alibi
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
+ self.is_mm_prefix_lm = self.model_config.is_mm_prefix_lm
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
@@ -1700,6 +1704,26 @@ class GPUModelRunner(
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
+ if self.is_mm_prefix_lm:
+ req_doc_ranges = {}
+ for req_id in self.input_batch.req_ids:
+ image_doc_ranges = []
+ req_state = self.requests[req_id]
+ for mm_feature in req_state.mm_features:
+ pos_info = mm_feature.mm_position
+ img_doc_range = pos_info.extract_embeds_range()
+ image_doc_ranges.extend(img_doc_range)
+ req_idx = self.input_batch.req_id_to_index[req_id]
+ req_doc_ranges[req_idx] = image_doc_ranges
+
+ if isinstance(attn_metadata, list):
+ for ub_metadata in attn_metadata:
+ for _metadata in ub_metadata.values():
+ _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
+ else:
+ for _metadata in attn_metadata.values():
+ _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
+
if spec_decode_common_attn_metadata is not None and (
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
):