From 74a1ac38b00a8cf502db085d1bbd77712cf47e41 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 18 Dec 2025 08:05:24 +0800 Subject: [PATCH] [v1] Add PrefixLM support to TritonAttention backend (#30386) --- .../generation/test_multimodal_gguf.py | 131 ++++++++++---- .../attention/ops/triton_unified_attention.py | 164 +++++++++++++++--- vllm/model_executor/models/gemma3.py | 69 -------- vllm/v1/attention/backends/triton_attn.py | 39 +++++ 4 files changed, 280 insertions(+), 123 deletions(-) diff --git a/tests/models/multimodal/generation/test_multimodal_gguf.py b/tests/models/multimodal/generation/test_multimodal_gguf.py index e596b20c6302b..813dccf1451b5 100644 --- a/tests/models/multimodal/generation/test_multimodal_gguf.py +++ b/tests/models/multimodal/generation/test_multimodal_gguf.py @@ -1,17 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Literal, NamedTuple +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +from typing import Any, NamedTuple import pytest from huggingface_hub import hf_hub_download from pytest import MarkDecorator +from transformers import AutoModelForImageTextToText from tests.quantization.utils import is_quant_method_supported from vllm.assets.image import ImageAsset +from vllm.multimodal.image import rescale_image_size from vllm.utils.torch_utils import set_default_torch_num_threads -from ....conftest import PromptImageInput, VllmRunner +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner from ...utils import check_logprobs_close @@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple): gguf_backbone: str gguf_mmproj: str prompt: list[str] - mm_data: dict[Literal["images"], PromptImageInput] + image_names: list[str] # Store names, load PIL images at runtime max_model_len: int = 4096 marks: list[MarkDecorator] = [] + mm_processor_kwargs: dict[str, Any] = {} @property def gguf_model(self): @@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple): return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone) +# Common prompts aligned with test_common.py "gemma3" entry format +_GEMMA3_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": ( + "user\n" + "What's the content in the center of the image?" + "\nmodel\n" + ), + "cherry_blossom": ( + "user\n" + "What is the season?" + "\nmodel\n" + ), + } +) + +# Image asset names - load at runtime to avoid pickle issues with subprocess +_GEMMA3_IMAGE_NAMES = ["stop_sign", "cherry_blossom"] + +# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF GEMMA3_CONFIG = GGUFMMTestConfig( original_model="google/gemma-3-4b-it", gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf", gguf_backbone="gemma-3-4b-it-q4_0.gguf", gguf_mmproj="mmproj-model-f16-4B.gguf", - prompt=["Describe this image in detail:"], - mm_data={"images": [ImageAsset("stop_sign").pil_image]}, + prompt=_GEMMA3_PROMPTS, + image_names=_GEMMA3_IMAGE_NAMES, + max_model_len=4096, marks=[pytest.mark.core_model], + mm_processor_kwargs={}, ) -MODELS_TO_TEST = [GEMMA3_CONFIG] +# Pan-and-scan multimodal - uses unquantized BF16 GGUF +GEMMA3_CONFIG_PAN_AND_SCAN = GGUFMMTestConfig( + original_model="google/gemma-3-4b-it", + gguf_repo="unsloth/gemma-3-4b-it-GGUF", + gguf_backbone="gemma-3-4b-it-BF16.gguf", + gguf_mmproj="mmproj-BF16.gguf", + prompt=_GEMMA3_PROMPTS, + image_names=_GEMMA3_IMAGE_NAMES, + max_model_len=4096, + marks=[pytest.mark.core_model], + mm_processor_kwargs={"do_pan_and_scan": True}, +) + +MODELS_TO_TEST = [GEMMA3_CONFIG, GEMMA3_CONFIG_PAN_AND_SCAN] def run_multimodal_gguf_test( + hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], model: GGUFMMTestConfig, dtype: str, max_tokens: int, num_logprobs: int, ): - # Run gguf model. + # Load images at runtime (inside subprocess) to avoid pickle issues + images = [ImageAsset(name).pil_image for name in model.image_names] + size_factors = [0.25, 0.5, 1.0] + inputs_per_image = [ + ( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) + for image, prompt in zip(images, model.prompt) + ] + + # NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork. + # Run GGUF model via vLLM. with ( set_default_torch_num_threads(1), vllm_runner( @@ -60,35 +115,42 @@ def run_multimodal_gguf_test( tokenizer_name=model.original_model, dtype=dtype, max_model_len=model.max_model_len, + mm_processor_kwargs=model.mm_processor_kwargs, ) as gguf_model, ): - gguf_outputs = gguf_model.generate_greedy_logprobs( - prompts=model.prompt, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - **model.mm_data, - ) + gguf_outputs_per_case = [ + gguf_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + ) + for prompts, images in inputs_per_image + ] - # Run unquantized model. - with vllm_runner( - model_name=model.original_model, - enforce_eager=True, # faster tests + # Then run HfRunner for HuggingFace baseline comparison. + with hf_runner( + model.original_model, dtype=dtype, - max_model_len=model.max_model_len, - ) as original_model: - original_outputs = original_model.generate_greedy_logprobs( - prompts=model.prompt, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - **model.mm_data, - ) + auto_cls=AutoModelForImageTextToText, + ) as hf_model: + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + ) + for prompts, images in inputs_per_image + ] - check_logprobs_close( - outputs_0_lst=original_outputs, - outputs_1_lst=gguf_outputs, - name_0="original", - name_1="gguf", - ) + for hf_outputs, gguf_outputs in zip(hf_outputs_per_case, gguf_outputs_per_case): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=gguf_outputs, + name_0="hf", + name_1="gguf", + ) @pytest.mark.skipif( @@ -105,11 +167,14 @@ def run_multimodal_gguf_test( @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models( +def test_gemma3_mm_gguf( + hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], model: GGUFMMTestConfig, dtype: str, max_tokens: int, num_logprobs: int, ) -> None: - run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs) + run_multimodal_gguf_test( + hf_runner, vllm_runner, model, dtype, max_tokens, num_logprobs + ) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index a1877bb4429b9..ae5a48ec3d26d 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -86,6 +86,9 @@ def kernel_unified_attention_2d( USE_SOFTCAP: tl.constexpr, # bool USE_SINKS: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int + USE_MM_PREFIX: tl.constexpr, # bool + MAX_MM_RANGES: tl.constexpr, # int + mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int stride_k_cache_2: tl.int64, # int @@ -270,7 +273,38 @@ def kernel_unified_attention_2d( else: V = V_load - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + # Compute attention mask: causal by default (key <= query) + query_abs_pos = context_len + query_pos[:, None] + seq_mask = seq_offset[None, :] <= query_abs_pos + + # Apply sliding window to base mask BEFORE mm_prefix OR. + # Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix + if SLIDING_WINDOW > 0: + seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW) + + # PrefixLM: extend mask with bidirectional ranges for multimodal tokens. + # Applied AFTER sliding window so mm_prefix ranges override SW restriction. + if USE_MM_PREFIX: + for i in range(MAX_MM_RANGES): + range_start = tl.load( + mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + ) + range_end = tl.load( + mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1 + ) + + is_valid = range_start < range_end + q_in_range = ( + (query_abs_pos >= range_start) + & (query_abs_pos <= range_end) + & is_valid + ) + k_in_range = ( + (seq_offset[None, :] >= range_start) + & (seq_offset[None, :] <= range_end) + & is_valid + ) + seq_mask |= q_in_range & k_in_range # S : (BLOCK_M, TILE_SIZE) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) @@ -284,13 +318,6 @@ def kernel_unified_attention_2d( query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") ) - if SLIDING_WINDOW > 0: - S = tl.where( - (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, - S, - float("-inf"), - ) - if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -398,6 +425,9 @@ def kernel_unified_attention_3d( num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + USE_MM_PREFIX: tl.constexpr, # bool + MAX_MM_RANGES: tl.constexpr, # int + mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) @@ -559,7 +589,38 @@ def kernel_unified_attention_3d( else: V = V_load - seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + # Compute attention mask: causal by default (key <= query) + query_abs_pos = context_len + query_pos[:, None] + seq_mask = seq_offset[None, :] <= query_abs_pos + + # Apply sliding window to base mask BEFORE mm_prefix OR. + # Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix + if SLIDING_WINDOW > 0: + seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW) + + # PrefixLM: extend mask with bidirectional ranges for multimodal tokens. + # Applied AFTER sliding window so mm_prefix ranges override SW restriction. + if USE_MM_PREFIX: + for i in range(MAX_MM_RANGES): + range_start = tl.load( + mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + ) + range_end = tl.load( + mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1 + ) + + is_valid = range_start < range_end + q_in_range = ( + (query_abs_pos >= range_start) + & (query_abs_pos <= range_end) + & is_valid + ) + k_in_range = ( + (seq_offset[None, :] >= range_start) + & (seq_offset[None, :] <= range_end) + & is_valid + ) + seq_mask |= q_in_range & k_in_range # S : (BLOCK_M, TILE_SIZE) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) @@ -572,13 +633,6 @@ def kernel_unified_attention_3d( query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") ) - if SLIDING_WINDOW > 0: - S = tl.where( - (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, - S, - float("-inf"), - ) - if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -732,6 +786,43 @@ def reduce_segments( tl.store(output_ptr + output_offset, acc, mask=dim_mask) +def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool: + """Detect Gemma3 models via unique (head_size, sliding_window) signature. + + Gemma3 models are the only ones using sliding_window=1024 with + head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use + different window sizes (Mistral=4096, Phi-3=2047). + """ + return sliding_window == 1024 and head_size in (128, 256) + + +def _get_tile_size( + head_size: int, + sliding_window: int, + element_size: int, + is_mm_prefix: bool, + is_prefill: bool, +) -> int: + """Select tile size with Gemma3-specific optimization. + + For Gemma3, use 32 for both prefill and decode to better utilize + the larger head dimension (128/256). For other models, use + the default vLLM behavior. + """ + if is_mm_prefix: + # Multimodal bidirectional attention needs a larger tile size + return 64 + + if _is_gemma3_attention(head_size, sliding_window): + # Gemma3: use 32 for decode (default is 16) + return 32 + + # Default behavior + if is_prefill: + return 32 + return 16 if element_size >= 2 else 32 + + def unified_attention( q, k, @@ -759,6 +850,8 @@ def unified_attention( qq_bias=None, # Optional tensor for sinks sinks=None, + # Optional tensor for prefix lengths (PrefixLM support) + mm_prefix_range=None, ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -766,6 +859,17 @@ def unified_attention( if sinks is not None: assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size" + use_mm_prefix = False + max_mm_ranges = 0 + if mm_prefix_range is not None: + if mm_prefix_range.ndim == 3: + use_mm_prefix = True + max_mm_ranges = mm_prefix_range.shape[1] + else: + raise ValueError( + f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}" + ) + use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None @@ -792,11 +896,23 @@ def unified_attention( # = floor(q.shape[0] / BLOCK_Q) + num_seqs total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs - # Assigning default tile sizes for prefill and decode. - # Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1) - # and at least 16 for all other data types. - TILE_SIZE_PREFILL = 32 - TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32 + # Tile sizes for prefill and decode. Gemma3 models use optimized values. + # Note: tile size must be at least 32 for fp8 (element_size == 1). + sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0 + TILE_SIZE_PREFILL = _get_tile_size( + head_size, + sliding_window_val, + q.element_size(), + is_mm_prefix=use_mm_prefix, + is_prefill=True, + ) + TILE_SIZE_DECODE = _get_tile_size( + head_size, + sliding_window_val, + q.element_size(), + is_mm_prefix=use_mm_prefix, + is_prefill=False, + ) # Launch the 2D kernel if # 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or @@ -847,6 +963,9 @@ def unified_attention( USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), USE_SINKS=(sinks is not None), + USE_MM_PREFIX=use_mm_prefix, + MAX_MM_RANGES=max_mm_ranges, + mm_prefix_range_ptr=mm_prefix_range, SLIDING_WINDOW=(1 + window_size[0]), stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), @@ -895,6 +1014,9 @@ def unified_attention( USE_QQ_BIAS=use_qq_bias, USE_SOFTCAP=(softcap > 0), USE_SINKS=(sinks is not None), + USE_MM_PREFIX=use_mm_prefix, + MAX_MM_RANGES=max_mm_ranges, + mm_prefix_range_ptr=mm_prefix_range, SLIDING_WINDOW=(1 + window_size[0]), stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 70f72b5cb9beb..e6a201c669e96 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -19,7 +19,6 @@ from collections.abc import Iterable from itertools import islice import torch -import torch.nn.functional as F from torch import nn from transformers import Gemma3TextConfig @@ -226,77 +225,9 @@ class Gemma3Attention(nn.Module): q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) - - if not kwargs.get("has_images", False): - # Fast path for text-only inputs. The performance for the text-only - # inputs are not affected by the naive attention below. - output, _ = self.o_proj(attn_output) - return output - - # NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens - # that correspond to the same image while using causal attention - # otherwise. Current attention backends cannot handle this pattern, so - # we temporarily use a naive attention implementation with mask tensors. - - # We intentionally keep the attention backend as-is and only override - # `attn_output` with the naive implementation's output. This minimizes - # changes to existing model runners and attention backends. The call to - # `self.attn(q, k, v)` is only used to populate the KV cache - its - # output is discarded and overwritten below. While this duplicates - # computation, it maintains compatibility. - # TODO(woosuk): Optimize by implementing custom attention kernels. - attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs) output, _ = self.o_proj(attn_output) return output - def naive_attn_with_masks( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - out: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - # NOTE(woosuk): As described in the comment above, this code is not - # meant to be performant. It is only meant to be correct. - q = q.view(-1, self.num_heads, self.head_dim) - # Expand the key and value to handle GQA. - num_queries_per_kv = self.num_heads // self.num_kv_heads - k = k.view(-1, self.num_kv_heads, self.head_dim) - k = k.repeat_interleave(num_queries_per_kv, dim=-2) - v = v.view(-1, self.num_kv_heads, self.head_dim) - v = v.repeat_interleave(num_queries_per_kv, dim=-2) - - if self.is_sliding: - attn_masks = kwargs["local_attn_masks"] - else: - attn_masks = kwargs["global_attn_masks"] - - seq_lens = kwargs["seq_lens"] - start_idx = 0 - for seq_len, attn_mask in zip(seq_lens, attn_masks): - end_idx = start_idx + seq_len - query = q[start_idx:end_idx].unsqueeze(0) - key = k[start_idx:end_idx].unsqueeze(0) - value = v[start_idx:end_idx].unsqueeze(0) - - # Transpose. - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - output = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask, - self.scaling, - ) - output = output.transpose(1, 2).flatten(-2, -1) - out[start_idx:end_idx] = output - start_idx = end_idx - return out - class Gemma3DecoderLayer(nn.Module): def __init__( diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 7bea3862a03f9..ca7be990ca555 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -76,6 +76,39 @@ class TritonAttentionMetadata: # Optional aot scheduling scheduler_metadata: torch.Tensor | None = None prefix_scheduler_metadata: torch.Tensor | None = None + mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None + + @property + def mm_prefix_range_tensor(self) -> torch.Tensor | None: + """Convert mm_prefix_range dict to padded tensor for Triton kernel. + + Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges. + Empty ranges have start==end==0, which kernel skips via is_valid check. + """ + # TODO(Isotr0py): Move to model runner's attention metadata + # preparation to avoid duplicate computation. + if self.mm_prefix_range is None: + return None + + num_seqs = self.seq_lens.shape[0] + device = self.seq_lens.device + + # Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims + range_lists = [ + self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs) + ] + + # Return None if all ranges are trivial (only (0,0) placeholders) + if all(r == [(0, 0)] for r in range_lists): + return None + + # Create 2D tensors with shape (num_ranges, 2) for each sequence + range_tensors = [ + torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2) + for r in range_lists + ] + + return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0) class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): @@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend): def supports_head_size(cls, head_size: int) -> bool: return head_size >= 32 + @classmethod + def supports_mm_prefix(cls) -> bool: + return True + @classmethod def supports_sink(cls) -> bool: return True @@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl): softmax_segm_expsum = attn_metadata.softmax_segm_expsum descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2]) + mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor unified_attention( q=query[:num_actual_tokens], @@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl): softmax_segm_expsum=softmax_segm_expsum, sinks=self.sinks, output_scale=output_scale, + mm_prefix_range=mm_prefix_range_tensor, ) return output