diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 1089de87b994e..ec3ba4474c192 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -740,23 +740,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor E Pre-computed embeddings can be inputted for this modality. + Multiple items can be inputted per text prompt for this modality. -!!! warning - Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. - However, there are differences in how they handle text + image inputs: - - V0 correctly implements the model's attention pattern: - - Uses bidirectional attention between the image tokens corresponding to the same image - - Uses causal attention for other tokens - - Implemented via (naive) PyTorch SDPA with masking tensors - - Note: May use significant memory for long prompts with image - - V1 currently uses a simplified attention pattern: - - Uses causal attention for all tokens, including image tokens - - Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}` - - Will be updated in the future to support the correct behavior - - This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. - !!! note `Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its MobileNet-v5 vision backbone. @@ -776,9 +759,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now. For more details, please see: -!!! warning - Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. - !!! note For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported. diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index fd26b838ae209..c5a0b6748f797 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -382,7 +382,6 @@ VLM_TEST_SETTINGS = { auto_cls=AutoModelForImageTextToText, vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}}, patch_hf_runner=model_utils.gemma3_patch_hf_runner, - num_logprobs=10, ), "glm4v": VLMTestInfo( models=["zai-org/glm-4v-9b"], diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index a80617a366cab..8448003e70531 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -30,5 +30,6 @@ class DummyPlatform(Platform): use_mla, has_sink, use_sparse, + use_mm_prefix, ): return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 84cca8e686075..03f4c40302eb8 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -166,6 +166,10 @@ class AttentionBackend(ABC): def supports_sink(cls) -> bool: return False + @classmethod + def supports_mm_prefix(cls) -> bool: + return False + @classmethod def is_sparse(cls) -> bool: return False @@ -207,6 +211,7 @@ class AttentionBackend(ABC): use_mla: bool, has_sink: bool, use_sparse: bool, + use_mm_prefix: bool, device_capability: "DeviceCapability", attn_type: str, ) -> list[str]: @@ -219,6 +224,10 @@ class AttentionBackend(ABC): invalid_reasons.append("kv_cache_dtype not supported") if not cls.supports_block_size(block_size): invalid_reasons.append("block_size not supported") + if use_mm_prefix and not cls.supports_mm_prefix(): + invalid_reasons.append( + "partial multimodal token full attention not supported" + ) if use_mla != cls.is_mla(): if use_mla: invalid_reasons.append("MLA not supported") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 8a522deedf3ce..340b161ea1e15 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -230,6 +230,10 @@ class Attention(nn.Module, AttentionLayerBase): self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None + # NOTE: model_config may be None during certain tests + model_config = vllm_config.model_config + self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm + # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() @@ -241,6 +245,7 @@ class Attention(nn.Module, AttentionLayerBase): block_size, use_mla=False, has_sink=self.has_sink, + use_mm_prefix=self.use_mm_prefix, attn_type=attn_type, ) else: diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index aeb130dfe8726..f6aba271d2e96 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -27,6 +27,7 @@ def get_attn_backend( use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, + use_mm_prefix: bool = False, attn_type: str | None = None, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" @@ -52,6 +53,7 @@ def get_attn_backend( use_mla=use_mla, has_sink=has_sink, use_sparse=use_sparse, + use_mm_prefix=use_mm_prefix, attn_type=attn_type, ) @@ -66,6 +68,7 @@ def _cached_get_attn_backend( use_mla: bool = False, has_sink: bool = False, use_sparse: bool = False, + use_mm_prefix: bool = False, attn_type: str | None = None, ) -> type[AttentionBackend]: from vllm.platforms import current_platform @@ -87,6 +90,7 @@ def _cached_get_attn_backend( use_mla, has_sink, use_sparse, + use_mm_prefix, attn_type, ) else: @@ -99,6 +103,7 @@ def _cached_get_attn_backend( use_mla, has_sink, use_sparse, + use_mm_prefix, attn_type, ) if not attention_cls: diff --git a/vllm/config/model.py b/vllm/config/model.py index 509a9c5e162f7..583904a949ea1 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -4,6 +4,7 @@ import warnings from collections.abc import Callable from dataclasses import InitVar, field +from functools import cached_property from typing import TYPE_CHECKING, Any, Literal, cast, get_args import torch @@ -1217,6 +1218,19 @@ class ModelConfig: ) return False + @cached_property + def is_mm_prefix_lm(self) -> bool: + """Whether to use bidirectional attention for mm positions.""" + MM_PREFIX_LM_MODELS = ( + "gemma3", + # TODO(Isotr0py): Disable paligemma for now before + # we supports soft cap attention for FlexAttention + # "paligemma", + ) + if not hasattr(self.hf_config, "model_type"): + return False + return self.hf_config.model_type in MM_PREFIX_LM_MODELS + def get_head_size(self) -> int: # TODO remove hard code if self.is_deepseek_mla: diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index d9118f5b9e9a5..2ed66554e358e 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -175,6 +175,31 @@ class PlaceholderRange: return int(self.is_embed.sum().item()) + def extract_embeds_range(self) -> list[tuple[int, int]]: + """Extract the start and end indices of the embedded region in prompt. + + For example, given `PlaceholderRange(offset=2, length=5)` and + `is_embed = [False, True, False, True, True]`, the output is + `[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`. + + Returns: + A tuple `(start, end)` representing the start and end + indices (inclusive) of the embedded region. + Returns full placeholder range if `is_embed` is `None`. + """ + if self.is_embed is None: + return [(self.offset, self.offset + self.length)] + + mask_i = self.is_embed.int() + starts = torch.nonzero( + torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1 + ).flatten() + ends = torch.nonzero( + torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1 + ).flatten() + ranges = torch.stack((starts, ends), dim=1) + self.offset + return [tuple(x) for x in ranges.tolist()] + def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index a2518d5fd3dc4..a49b6e92df00d 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -133,6 +133,7 @@ class CpuPlatform(Platform): use_mla: bool, has_sink: bool, use_sparse: bool, + use_mm_prefix: bool, attn_type: str | None = None, ) -> str: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 37c95f48669f8..39101c43142f7 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -233,6 +233,20 @@ class CudaPlatformBase(Platform): "Forcing kv cache block size to 64 for FlashMLASparse backend." ) + scheduler_config = vllm_config.scheduler_config + # Note: model_config may be None during testing + if ( + model_config is not None + and model_config.is_mm_prefix_lm + and scheduler_config.is_multimodal_model + and not scheduler_config.disable_chunked_mm_input + ): + logger.warning( + "Forcing --disable_chunked_mm_input for models " + "with multimodal-bidirectional attention." + ) + scheduler_config.disable_chunked_mm_input = True + @classmethod def get_current_memory_usage( cls, device: torch.types.Device | None = None @@ -268,6 +282,7 @@ class CudaPlatformBase(Platform): use_mla, has_sink, use_sparse, + use_mm_prefix, device_capability, attn_type, ) -> tuple[ @@ -289,6 +304,7 @@ class CudaPlatformBase(Platform): use_mla, has_sink, use_sparse, + use_mm_prefix, device_capability, attn_type, ) @@ -312,6 +328,7 @@ class CudaPlatformBase(Platform): use_mla: bool, has_sink: bool, use_sparse: bool, + use_mm_prefix: bool, attn_type: str | None = None, ) -> str: if attn_type is None: @@ -332,6 +349,7 @@ class CudaPlatformBase(Platform): use_mla, has_sink, use_sparse, + use_mm_prefix, device_capability, attn_type, ) @@ -356,6 +374,7 @@ class CudaPlatformBase(Platform): use_mla, has_sink, use_sparse, + use_mm_prefix, device_capability, attn_type, ) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 27c6fac09f498..f04e94e425257 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -239,6 +239,7 @@ class Platform: use_mla: bool, has_sink: bool, use_sparse: bool, + use_mm_prefix: bool, attn_type: str | None = None, ) -> str: """Get the attention backend class of a device.""" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 32c7f8e536639..ff0fc78517876 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -216,6 +216,7 @@ class RocmPlatform(Platform): use_mla, has_sink, use_sparse, + use_mm_prefix, attn_type: str | None = None, ) -> str: from vllm._aiter_ops import rocm_aiter_ops diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index cbc0a996f3661..d6998e7a308af 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -62,8 +62,9 @@ class TpuPlatform(Platform): kv_cache_dtype: str | None, block_size: int, use_mla: bool, - has_sink, - use_sparse, + has_sink: bool, + use_sparse: bool, + use_mm_prefix: bool, attn_type: str | None = None, ) -> str: if use_sparse: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 768714fb16726..0a05750764d8d 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -48,7 +48,8 @@ class XPUPlatform(Platform): block_size: int, use_mla: bool, has_sink: bool, - use_sparse, + use_sparse: bool, + use_mm_prefix: bool, attn_type: str | None = None, ) -> str: from vllm.v1.attention.backends.utils import set_kv_cache_layout diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index a2a6eeeb16b24..d8dbe4cbae013 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -17,6 +17,7 @@ from torch.nn.attention.flex_attention import ( and_masks, create_block_mask, flex_attention, + or_masks, ) from vllm.attention.backends.abstract import ( @@ -42,6 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) +torch._dynamo.config.recompile_limit = 16 create_block_mask_compiled = torch.compile( create_block_mask, fullgraph=True, mode="reduce-overhead" ) @@ -91,6 +93,11 @@ class FlexAttentionBackend(AttentionBackend): """FlexAttention supports both decoder and encoder-only attention.""" return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY) + @classmethod + def supports_mm_prefix(cls) -> bool: + """FlexAttention supports full attention for image tokens.""" + return True + @staticmethod def get_impl_cls() -> type["FlexAttentionImpl"]: return FlexAttentionImpl @@ -316,6 +323,7 @@ class FlexAttentionMetadata: kv_block_size: int = 16 transformed_score_mod: _score_mod_signature | None = None sliding_window: int | None = None + mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None @cached_property def logical_block_ids(self): @@ -443,6 +451,45 @@ class FlexAttentionMetadata: return final_mask_mod if self.causal else sliding_window_mask_mod + def get_prefix_lm_mask_mod(self) -> _mask_mod_signature: + """Creates the prefix LM mask_mod function for FlexAttention.""" + + assert self.doc_ids is not None + request_lookup = self.doc_ids + + def prefix_lm_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + cu_q_idx: torch.Tensor, + q_idx: torch.Tensor, + kv_idx: torch.Tensor, + ): + mask = torch.zeros_like(q_idx, dtype=torch.bool) + for req, doc_range_lst in (self.mm_prefix_range or {}).items(): + req_mask = request_lookup[cu_q_idx] == req + for start, end in doc_range_lst: + doc_mask_q = (q_idx >= start) & (q_idx <= end) + doc_mask_kv = (kv_idx >= start) & (kv_idx <= end) + mask = mask | (req_mask & doc_mask_q & doc_mask_kv) + return mask + + def final_mask_mod( + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) + return torch.where( + is_valid, + prefix_lm_mask_mod(b, h, q_idx, logical_q_idx, logical_kv_idx), + False, + ) + + return final_mask_mod + def get_mask_mod(self): # Stage-1: initialize the base mask_mod # (causal mask for decoder or bidirectional mask for encoder) @@ -456,6 +503,10 @@ class FlexAttentionMetadata: # Add sliding window mask for sliding window attention sliding_window_mask_mod = self.get_sliding_window_mask_mod() mask_mod = and_masks(mask_mod, sliding_window_mask_mod) + if self.mm_prefix_range: + # Add prefix LM mask for vision-language prefix LM attention + prefix_lm_mask_mod = self.get_prefix_lm_mask_mod() + mask_mod = or_masks(mask_mod, prefix_lm_mask_mod) return mask_mod def get_transformed_score_mod(self) -> _score_mod_signature | None: @@ -709,6 +760,7 @@ class FlexAttentionImpl(AttentionImpl): sliding_window: int | None alibi_slopes: torch.Tensor | None logits_soft_cap: float | None + mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None def __init__( self, @@ -810,11 +862,21 @@ class FlexAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens + needs_rebuild_block_mask = False if attn_metadata.sliding_window != self.sliding_window: attn_metadata.sliding_window = self.sliding_window if attn_metadata.direct_build: # update mask mod in attention metadata attn_metadata.mask_mod = attn_metadata.get_mask_mod() + needs_rebuild_block_mask = True + + if self.mm_prefix_range != getattr(attn_metadata, "mm_prefix_range", None): + self.mm_prefix_range = attn_metadata.mm_prefix_range + attn_metadata.mask_mod = attn_metadata.get_mask_mod() + needs_rebuild_block_mask = True + + if needs_rebuild_block_mask: + if attn_metadata.direct_build and attn_metadata.causal: attn_metadata.block_mask = attn_metadata._build_block_mask_direct() else: attn_metadata.block_mask = attn_metadata.build_block_mask() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a50360ab08694..22a3f9d8d2dda 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -48,7 +48,10 @@ from vllm.distributed.parallel_state import ( is_global_first_rank, prepare_communication_buffer_for_model, ) -from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.forward_context import ( + BatchDescriptor, + set_forward_context, +) from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.rotary_embedding import ( @@ -329,6 +332,7 @@ class GPUModelRunner( self.use_alibi = model_config.uses_alibi self.cascade_attn_enabled = not self.model_config.disable_cascade_attn + self.is_mm_prefix_lm = self.model_config.is_mm_prefix_lm # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY @@ -1700,6 +1704,26 @@ class GPUModelRunner( for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i + if self.is_mm_prefix_lm: + req_doc_ranges = {} + for req_id in self.input_batch.req_ids: + image_doc_ranges = [] + req_state = self.requests[req_id] + for mm_feature in req_state.mm_features: + pos_info = mm_feature.mm_position + img_doc_range = pos_info.extract_embeds_range() + image_doc_ranges.extend(img_doc_range) + req_idx = self.input_batch.req_id_to_index[req_id] + req_doc_ranges[req_idx] = image_doc_ranges + + if isinstance(attn_metadata, list): + for ub_metadata in attn_metadata: + for _metadata in ub_metadata.values(): + _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] + else: + for _metadata in attn_metadata.values(): + _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] + if spec_decode_common_attn_metadata is not None and ( num_reqs != num_reqs_padded or num_tokens != num_tokens_padded ):