[v1] Add PrefixLM support to FlexAttention backend (#27938)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-12-07 23:51:36 +08:00 committed by GitHub
parent 541a2ef892
commit b952f4d3c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 173 additions and 25 deletions

View File

@ -740,23 +740,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
!!! warning
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
However, there are differences in how they handle text + image inputs:
V0 correctly implements the model's attention pattern:
- Uses bidirectional attention between the image tokens corresponding to the same image
- Uses causal attention for other tokens
- Implemented via (naive) PyTorch SDPA with masking tensors
- Note: May use significant memory for long prompts with image
V1 currently uses a simplified attention pattern:
- Uses causal attention for all tokens, including image tokens
- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}`
- Will be updated in the future to support the correct behavior
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
!!! note
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
MobileNet-v5 vision backbone.
@ -776,9 +759,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: <https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630>
!!! warning
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
!!! note
For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported.

View File

@ -382,7 +382,6 @@ VLM_TEST_SETTINGS = {
auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
num_logprobs=10,
),
"glm4v": VLMTestInfo(
models=["zai-org/glm-4v-9b"],

View File

@ -30,5 +30,6 @@ class DummyPlatform(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

View File

@ -166,6 +166,10 @@ class AttentionBackend(ABC):
def supports_sink(cls) -> bool:
return False
@classmethod
def supports_mm_prefix(cls) -> bool:
return False
@classmethod
def is_sparse(cls) -> bool:
return False
@ -207,6 +211,7 @@ class AttentionBackend(ABC):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
@ -219,6 +224,10 @@ class AttentionBackend(ABC):
invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported")
if use_mm_prefix and not cls.supports_mm_prefix():
invalid_reasons.append(
"partial multimodal token full attention not supported"
)
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")

View File

@ -230,6 +230,10 @@ class Attention(nn.Module, AttentionLayerBase):
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
# NOTE: model_config may be None during certain tests
model_config = vllm_config.model_config
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
@ -241,6 +245,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size,
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
attn_type=attn_type,
)
else:

View File

@ -27,6 +27,7 @@ def get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
@ -52,6 +53,7 @@ def get_attn_backend(
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
attn_type=attn_type,
)
@ -66,6 +68,7 @@ def _cached_get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
from vllm.platforms import current_platform
@ -87,6 +90,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
else:
@ -99,6 +103,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
if not attention_cls:

View File

@ -4,6 +4,7 @@
import warnings
from collections.abc import Callable
from dataclasses import InitVar, field
from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch
@ -1217,6 +1218,19 @@ class ModelConfig:
)
return False
@cached_property
def is_mm_prefix_lm(self) -> bool:
"""Whether to use bidirectional attention for mm positions."""
MM_PREFIX_LM_MODELS = (
"gemma3",
# TODO(Isotr0py): Disable paligemma for now before
# we supports soft cap attention for FlexAttention
# "paligemma",
)
if not hasattr(self.hf_config, "model_type"):
return False
return self.hf_config.model_type in MM_PREFIX_LM_MODELS
def get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:

View File

@ -175,6 +175,31 @@ class PlaceholderRange:
return int(self.is_embed.sum().item())
def extract_embeds_range(self) -> list[tuple[int, int]]:
"""Extract the start and end indices of the embedded region in prompt.
For example, given `PlaceholderRange(offset=2, length=5)` and
`is_embed = [False, True, False, True, True]`, the output is
`[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.
Returns:
A tuple `(start, end)` representing the start and end
indices (inclusive) of the embedded region.
Returns full placeholder range if `is_embed` is `None`.
"""
if self.is_embed is None:
return [(self.offset, self.offset + self.length)]
mask_i = self.is_embed.int()
starts = torch.nonzero(
torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
).flatten()
ends = torch.nonzero(
torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
).flatten()
ranges = torch.stack((starts, ends), dim=1) + self.offset
return [tuple(x) for x in ranges.tolist()]
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False

View File

@ -133,6 +133,7 @@ class CpuPlatform(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:

View File

@ -233,6 +233,20 @@ class CudaPlatformBase(Platform):
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing
if (
model_config is not None
and model_config.is_mm_prefix_lm
and scheduler_config.is_multimodal_model
and not scheduler_config.disable_chunked_mm_input
):
logger.warning(
"Forcing --disable_chunked_mm_input for models "
"with multimodal-bidirectional attention."
)
scheduler_config.disable_chunked_mm_input = True
@classmethod
def get_current_memory_usage(
cls, device: torch.types.Device | None = None
@ -268,6 +282,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) -> tuple[
@ -289,6 +304,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
)
@ -312,6 +328,7 @@ class CudaPlatformBase(Platform):
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if attn_type is None:
@ -332,6 +349,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
)
@ -356,6 +374,7 @@ class CudaPlatformBase(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
)

View File

@ -239,6 +239,7 @@ class Platform:
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
"""Get the attention backend class of a device."""

View File

@ -216,6 +216,7 @@ class RocmPlatform(Platform):
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops

View File

@ -62,8 +62,9 @@ class TpuPlatform(Platform):
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink,
use_sparse,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if use_sparse:

View File

@ -48,7 +48,8 @@ class XPUPlatform(Platform):
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout

View File

@ -17,6 +17,7 @@ from torch.nn.attention.flex_attention import (
and_masks,
create_block_mask,
flex_attention,
or_masks,
)
from vllm.attention.backends.abstract import (
@ -42,6 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
torch._dynamo.config.recompile_limit = 16
create_block_mask_compiled = torch.compile(
create_block_mask, fullgraph=True, mode="reduce-overhead"
)
@ -91,6 +93,11 @@ class FlexAttentionBackend(AttentionBackend):
"""FlexAttention supports both decoder and encoder-only attention."""
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@classmethod
def supports_mm_prefix(cls) -> bool:
"""FlexAttention supports full attention for image tokens."""
return True
@staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl
@ -316,6 +323,7 @@ class FlexAttentionMetadata:
kv_block_size: int = 16
transformed_score_mod: _score_mod_signature | None = None
sliding_window: int | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
@cached_property
def logical_block_ids(self):
@ -443,6 +451,45 @@ class FlexAttentionMetadata:
return final_mask_mod if self.causal else sliding_window_mask_mod
def get_prefix_lm_mask_mod(self) -> _mask_mod_signature:
"""Creates the prefix LM mask_mod function for FlexAttention."""
assert self.doc_ids is not None
request_lookup = self.doc_ids
def prefix_lm_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
cu_q_idx: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
):
mask = torch.zeros_like(q_idx, dtype=torch.bool)
for req, doc_range_lst in (self.mm_prefix_range or {}).items():
req_mask = request_lookup[cu_q_idx] == req
for start, end in doc_range_lst:
doc_mask_q = (q_idx >= start) & (q_idx <= end)
doc_mask_kv = (kv_idx >= start) & (kv_idx <= end)
mask = mask | (req_mask & doc_mask_q & doc_mask_kv)
return mask
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx, logical_kv_idx) = (
self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
)
return torch.where(
is_valid,
prefix_lm_mask_mod(b, h, q_idx, logical_q_idx, logical_kv_idx),
False,
)
return final_mask_mod
def get_mask_mod(self):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
@ -456,6 +503,10 @@ class FlexAttentionMetadata:
# Add sliding window mask for sliding window attention
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
if self.mm_prefix_range:
# Add prefix LM mask for vision-language prefix LM attention
prefix_lm_mask_mod = self.get_prefix_lm_mask_mod()
mask_mod = or_masks(mask_mod, prefix_lm_mask_mod)
return mask_mod
def get_transformed_score_mod(self) -> _score_mod_signature | None:
@ -709,6 +760,7 @@ class FlexAttentionImpl(AttentionImpl):
sliding_window: int | None
alibi_slopes: torch.Tensor | None
logits_soft_cap: float | None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
def __init__(
self,
@ -810,11 +862,21 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
needs_rebuild_block_mask = False
if attn_metadata.sliding_window != self.sliding_window:
attn_metadata.sliding_window = self.sliding_window
if attn_metadata.direct_build:
# update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True
if self.mm_prefix_range != getattr(attn_metadata, "mm_prefix_range", None):
self.mm_prefix_range = attn_metadata.mm_prefix_range
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True
if needs_rebuild_block_mask:
if attn_metadata.direct_build and attn_metadata.causal:
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
else:
attn_metadata.block_mask = attn_metadata.build_block_mask()

View File

@ -48,7 +48,10 @@ from vllm.distributed.parallel_state import (
is_global_first_rank,
prepare_communication_buffer_for_model,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.forward_context import (
BatchDescriptor,
set_forward_context,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.rotary_embedding import (
@ -329,6 +332,7 @@ class GPUModelRunner(
self.use_alibi = model_config.uses_alibi
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
self.is_mm_prefix_lm = self.model_config.is_mm_prefix_lm
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
@ -1700,6 +1704,26 @@ class GPUModelRunner(
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
if self.is_mm_prefix_lm:
req_doc_ranges = {}
for req_id in self.input_batch.req_ids:
image_doc_ranges = []
req_state = self.requests[req_id]
for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position
img_doc_range = pos_info.extract_embeds_range()
image_doc_ranges.extend(img_doc_range)
req_idx = self.input_batch.req_id_to_index[req_id]
req_doc_ranges[req_idx] = image_doc_ranges
if isinstance(attn_metadata, list):
for ub_metadata in attn_metadata:
for _metadata in ub_metadata.values():
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
else:
for _metadata in attn_metadata.values():
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
if spec_decode_common_attn_metadata is not None and (
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
):