[v1] Add PrefixLM support to TritonAttention backend (#30386)

This commit is contained in:
Isotr0py 2025-12-18 08:05:24 +08:00 committed by GitHub
parent 05a83dc6ee
commit 74a1ac38b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 280 additions and 123 deletions

View File

@ -1,17 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, NamedTuple
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from typing import Any, NamedTuple
import pytest
from huggingface_hub import hf_hub_download
from pytest import MarkDecorator
from transformers import AutoModelForImageTextToText
from tests.quantization.utils import is_quant_method_supported
from vllm.assets.image import ImageAsset
from vllm.multimodal.image import rescale_image_size
from vllm.utils.torch_utils import set_default_torch_num_threads
from ....conftest import PromptImageInput, VllmRunner
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner
from ...utils import check_logprobs_close
@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
gguf_backbone: str
gguf_mmproj: str
prompt: list[str]
mm_data: dict[Literal["images"], PromptImageInput]
image_names: list[str] # Store names, load PIL images at runtime
max_model_len: int = 4096
marks: list[MarkDecorator] = []
mm_processor_kwargs: dict[str, Any] = {}
@property
def gguf_model(self):
@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
# Common prompts aligned with test_common.py "gemma3" entry format
_GEMMA3_PROMPTS = IMAGE_ASSETS.prompts(
{
"stop_sign": (
"<bos><start_of_turn>user\n"
"<start_of_image>What's the content in the center of the image?"
"<end_of_turn>\n<start_of_turn>model\n"
),
"cherry_blossom": (
"<bos><start_of_turn>user\n"
"<start_of_image>What is the season?"
"<end_of_turn>\n<start_of_turn>model\n"
),
}
)
# Image asset names - load at runtime to avoid pickle issues with subprocess
_GEMMA3_IMAGE_NAMES = ["stop_sign", "cherry_blossom"]
# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF
GEMMA3_CONFIG = GGUFMMTestConfig(
original_model="google/gemma-3-4b-it",
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
gguf_mmproj="mmproj-model-f16-4B.gguf",
prompt=["<start_of_image>Describe this image in detail:"],
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
prompt=_GEMMA3_PROMPTS,
image_names=_GEMMA3_IMAGE_NAMES,
max_model_len=4096,
marks=[pytest.mark.core_model],
mm_processor_kwargs={},
)
MODELS_TO_TEST = [GEMMA3_CONFIG]
# Pan-and-scan multimodal - uses unquantized BF16 GGUF
GEMMA3_CONFIG_PAN_AND_SCAN = GGUFMMTestConfig(
original_model="google/gemma-3-4b-it",
gguf_repo="unsloth/gemma-3-4b-it-GGUF",
gguf_backbone="gemma-3-4b-it-BF16.gguf",
gguf_mmproj="mmproj-BF16.gguf",
prompt=_GEMMA3_PROMPTS,
image_names=_GEMMA3_IMAGE_NAMES,
max_model_len=4096,
marks=[pytest.mark.core_model],
mm_processor_kwargs={"do_pan_and_scan": True},
)
MODELS_TO_TEST = [GEMMA3_CONFIG, GEMMA3_CONFIG_PAN_AND_SCAN]
def run_multimodal_gguf_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
model: GGUFMMTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
):
# Run gguf model.
# Load images at runtime (inside subprocess) to avoid pickle issues
images = [ImageAsset(name).pil_image for name in model.image_names]
size_factors = [0.25, 0.5, 1.0]
inputs_per_image = [
(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
)
for image, prompt in zip(images, model.prompt)
]
# NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork.
# Run GGUF model via vLLM.
with (
set_default_torch_num_threads(1),
vllm_runner(
@ -60,35 +115,42 @@ def run_multimodal_gguf_test(
tokenizer_name=model.original_model,
dtype=dtype,
max_model_len=model.max_model_len,
mm_processor_kwargs=model.mm_processor_kwargs,
) as gguf_model,
):
gguf_outputs = gguf_model.generate_greedy_logprobs(
prompts=model.prompt,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**model.mm_data,
)
gguf_outputs_per_case = [
gguf_model.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
)
for prompts, images in inputs_per_image
]
# Run unquantized model.
with vllm_runner(
model_name=model.original_model,
enforce_eager=True, # faster tests
# Then run HfRunner for HuggingFace baseline comparison.
with hf_runner(
model.original_model,
dtype=dtype,
max_model_len=model.max_model_len,
) as original_model:
original_outputs = original_model.generate_greedy_logprobs(
prompts=model.prompt,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**model.mm_data,
)
auto_cls=AutoModelForImageTextToText,
) as hf_model:
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
)
for prompts, images in inputs_per_image
]
check_logprobs_close(
outputs_0_lst=original_outputs,
outputs_1_lst=gguf_outputs,
name_0="original",
name_1="gguf",
)
for hf_outputs, gguf_outputs in zip(hf_outputs_per_case, gguf_outputs_per_case):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=gguf_outputs,
name_0="hf",
name_1="gguf",
)
@pytest.mark.skipif(
@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(
def test_gemma3_mm_gguf(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
model: GGUFMMTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
run_multimodal_gguf_test(
hf_runner, vllm_runner, model, dtype, max_tokens, num_logprobs
)

View File

@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
USE_MM_PREFIX: tl.constexpr, # bool
MAX_MM_RANGES: tl.constexpr, # int
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
else:
V = V_load
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# Compute attention mask: causal by default (key <= query)
query_abs_pos = context_len + query_pos[:, None]
seq_mask = seq_offset[None, :] <= query_abs_pos
# Apply sliding window to base mask BEFORE mm_prefix OR.
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
if SLIDING_WINDOW > 0:
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
if USE_MM_PREFIX:
for i in range(MAX_MM_RANGES):
range_start = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
)
range_end = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
)
is_valid = range_start < range_end
q_in_range = (
(query_abs_pos >= range_start)
& (query_abs_pos <= range_end)
& is_valid
)
k_in_range = (
(seq_offset[None, :] >= range_start)
& (seq_offset[None, :] <= range_end)
& is_valid
)
seq_mask |= q_in_range & k_in_range
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
)
if SLIDING_WINDOW > 0:
S = tl.where(
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
S,
float("-inf"),
)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
USE_MM_PREFIX: tl.constexpr, # bool
MAX_MM_RANGES: tl.constexpr, # int
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
else:
V = V_load
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# Compute attention mask: causal by default (key <= query)
query_abs_pos = context_len + query_pos[:, None]
seq_mask = seq_offset[None, :] <= query_abs_pos
# Apply sliding window to base mask BEFORE mm_prefix OR.
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
if SLIDING_WINDOW > 0:
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
if USE_MM_PREFIX:
for i in range(MAX_MM_RANGES):
range_start = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
)
range_end = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
)
is_valid = range_start < range_end
q_in_range = (
(query_abs_pos >= range_start)
& (query_abs_pos <= range_end)
& is_valid
)
k_in_range = (
(seq_offset[None, :] >= range_start)
& (seq_offset[None, :] <= range_end)
& is_valid
)
seq_mask |= q_in_range & k_in_range
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
)
if SLIDING_WINDOW > 0:
S = tl.where(
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
S,
float("-inf"),
)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
@ -732,6 +786,43 @@ def reduce_segments(
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
"""Detect Gemma3 models via unique (head_size, sliding_window) signature.
Gemma3 models are the only ones using sliding_window=1024 with
head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
different window sizes (Mistral=4096, Phi-3=2047).
"""
return sliding_window == 1024 and head_size in (128, 256)
def _get_tile_size(
head_size: int,
sliding_window: int,
element_size: int,
is_mm_prefix: bool,
is_prefill: bool,
) -> int:
"""Select tile size with Gemma3-specific optimization.
For Gemma3, use 32 for both prefill and decode to better utilize
the larger head dimension (128/256). For other models, use
the default vLLM behavior.
"""
if is_mm_prefix:
# Multimodal bidirectional attention needs a larger tile size
return 64
if _is_gemma3_attention(head_size, sliding_window):
# Gemma3: use 32 for decode (default is 16)
return 32
# Default behavior
if is_prefill:
return 32
return 16 if element_size >= 2 else 32
def unified_attention(
q,
k,
@ -759,6 +850,8 @@ def unified_attention(
qq_bias=None,
# Optional tensor for sinks
sinks=None,
# Optional tensor for prefix lengths (PrefixLM support)
mm_prefix_range=None,
):
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
@ -766,6 +859,17 @@ def unified_attention(
if sinks is not None:
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
use_mm_prefix = False
max_mm_ranges = 0
if mm_prefix_range is not None:
if mm_prefix_range.ndim == 3:
use_mm_prefix = True
max_mm_ranges = mm_prefix_range.shape[1]
else:
raise ValueError(
f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
)
use_alibi_slopes = alibi_slopes is not None
use_qq_bias = qq_bias is not None
@ -792,11 +896,23 @@ def unified_attention(
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
# Assigning default tile sizes for prefill and decode.
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
# and at least 16 for all other data types.
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
# Tile sizes for prefill and decode. Gemma3 models use optimized values.
# Note: tile size must be at least 32 for fp8 (element_size == 1).
sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
TILE_SIZE_PREFILL = _get_tile_size(
head_size,
sliding_window_val,
q.element_size(),
is_mm_prefix=use_mm_prefix,
is_prefill=True,
)
TILE_SIZE_DECODE = _get_tile_size(
head_size,
sliding_window_val,
q.element_size(),
is_mm_prefix=use_mm_prefix,
is_prefill=False,
)
# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
@ -847,6 +963,9 @@ def unified_attention(
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
@ -895,6 +1014,9 @@ def unified_attention(
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),

View File

@ -19,7 +19,6 @@ from collections.abc import Iterable
from itertools import islice
import torch
import torch.nn.functional as F
from torch import nn
from transformers import Gemma3TextConfig
@ -226,77 +225,9 @@ class Gemma3Attention(nn.Module):
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if not kwargs.get("has_images", False):
# Fast path for text-only inputs. The performance for the text-only
# inputs are not affected by the naive attention below.
output, _ = self.o_proj(attn_output)
return output
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
# that correspond to the same image while using causal attention
# otherwise. Current attention backends cannot handle this pattern, so
# we temporarily use a naive attention implementation with mask tensors.
# We intentionally keep the attention backend as-is and only override
# `attn_output` with the naive implementation's output. This minimizes
# changes to existing model runners and attention backends. The call to
# `self.attn(q, k, v)` is only used to populate the KV cache - its
# output is discarded and overwritten below. While this duplicates
# computation, it maintains compatibility.
# TODO(woosuk): Optimize by implementing custom attention kernels.
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
output, _ = self.o_proj(attn_output)
return output
def naive_attn_with_masks(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# NOTE(woosuk): As described in the comment above, this code is not
# meant to be performant. It is only meant to be correct.
q = q.view(-1, self.num_heads, self.head_dim)
# Expand the key and value to handle GQA.
num_queries_per_kv = self.num_heads // self.num_kv_heads
k = k.view(-1, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
v = v.view(-1, self.num_kv_heads, self.head_dim)
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
if self.is_sliding:
attn_masks = kwargs["local_attn_masks"]
else:
attn_masks = kwargs["global_attn_masks"]
seq_lens = kwargs["seq_lens"]
start_idx = 0
for seq_len, attn_mask in zip(seq_lens, attn_masks):
end_idx = start_idx + seq_len
query = q[start_idx:end_idx].unsqueeze(0)
key = k[start_idx:end_idx].unsqueeze(0)
value = v[start_idx:end_idx].unsqueeze(0)
# Transpose.
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask,
self.scaling,
)
output = output.transpose(1, 2).flatten(-2, -1)
out[start_idx:end_idx] = output
start_idx = end_idx
return out
class Gemma3DecoderLayer(nn.Module):
def __init__(

View File

@ -76,6 +76,39 @@ class TritonAttentionMetadata:
# Optional aot scheduling
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
@property
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
Empty ranges have start==end==0, which kernel skips via is_valid check.
"""
# TODO(Isotr0py): Move to model runner's attention metadata
# preparation to avoid duplicate computation.
if self.mm_prefix_range is None:
return None
num_seqs = self.seq_lens.shape[0]
device = self.seq_lens.device
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
range_lists = [
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
]
# Return None if all ranges are trivial (only (0,0) placeholders)
if all(r == [(0, 0)] for r in range_lists):
return None
# Create 2D tensors with shape (num_ranges, 2) for each sequence
range_tensors = [
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
for r in range_lists
]
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@classmethod
def supports_sink(cls) -> bool:
return True
@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
unified_attention(
q=query[:num_actual_tokens],
@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_segm_expsum=softmax_segm_expsum,
sinks=self.sinks,
output_scale=output_scale,
mm_prefix_range=mm_prefix_range_tensor,
)
return output