mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 18:35:56 +08:00
[v1] Add PrefixLM support to TritonAttention backend (#30386)
This commit is contained in:
parent
05a83dc6ee
commit
74a1ac38b0
@ -1,17 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal, NamedTuple
|
||||
import os
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pytest import MarkDecorator
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
|
||||
from ....conftest import PromptImageInput, VllmRunner
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
|
||||
@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
|
||||
gguf_backbone: str
|
||||
gguf_mmproj: str
|
||||
prompt: list[str]
|
||||
mm_data: dict[Literal["images"], PromptImageInput]
|
||||
image_names: list[str] # Store names, load PIL images at runtime
|
||||
max_model_len: int = 4096
|
||||
marks: list[MarkDecorator] = []
|
||||
mm_processor_kwargs: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def gguf_model(self):
|
||||
@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
|
||||
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
|
||||
|
||||
|
||||
# Common prompts aligned with test_common.py "gemma3" entry format
|
||||
_GEMMA3_PROMPTS = IMAGE_ASSETS.prompts(
|
||||
{
|
||||
"stop_sign": (
|
||||
"<bos><start_of_turn>user\n"
|
||||
"<start_of_image>What's the content in the center of the image?"
|
||||
"<end_of_turn>\n<start_of_turn>model\n"
|
||||
),
|
||||
"cherry_blossom": (
|
||||
"<bos><start_of_turn>user\n"
|
||||
"<start_of_image>What is the season?"
|
||||
"<end_of_turn>\n<start_of_turn>model\n"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Image asset names - load at runtime to avoid pickle issues with subprocess
|
||||
_GEMMA3_IMAGE_NAMES = ["stop_sign", "cherry_blossom"]
|
||||
|
||||
# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF
|
||||
GEMMA3_CONFIG = GGUFMMTestConfig(
|
||||
original_model="google/gemma-3-4b-it",
|
||||
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
|
||||
gguf_mmproj="mmproj-model-f16-4B.gguf",
|
||||
prompt=["<start_of_image>Describe this image in detail:"],
|
||||
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
|
||||
prompt=_GEMMA3_PROMPTS,
|
||||
image_names=_GEMMA3_IMAGE_NAMES,
|
||||
max_model_len=4096,
|
||||
marks=[pytest.mark.core_model],
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
MODELS_TO_TEST = [GEMMA3_CONFIG]
|
||||
# Pan-and-scan multimodal - uses unquantized BF16 GGUF
|
||||
GEMMA3_CONFIG_PAN_AND_SCAN = GGUFMMTestConfig(
|
||||
original_model="google/gemma-3-4b-it",
|
||||
gguf_repo="unsloth/gemma-3-4b-it-GGUF",
|
||||
gguf_backbone="gemma-3-4b-it-BF16.gguf",
|
||||
gguf_mmproj="mmproj-BF16.gguf",
|
||||
prompt=_GEMMA3_PROMPTS,
|
||||
image_names=_GEMMA3_IMAGE_NAMES,
|
||||
max_model_len=4096,
|
||||
marks=[pytest.mark.core_model],
|
||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||
)
|
||||
|
||||
MODELS_TO_TEST = [GEMMA3_CONFIG, GEMMA3_CONFIG_PAN_AND_SCAN]
|
||||
|
||||
|
||||
def run_multimodal_gguf_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: GGUFMMTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
):
|
||||
# Run gguf model.
|
||||
# Load images at runtime (inside subprocess) to avoid pickle issues
|
||||
images = [ImageAsset(name).pil_image for name in model.image_names]
|
||||
size_factors = [0.25, 0.5, 1.0]
|
||||
inputs_per_image = [
|
||||
(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
)
|
||||
for image, prompt in zip(images, model.prompt)
|
||||
]
|
||||
|
||||
# NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork.
|
||||
# Run GGUF model via vLLM.
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
@ -60,35 +115,42 @@ def run_multimodal_gguf_test(
|
||||
tokenizer_name=model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=model.max_model_len,
|
||||
mm_processor_kwargs=model.mm_processor_kwargs,
|
||||
) as gguf_model,
|
||||
):
|
||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
||||
prompts=model.prompt,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
**model.mm_data,
|
||||
)
|
||||
gguf_outputs_per_case = [
|
||||
gguf_model.generate_greedy_logprobs(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
# Run unquantized model.
|
||||
with vllm_runner(
|
||||
model_name=model.original_model,
|
||||
enforce_eager=True, # faster tests
|
||||
# Then run HfRunner for HuggingFace baseline comparison.
|
||||
with hf_runner(
|
||||
model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=model.max_model_len,
|
||||
) as original_model:
|
||||
original_outputs = original_model.generate_greedy_logprobs(
|
||||
prompts=model.prompt,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
**model.mm_data,
|
||||
)
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
) as hf_model:
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=original_outputs,
|
||||
outputs_1_lst=gguf_outputs,
|
||||
name_0="original",
|
||||
name_1="gguf",
|
||||
)
|
||||
for hf_outputs, gguf_outputs in zip(hf_outputs_per_case, gguf_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=gguf_outputs,
|
||||
name_0="hf",
|
||||
name_1="gguf",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(
|
||||
def test_gemma3_mm_gguf(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: GGUFMMTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
|
||||
run_multimodal_gguf_test(
|
||||
hf_runner, vllm_runner, model, dtype, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
USE_MM_PREFIX: tl.constexpr, # bool
|
||||
MAX_MM_RANGES: tl.constexpr, # int
|
||||
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
# Compute attention mask: causal by default (key <= query)
|
||||
query_abs_pos = context_len + query_pos[:, None]
|
||||
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||
|
||||
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||
if SLIDING_WINDOW > 0:
|
||||
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||
|
||||
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||
if USE_MM_PREFIX:
|
||||
for i in range(MAX_MM_RANGES):
|
||||
range_start = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||
)
|
||||
range_end = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||
)
|
||||
|
||||
is_valid = range_start < range_end
|
||||
q_in_range = (
|
||||
(query_abs_pos >= range_start)
|
||||
& (query_abs_pos <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
k_in_range = (
|
||||
(seq_offset[None, :] >= range_start)
|
||||
& (seq_offset[None, :] <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
seq_mask |= q_in_range & k_in_range
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
|
||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||
)
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where(
|
||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
||||
S,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
USE_MM_PREFIX: tl.constexpr, # bool
|
||||
MAX_MM_RANGES: tl.constexpr, # int
|
||||
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
# Compute attention mask: causal by default (key <= query)
|
||||
query_abs_pos = context_len + query_pos[:, None]
|
||||
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||
|
||||
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||
if SLIDING_WINDOW > 0:
|
||||
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||
|
||||
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||
if USE_MM_PREFIX:
|
||||
for i in range(MAX_MM_RANGES):
|
||||
range_start = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||
)
|
||||
range_end = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||
)
|
||||
|
||||
is_valid = range_start < range_end
|
||||
q_in_range = (
|
||||
(query_abs_pos >= range_start)
|
||||
& (query_abs_pos <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
k_in_range = (
|
||||
(seq_offset[None, :] >= range_start)
|
||||
& (seq_offset[None, :] <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
seq_mask |= q_in_range & k_in_range
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
|
||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||
)
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where(
|
||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
||||
S,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
@ -732,6 +786,43 @@ def reduce_segments(
|
||||
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||
|
||||
|
||||
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
|
||||
"""Detect Gemma3 models via unique (head_size, sliding_window) signature.
|
||||
|
||||
Gemma3 models are the only ones using sliding_window=1024 with
|
||||
head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
|
||||
different window sizes (Mistral=4096, Phi-3=2047).
|
||||
"""
|
||||
return sliding_window == 1024 and head_size in (128, 256)
|
||||
|
||||
|
||||
def _get_tile_size(
|
||||
head_size: int,
|
||||
sliding_window: int,
|
||||
element_size: int,
|
||||
is_mm_prefix: bool,
|
||||
is_prefill: bool,
|
||||
) -> int:
|
||||
"""Select tile size with Gemma3-specific optimization.
|
||||
|
||||
For Gemma3, use 32 for both prefill and decode to better utilize
|
||||
the larger head dimension (128/256). For other models, use
|
||||
the default vLLM behavior.
|
||||
"""
|
||||
if is_mm_prefix:
|
||||
# Multimodal bidirectional attention needs a larger tile size
|
||||
return 64
|
||||
|
||||
if _is_gemma3_attention(head_size, sliding_window):
|
||||
# Gemma3: use 32 for decode (default is 16)
|
||||
return 32
|
||||
|
||||
# Default behavior
|
||||
if is_prefill:
|
||||
return 32
|
||||
return 16 if element_size >= 2 else 32
|
||||
|
||||
|
||||
def unified_attention(
|
||||
q,
|
||||
k,
|
||||
@ -759,6 +850,8 @@ def unified_attention(
|
||||
qq_bias=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
# Optional tensor for prefix lengths (PrefixLM support)
|
||||
mm_prefix_range=None,
|
||||
):
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
@ -766,6 +859,17 @@ def unified_attention(
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
|
||||
|
||||
use_mm_prefix = False
|
||||
max_mm_ranges = 0
|
||||
if mm_prefix_range is not None:
|
||||
if mm_prefix_range.ndim == 3:
|
||||
use_mm_prefix = True
|
||||
max_mm_ranges = mm_prefix_range.shape[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
|
||||
)
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
use_qq_bias = qq_bias is not None
|
||||
|
||||
@ -792,11 +896,23 @@ def unified_attention(
|
||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||
|
||||
# Assigning default tile sizes for prefill and decode.
|
||||
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
||||
# and at least 16 for all other data types.
|
||||
TILE_SIZE_PREFILL = 32
|
||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||
# Tile sizes for prefill and decode. Gemma3 models use optimized values.
|
||||
# Note: tile size must be at least 32 for fp8 (element_size == 1).
|
||||
sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
|
||||
TILE_SIZE_PREFILL = _get_tile_size(
|
||||
head_size,
|
||||
sliding_window_val,
|
||||
q.element_size(),
|
||||
is_mm_prefix=use_mm_prefix,
|
||||
is_prefill=True,
|
||||
)
|
||||
TILE_SIZE_DECODE = _get_tile_size(
|
||||
head_size,
|
||||
sliding_window_val,
|
||||
q.element_size(),
|
||||
is_mm_prefix=use_mm_prefix,
|
||||
is_prefill=False,
|
||||
)
|
||||
|
||||
# Launch the 2D kernel if
|
||||
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
||||
@ -847,6 +963,9 @@ def unified_attention(
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
USE_MM_PREFIX=use_mm_prefix,
|
||||
MAX_MM_RANGES=max_mm_ranges,
|
||||
mm_prefix_range_ptr=mm_prefix_range,
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
@ -895,6 +1014,9 @@ def unified_attention(
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
USE_MM_PREFIX=use_mm_prefix,
|
||||
MAX_MM_RANGES=max_mm_ranges,
|
||||
mm_prefix_range_ptr=mm_prefix_range,
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
|
||||
@ -19,7 +19,6 @@ from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import Gemma3TextConfig
|
||||
|
||||
@ -226,77 +225,9 @@ class Gemma3Attention(nn.Module):
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
if not kwargs.get("has_images", False):
|
||||
# Fast path for text-only inputs. The performance for the text-only
|
||||
# inputs are not affected by the naive attention below.
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
|
||||
# that correspond to the same image while using causal attention
|
||||
# otherwise. Current attention backends cannot handle this pattern, so
|
||||
# we temporarily use a naive attention implementation with mask tensors.
|
||||
|
||||
# We intentionally keep the attention backend as-is and only override
|
||||
# `attn_output` with the naive implementation's output. This minimizes
|
||||
# changes to existing model runners and attention backends. The call to
|
||||
# `self.attn(q, k, v)` is only used to populate the KV cache - its
|
||||
# output is discarded and overwritten below. While this duplicates
|
||||
# computation, it maintains compatibility.
|
||||
# TODO(woosuk): Optimize by implementing custom attention kernels.
|
||||
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def naive_attn_with_masks(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): As described in the comment above, this code is not
|
||||
# meant to be performant. It is only meant to be correct.
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
# Expand the key and value to handle GQA.
|
||||
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||
|
||||
if self.is_sliding:
|
||||
attn_masks = kwargs["local_attn_masks"]
|
||||
else:
|
||||
attn_masks = kwargs["global_attn_masks"]
|
||||
|
||||
seq_lens = kwargs["seq_lens"]
|
||||
start_idx = 0
|
||||
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
||||
end_idx = start_idx + seq_len
|
||||
query = q[start_idx:end_idx].unsqueeze(0)
|
||||
key = k[start_idx:end_idx].unsqueeze(0)
|
||||
value = v[start_idx:end_idx].unsqueeze(0)
|
||||
|
||||
# Transpose.
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
output = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
self.scaling,
|
||||
)
|
||||
output = output.transpose(1, 2).flatten(-2, -1)
|
||||
out[start_idx:end_idx] = output
|
||||
start_idx = end_idx
|
||||
return out
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
|
||||
@ -76,6 +76,39 @@ class TritonAttentionMetadata:
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||
|
||||
@property
|
||||
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
|
||||
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
|
||||
|
||||
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
|
||||
Empty ranges have start==end==0, which kernel skips via is_valid check.
|
||||
"""
|
||||
# TODO(Isotr0py): Move to model runner's attention metadata
|
||||
# preparation to avoid duplicate computation.
|
||||
if self.mm_prefix_range is None:
|
||||
return None
|
||||
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
device = self.seq_lens.device
|
||||
|
||||
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
|
||||
range_lists = [
|
||||
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
|
||||
]
|
||||
|
||||
# Return None if all ranges are trivial (only (0,0) placeholders)
|
||||
if all(r == [(0, 0)] for r in range_lists):
|
||||
return None
|
||||
|
||||
# Create 2D tensors with shape (num_ranges, 2) for each sequence
|
||||
range_tensors = [
|
||||
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
|
||||
for r in range_lists
|
||||
]
|
||||
|
||||
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
mm_prefix_range=mm_prefix_range_tensor,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user