mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 09:54:26 +08:00
[v1] Add PrefixLM support to FlexAttention backend (#27938)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
541a2ef892
commit
b952f4d3c3
@ -740,23 +740,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor
|
||||
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
|
||||
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
|
||||
|
||||
!!! warning
|
||||
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
|
||||
However, there are differences in how they handle text + image inputs:
|
||||
|
||||
V0 correctly implements the model's attention pattern:
|
||||
- Uses bidirectional attention between the image tokens corresponding to the same image
|
||||
- Uses causal attention for other tokens
|
||||
- Implemented via (naive) PyTorch SDPA with masking tensors
|
||||
- Note: May use significant memory for long prompts with image
|
||||
|
||||
V1 currently uses a simplified attention pattern:
|
||||
- Uses causal attention for all tokens, including image tokens
|
||||
- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}`
|
||||
- Will be updated in the future to support the correct behavior
|
||||
|
||||
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
|
||||
|
||||
!!! note
|
||||
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
|
||||
MobileNet-v5 vision backbone.
|
||||
@ -776,9 +759,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor
|
||||
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
|
||||
For more details, please see: <https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630>
|
||||
|
||||
!!! warning
|
||||
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
|
||||
|
||||
!!! note
|
||||
For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported.
|
||||
|
||||
|
||||
@ -382,7 +382,6 @@ VLM_TEST_SETTINGS = {
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
|
||||
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
|
||||
num_logprobs=10,
|
||||
),
|
||||
"glm4v": VLMTestInfo(
|
||||
models=["zai-org/glm-4v-9b"],
|
||||
|
||||
@ -30,5 +30,6 @@ class DummyPlatform(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
):
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
|
||||
@ -166,6 +166,10 @@ class AttentionBackend(ABC):
|
||||
def supports_sink(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return False
|
||||
@ -207,6 +211,7 @@ class AttentionBackend(ABC):
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
device_capability: "DeviceCapability",
|
||||
attn_type: str,
|
||||
) -> list[str]:
|
||||
@ -219,6 +224,10 @@ class AttentionBackend(ABC):
|
||||
invalid_reasons.append("kv_cache_dtype not supported")
|
||||
if not cls.supports_block_size(block_size):
|
||||
invalid_reasons.append("block_size not supported")
|
||||
if use_mm_prefix and not cls.supports_mm_prefix():
|
||||
invalid_reasons.append(
|
||||
"partial multimodal token full attention not supported"
|
||||
)
|
||||
if use_mla != cls.is_mla():
|
||||
if use_mla:
|
||||
invalid_reasons.append("MLA not supported")
|
||||
|
||||
@ -230,6 +230,10 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
# NOTE: model_config may be None during certain tests
|
||||
model_config = vllm_config.model_config
|
||||
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
@ -241,6 +245,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
block_size,
|
||||
use_mla=False,
|
||||
has_sink=self.has_sink,
|
||||
use_mm_prefix=self.use_mm_prefix,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -27,6 +27,7 @@ def get_attn_backend(
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
use_mm_prefix: bool = False,
|
||||
attn_type: str | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
@ -52,6 +53,7 @@ def get_attn_backend(
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
use_sparse=use_sparse,
|
||||
use_mm_prefix=use_mm_prefix,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
|
||||
@ -66,6 +68,7 @@ def _cached_get_attn_backend(
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
use_mm_prefix: bool = False,
|
||||
attn_type: str | None = None,
|
||||
) -> type[AttentionBackend]:
|
||||
from vllm.platforms import current_platform
|
||||
@ -87,6 +90,7 @@ def _cached_get_attn_backend(
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type,
|
||||
)
|
||||
else:
|
||||
@ -99,6 +103,7 @@ def _cached_get_attn_backend(
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type,
|
||||
)
|
||||
if not attention_cls:
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar, field
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
|
||||
|
||||
import torch
|
||||
@ -1217,6 +1218,19 @@ class ModelConfig:
|
||||
)
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def is_mm_prefix_lm(self) -> bool:
|
||||
"""Whether to use bidirectional attention for mm positions."""
|
||||
MM_PREFIX_LM_MODELS = (
|
||||
"gemma3",
|
||||
# TODO(Isotr0py): Disable paligemma for now before
|
||||
# we supports soft cap attention for FlexAttention
|
||||
# "paligemma",
|
||||
)
|
||||
if not hasattr(self.hf_config, "model_type"):
|
||||
return False
|
||||
return self.hf_config.model_type in MM_PREFIX_LM_MODELS
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if self.is_deepseek_mla:
|
||||
|
||||
@ -175,6 +175,31 @@ class PlaceholderRange:
|
||||
|
||||
return int(self.is_embed.sum().item())
|
||||
|
||||
def extract_embeds_range(self) -> list[tuple[int, int]]:
|
||||
"""Extract the start and end indices of the embedded region in prompt.
|
||||
|
||||
For example, given `PlaceholderRange(offset=2, length=5)` and
|
||||
`is_embed = [False, True, False, True, True]`, the output is
|
||||
`[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.
|
||||
|
||||
Returns:
|
||||
A tuple `(start, end)` representing the start and end
|
||||
indices (inclusive) of the embedded region.
|
||||
Returns full placeholder range if `is_embed` is `None`.
|
||||
"""
|
||||
if self.is_embed is None:
|
||||
return [(self.offset, self.offset + self.length)]
|
||||
|
||||
mask_i = self.is_embed.int()
|
||||
starts = torch.nonzero(
|
||||
torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
|
||||
).flatten()
|
||||
ends = torch.nonzero(
|
||||
torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
|
||||
).flatten()
|
||||
ranges = torch.stack((starts, ends), dim=1) + self.offset
|
||||
return [tuple(x) for x in ranges.tolist()]
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
@ -133,6 +133,7 @@ class CpuPlatform(Platform):
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||
|
||||
@ -233,6 +233,20 @@ class CudaPlatformBase(Platform):
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse backend."
|
||||
)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
# Note: model_config may be None during testing
|
||||
if (
|
||||
model_config is not None
|
||||
and model_config.is_mm_prefix_lm
|
||||
and scheduler_config.is_multimodal_model
|
||||
and not scheduler_config.disable_chunked_mm_input
|
||||
):
|
||||
logger.warning(
|
||||
"Forcing --disable_chunked_mm_input for models "
|
||||
"with multimodal-bidirectional attention."
|
||||
)
|
||||
scheduler_config.disable_chunked_mm_input = True
|
||||
|
||||
@classmethod
|
||||
def get_current_memory_usage(
|
||||
cls, device: torch.types.Device | None = None
|
||||
@ -268,6 +282,7 @@ class CudaPlatformBase(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
) -> tuple[
|
||||
@ -289,6 +304,7 @@ class CudaPlatformBase(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
)
|
||||
@ -312,6 +328,7 @@ class CudaPlatformBase(Platform):
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
if attn_type is None:
|
||||
@ -332,6 +349,7 @@ class CudaPlatformBase(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
)
|
||||
@ -356,6 +374,7 @@ class CudaPlatformBase(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
device_capability,
|
||||
attn_type,
|
||||
)
|
||||
|
||||
@ -239,6 +239,7 @@ class Platform:
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
|
||||
@ -216,6 +216,7 @@ class RocmPlatform(Platform):
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
use_mm_prefix,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
@ -62,8 +62,9 @@ class TpuPlatform(Platform):
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
if use_sparse:
|
||||
|
||||
@ -48,7 +48,8 @@ class XPUPlatform(Platform):
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse,
|
||||
use_sparse: bool,
|
||||
use_mm_prefix: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
@ -17,6 +17,7 @@ from torch.nn.attention.flex_attention import (
|
||||
and_masks,
|
||||
create_block_mask,
|
||||
flex_attention,
|
||||
or_masks,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
@ -42,6 +43,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
torch._dynamo.config.recompile_limit = 16
|
||||
create_block_mask_compiled = torch.compile(
|
||||
create_block_mask, fullgraph=True, mode="reduce-overhead"
|
||||
)
|
||||
@ -91,6 +93,11 @@ class FlexAttentionBackend(AttentionBackend):
|
||||
"""FlexAttention supports both decoder and encoder-only attention."""
|
||||
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
"""FlexAttention supports full attention for image tokens."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlexAttentionImpl"]:
|
||||
return FlexAttentionImpl
|
||||
@ -316,6 +323,7 @@ class FlexAttentionMetadata:
|
||||
kv_block_size: int = 16
|
||||
transformed_score_mod: _score_mod_signature | None = None
|
||||
sliding_window: int | None = None
|
||||
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||
|
||||
@cached_property
|
||||
def logical_block_ids(self):
|
||||
@ -443,6 +451,45 @@ class FlexAttentionMetadata:
|
||||
|
||||
return final_mask_mod if self.causal else sliding_window_mask_mod
|
||||
|
||||
def get_prefix_lm_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the prefix LM mask_mod function for FlexAttention."""
|
||||
|
||||
assert self.doc_ids is not None
|
||||
request_lookup = self.doc_ids
|
||||
|
||||
def prefix_lm_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
cu_q_idx: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
):
|
||||
mask = torch.zeros_like(q_idx, dtype=torch.bool)
|
||||
for req, doc_range_lst in (self.mm_prefix_range or {}).items():
|
||||
req_mask = request_lookup[cu_q_idx] == req
|
||||
for start, end in doc_range_lst:
|
||||
doc_mask_q = (q_idx >= start) & (q_idx <= end)
|
||||
doc_mask_kv = (kv_idx >= start) & (kv_idx <= end)
|
||||
mask = mask | (req_mask & doc_mask_q & doc_mask_kv)
|
||||
return mask
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
(is_valid, logical_q_idx, logical_kv_idx) = (
|
||||
self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
|
||||
)
|
||||
return torch.where(
|
||||
is_valid,
|
||||
prefix_lm_mask_mod(b, h, q_idx, logical_q_idx, logical_kv_idx),
|
||||
False,
|
||||
)
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def get_mask_mod(self):
|
||||
# Stage-1: initialize the base mask_mod
|
||||
# (causal mask for decoder or bidirectional mask for encoder)
|
||||
@ -456,6 +503,10 @@ class FlexAttentionMetadata:
|
||||
# Add sliding window mask for sliding window attention
|
||||
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
|
||||
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
|
||||
if self.mm_prefix_range:
|
||||
# Add prefix LM mask for vision-language prefix LM attention
|
||||
prefix_lm_mask_mod = self.get_prefix_lm_mask_mod()
|
||||
mask_mod = or_masks(mask_mod, prefix_lm_mask_mod)
|
||||
return mask_mod
|
||||
|
||||
def get_transformed_score_mod(self) -> _score_mod_signature | None:
|
||||
@ -709,6 +760,7 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
sliding_window: int | None
|
||||
alibi_slopes: torch.Tensor | None
|
||||
logits_soft_cap: float | None
|
||||
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -810,11 +862,21 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
needs_rebuild_block_mask = False
|
||||
if attn_metadata.sliding_window != self.sliding_window:
|
||||
attn_metadata.sliding_window = self.sliding_window
|
||||
if attn_metadata.direct_build:
|
||||
# update mask mod in attention metadata
|
||||
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
|
||||
needs_rebuild_block_mask = True
|
||||
|
||||
if self.mm_prefix_range != getattr(attn_metadata, "mm_prefix_range", None):
|
||||
self.mm_prefix_range = attn_metadata.mm_prefix_range
|
||||
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
|
||||
needs_rebuild_block_mask = True
|
||||
|
||||
if needs_rebuild_block_mask:
|
||||
if attn_metadata.direct_build and attn_metadata.causal:
|
||||
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
|
||||
else:
|
||||
attn_metadata.block_mask = attn_metadata.build_block_mask()
|
||||
|
||||
@ -48,7 +48,10 @@ from vllm.distributed.parallel_state import (
|
||||
is_global_first_rank,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.forward_context import (
|
||||
BatchDescriptor,
|
||||
set_forward_context,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
@ -329,6 +332,7 @@ class GPUModelRunner(
|
||||
self.use_alibi = model_config.uses_alibi
|
||||
|
||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||
self.is_mm_prefix_lm = self.model_config.is_mm_prefix_lm
|
||||
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
@ -1700,6 +1704,26 @@ class GPUModelRunner(
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
if self.is_mm_prefix_lm:
|
||||
req_doc_ranges = {}
|
||||
for req_id in self.input_batch.req_ids:
|
||||
image_doc_ranges = []
|
||||
req_state = self.requests[req_id]
|
||||
for mm_feature in req_state.mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
img_doc_range = pos_info.extract_embeds_range()
|
||||
image_doc_ranges.extend(img_doc_range)
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
req_doc_ranges[req_idx] = image_doc_ranges
|
||||
|
||||
if isinstance(attn_metadata, list):
|
||||
for ub_metadata in attn_metadata:
|
||||
for _metadata in ub_metadata.values():
|
||||
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
|
||||
else:
|
||||
for _metadata in attn_metadata.values():
|
||||
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
|
||||
|
||||
if spec_decode_common_attn_metadata is not None and (
|
||||
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
|
||||
):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user