mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 12:31:20 +08:00
[v1] Add PrefixLM support to TritonAttention backend (#30386)
This commit is contained in:
parent
05a83dc6ee
commit
74a1ac38b0
@ -1,17 +1,23 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Literal, NamedTuple
|
import os
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from pytest import MarkDecorator
|
from pytest import MarkDecorator
|
||||||
|
from transformers import AutoModelForImageTextToText
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
from vllm.multimodal.image import rescale_image_size
|
||||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||||
|
|
||||||
from ....conftest import PromptImageInput, VllmRunner
|
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
|
|
||||||
@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
|
|||||||
gguf_backbone: str
|
gguf_backbone: str
|
||||||
gguf_mmproj: str
|
gguf_mmproj: str
|
||||||
prompt: list[str]
|
prompt: list[str]
|
||||||
mm_data: dict[Literal["images"], PromptImageInput]
|
image_names: list[str] # Store names, load PIL images at runtime
|
||||||
max_model_len: int = 4096
|
max_model_len: int = 4096
|
||||||
marks: list[MarkDecorator] = []
|
marks: list[MarkDecorator] = []
|
||||||
|
mm_processor_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gguf_model(self):
|
def gguf_model(self):
|
||||||
@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
|
|||||||
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
|
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
|
||||||
|
|
||||||
|
|
||||||
|
# Common prompts aligned with test_common.py "gemma3" entry format
|
||||||
|
_GEMMA3_PROMPTS = IMAGE_ASSETS.prompts(
|
||||||
|
{
|
||||||
|
"stop_sign": (
|
||||||
|
"<bos><start_of_turn>user\n"
|
||||||
|
"<start_of_image>What's the content in the center of the image?"
|
||||||
|
"<end_of_turn>\n<start_of_turn>model\n"
|
||||||
|
),
|
||||||
|
"cherry_blossom": (
|
||||||
|
"<bos><start_of_turn>user\n"
|
||||||
|
"<start_of_image>What is the season?"
|
||||||
|
"<end_of_turn>\n<start_of_turn>model\n"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Image asset names - load at runtime to avoid pickle issues with subprocess
|
||||||
|
_GEMMA3_IMAGE_NAMES = ["stop_sign", "cherry_blossom"]
|
||||||
|
|
||||||
|
# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF
|
||||||
GEMMA3_CONFIG = GGUFMMTestConfig(
|
GEMMA3_CONFIG = GGUFMMTestConfig(
|
||||||
original_model="google/gemma-3-4b-it",
|
original_model="google/gemma-3-4b-it",
|
||||||
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
|
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||||
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
|
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
|
||||||
gguf_mmproj="mmproj-model-f16-4B.gguf",
|
gguf_mmproj="mmproj-model-f16-4B.gguf",
|
||||||
prompt=["<start_of_image>Describe this image in detail:"],
|
prompt=_GEMMA3_PROMPTS,
|
||||||
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
|
image_names=_GEMMA3_IMAGE_NAMES,
|
||||||
|
max_model_len=4096,
|
||||||
marks=[pytest.mark.core_model],
|
marks=[pytest.mark.core_model],
|
||||||
|
mm_processor_kwargs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
MODELS_TO_TEST = [GEMMA3_CONFIG]
|
# Pan-and-scan multimodal - uses unquantized BF16 GGUF
|
||||||
|
GEMMA3_CONFIG_PAN_AND_SCAN = GGUFMMTestConfig(
|
||||||
|
original_model="google/gemma-3-4b-it",
|
||||||
|
gguf_repo="unsloth/gemma-3-4b-it-GGUF",
|
||||||
|
gguf_backbone="gemma-3-4b-it-BF16.gguf",
|
||||||
|
gguf_mmproj="mmproj-BF16.gguf",
|
||||||
|
prompt=_GEMMA3_PROMPTS,
|
||||||
|
image_names=_GEMMA3_IMAGE_NAMES,
|
||||||
|
max_model_len=4096,
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
|
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
MODELS_TO_TEST = [GEMMA3_CONFIG, GEMMA3_CONFIG_PAN_AND_SCAN]
|
||||||
|
|
||||||
|
|
||||||
def run_multimodal_gguf_test(
|
def run_multimodal_gguf_test(
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
model: GGUFMMTestConfig,
|
model: GGUFMMTestConfig,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
):
|
):
|
||||||
# Run gguf model.
|
# Load images at runtime (inside subprocess) to avoid pickle issues
|
||||||
|
images = [ImageAsset(name).pil_image for name in model.image_names]
|
||||||
|
size_factors = [0.25, 0.5, 1.0]
|
||||||
|
inputs_per_image = [
|
||||||
|
(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
)
|
||||||
|
for image, prompt in zip(images, model.prompt)
|
||||||
|
]
|
||||||
|
|
||||||
|
# NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork.
|
||||||
|
# Run GGUF model via vLLM.
|
||||||
with (
|
with (
|
||||||
set_default_torch_num_threads(1),
|
set_default_torch_num_threads(1),
|
||||||
vllm_runner(
|
vllm_runner(
|
||||||
@ -60,35 +115,42 @@ def run_multimodal_gguf_test(
|
|||||||
tokenizer_name=model.original_model,
|
tokenizer_name=model.original_model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=model.max_model_len,
|
max_model_len=model.max_model_len,
|
||||||
|
mm_processor_kwargs=model.mm_processor_kwargs,
|
||||||
) as gguf_model,
|
) as gguf_model,
|
||||||
):
|
):
|
||||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
gguf_outputs_per_case = [
|
||||||
prompts=model.prompt,
|
gguf_model.generate_greedy_logprobs(
|
||||||
max_tokens=max_tokens,
|
prompts,
|
||||||
num_logprobs=num_logprobs,
|
max_tokens,
|
||||||
**model.mm_data,
|
num_logprobs=num_logprobs,
|
||||||
)
|
images=images,
|
||||||
|
)
|
||||||
|
for prompts, images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
# Run unquantized model.
|
# Then run HfRunner for HuggingFace baseline comparison.
|
||||||
with vllm_runner(
|
with hf_runner(
|
||||||
model_name=model.original_model,
|
model.original_model,
|
||||||
enforce_eager=True, # faster tests
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=model.max_model_len,
|
auto_cls=AutoModelForImageTextToText,
|
||||||
) as original_model:
|
) as hf_model:
|
||||||
original_outputs = original_model.generate_greedy_logprobs(
|
hf_outputs_per_case = [
|
||||||
prompts=model.prompt,
|
hf_model.generate_greedy_logprobs_limit(
|
||||||
max_tokens=max_tokens,
|
prompts,
|
||||||
num_logprobs=num_logprobs,
|
max_tokens,
|
||||||
**model.mm_data,
|
num_logprobs=num_logprobs,
|
||||||
)
|
images=images,
|
||||||
|
)
|
||||||
|
for prompts, images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
check_logprobs_close(
|
for hf_outputs, gguf_outputs in zip(hf_outputs_per_case, gguf_outputs_per_case):
|
||||||
outputs_0_lst=original_outputs,
|
check_logprobs_close(
|
||||||
outputs_1_lst=gguf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
name_0="original",
|
outputs_1_lst=gguf_outputs,
|
||||||
name_1="gguf",
|
name_0="hf",
|
||||||
)
|
name_1="gguf",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
|
|||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [32])
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
@pytest.mark.parametrize("num_logprobs", [10])
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
def test_models(
|
def test_gemma3_mm_gguf(
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
model: GGUFMMTestConfig,
|
model: GGUFMMTestConfig,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
|
run_multimodal_gguf_test(
|
||||||
|
hf_runner, vllm_runner, model, dtype, max_tokens, num_logprobs
|
||||||
|
)
|
||||||
|
|||||||
@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
|
|||||||
USE_SOFTCAP: tl.constexpr, # bool
|
USE_SOFTCAP: tl.constexpr, # bool
|
||||||
USE_SINKS: tl.constexpr, # bool
|
USE_SINKS: tl.constexpr, # bool
|
||||||
SLIDING_WINDOW: tl.constexpr, # int
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
|
USE_MM_PREFIX: tl.constexpr, # bool
|
||||||
|
MAX_MM_RANGES: tl.constexpr, # int
|
||||||
|
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||||
stride_k_cache_0: tl.int64, # int
|
stride_k_cache_0: tl.int64, # int
|
||||||
stride_k_cache_1: tl.int64, # int
|
stride_k_cache_1: tl.int64, # int
|
||||||
stride_k_cache_2: tl.int64, # int
|
stride_k_cache_2: tl.int64, # int
|
||||||
@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
|
|||||||
else:
|
else:
|
||||||
V = V_load
|
V = V_load
|
||||||
|
|
||||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
# Compute attention mask: causal by default (key <= query)
|
||||||
|
query_abs_pos = context_len + query_pos[:, None]
|
||||||
|
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||||
|
|
||||||
|
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||||
|
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||||
|
|
||||||
|
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||||
|
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||||
|
if USE_MM_PREFIX:
|
||||||
|
for i in range(MAX_MM_RANGES):
|
||||||
|
range_start = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||||
|
)
|
||||||
|
range_end = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
is_valid = range_start < range_end
|
||||||
|
q_in_range = (
|
||||||
|
(query_abs_pos >= range_start)
|
||||||
|
& (query_abs_pos <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
k_in_range = (
|
||||||
|
(seq_offset[None, :] >= range_start)
|
||||||
|
& (seq_offset[None, :] <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
seq_mask |= q_in_range & k_in_range
|
||||||
|
|
||||||
# S : (BLOCK_M, TILE_SIZE)
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
|
|||||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||||
)
|
)
|
||||||
|
|
||||||
if SLIDING_WINDOW > 0:
|
|
||||||
S = tl.where(
|
|
||||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
|
||||||
S,
|
|
||||||
float("-inf"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if USE_ALIBI_SLOPES:
|
if USE_ALIBI_SLOPES:
|
||||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
|
|||||||
num_seqs: tl.int32,
|
num_seqs: tl.int32,
|
||||||
BLOCK_M: tl.constexpr, # int
|
BLOCK_M: tl.constexpr, # int
|
||||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||||
|
USE_MM_PREFIX: tl.constexpr, # bool
|
||||||
|
MAX_MM_RANGES: tl.constexpr, # int
|
||||||
|
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||||
):
|
):
|
||||||
q_block_global_idx = tl.program_id(0)
|
q_block_global_idx = tl.program_id(0)
|
||||||
kv_head_idx = tl.program_id(1)
|
kv_head_idx = tl.program_id(1)
|
||||||
@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
|
|||||||
else:
|
else:
|
||||||
V = V_load
|
V = V_load
|
||||||
|
|
||||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
# Compute attention mask: causal by default (key <= query)
|
||||||
|
query_abs_pos = context_len + query_pos[:, None]
|
||||||
|
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||||
|
|
||||||
|
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||||
|
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||||
|
|
||||||
|
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||||
|
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||||
|
if USE_MM_PREFIX:
|
||||||
|
for i in range(MAX_MM_RANGES):
|
||||||
|
range_start = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||||
|
)
|
||||||
|
range_end = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
is_valid = range_start < range_end
|
||||||
|
q_in_range = (
|
||||||
|
(query_abs_pos >= range_start)
|
||||||
|
& (query_abs_pos <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
k_in_range = (
|
||||||
|
(seq_offset[None, :] >= range_start)
|
||||||
|
& (seq_offset[None, :] <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
seq_mask |= q_in_range & k_in_range
|
||||||
|
|
||||||
# S : (BLOCK_M, TILE_SIZE)
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
|
|||||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||||
)
|
)
|
||||||
|
|
||||||
if SLIDING_WINDOW > 0:
|
|
||||||
S = tl.where(
|
|
||||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
|
||||||
S,
|
|
||||||
float("-inf"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if USE_ALIBI_SLOPES:
|
if USE_ALIBI_SLOPES:
|
||||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
@ -732,6 +786,43 @@ def reduce_segments(
|
|||||||
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
|
||||||
|
"""Detect Gemma3 models via unique (head_size, sliding_window) signature.
|
||||||
|
|
||||||
|
Gemma3 models are the only ones using sliding_window=1024 with
|
||||||
|
head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
|
||||||
|
different window sizes (Mistral=4096, Phi-3=2047).
|
||||||
|
"""
|
||||||
|
return sliding_window == 1024 and head_size in (128, 256)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tile_size(
|
||||||
|
head_size: int,
|
||||||
|
sliding_window: int,
|
||||||
|
element_size: int,
|
||||||
|
is_mm_prefix: bool,
|
||||||
|
is_prefill: bool,
|
||||||
|
) -> int:
|
||||||
|
"""Select tile size with Gemma3-specific optimization.
|
||||||
|
|
||||||
|
For Gemma3, use 32 for both prefill and decode to better utilize
|
||||||
|
the larger head dimension (128/256). For other models, use
|
||||||
|
the default vLLM behavior.
|
||||||
|
"""
|
||||||
|
if is_mm_prefix:
|
||||||
|
# Multimodal bidirectional attention needs a larger tile size
|
||||||
|
return 64
|
||||||
|
|
||||||
|
if _is_gemma3_attention(head_size, sliding_window):
|
||||||
|
# Gemma3: use 32 for decode (default is 16)
|
||||||
|
return 32
|
||||||
|
|
||||||
|
# Default behavior
|
||||||
|
if is_prefill:
|
||||||
|
return 32
|
||||||
|
return 16 if element_size >= 2 else 32
|
||||||
|
|
||||||
|
|
||||||
def unified_attention(
|
def unified_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@ -759,6 +850,8 @@ def unified_attention(
|
|||||||
qq_bias=None,
|
qq_bias=None,
|
||||||
# Optional tensor for sinks
|
# Optional tensor for sinks
|
||||||
sinks=None,
|
sinks=None,
|
||||||
|
# Optional tensor for prefix lengths (PrefixLM support)
|
||||||
|
mm_prefix_range=None,
|
||||||
):
|
):
|
||||||
assert causal, "Only causal attention is supported"
|
assert causal, "Only causal attention is supported"
|
||||||
assert q_descale is None, "Q scales not supported"
|
assert q_descale is None, "Q scales not supported"
|
||||||
@ -766,6 +859,17 @@ def unified_attention(
|
|||||||
if sinks is not None:
|
if sinks is not None:
|
||||||
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
|
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
|
||||||
|
|
||||||
|
use_mm_prefix = False
|
||||||
|
max_mm_ranges = 0
|
||||||
|
if mm_prefix_range is not None:
|
||||||
|
if mm_prefix_range.ndim == 3:
|
||||||
|
use_mm_prefix = True
|
||||||
|
max_mm_ranges = mm_prefix_range.shape[1]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
use_alibi_slopes = alibi_slopes is not None
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
use_qq_bias = qq_bias is not None
|
use_qq_bias = qq_bias is not None
|
||||||
|
|
||||||
@ -792,11 +896,23 @@ def unified_attention(
|
|||||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||||
|
|
||||||
# Assigning default tile sizes for prefill and decode.
|
# Tile sizes for prefill and decode. Gemma3 models use optimized values.
|
||||||
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
# Note: tile size must be at least 32 for fp8 (element_size == 1).
|
||||||
# and at least 16 for all other data types.
|
sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
|
||||||
TILE_SIZE_PREFILL = 32
|
TILE_SIZE_PREFILL = _get_tile_size(
|
||||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
head_size,
|
||||||
|
sliding_window_val,
|
||||||
|
q.element_size(),
|
||||||
|
is_mm_prefix=use_mm_prefix,
|
||||||
|
is_prefill=True,
|
||||||
|
)
|
||||||
|
TILE_SIZE_DECODE = _get_tile_size(
|
||||||
|
head_size,
|
||||||
|
sliding_window_val,
|
||||||
|
q.element_size(),
|
||||||
|
is_mm_prefix=use_mm_prefix,
|
||||||
|
is_prefill=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Launch the 2D kernel if
|
# Launch the 2D kernel if
|
||||||
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
||||||
@ -847,6 +963,9 @@ def unified_attention(
|
|||||||
USE_QQ_BIAS=use_qq_bias,
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
USE_SOFTCAP=(softcap > 0),
|
USE_SOFTCAP=(softcap > 0),
|
||||||
USE_SINKS=(sinks is not None),
|
USE_SINKS=(sinks is not None),
|
||||||
|
USE_MM_PREFIX=use_mm_prefix,
|
||||||
|
MAX_MM_RANGES=max_mm_ranges,
|
||||||
|
mm_prefix_range_ptr=mm_prefix_range,
|
||||||
SLIDING_WINDOW=(1 + window_size[0]),
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
stride_k_cache_0=k.stride(0),
|
stride_k_cache_0=k.stride(0),
|
||||||
stride_k_cache_1=k.stride(1),
|
stride_k_cache_1=k.stride(1),
|
||||||
@ -895,6 +1014,9 @@ def unified_attention(
|
|||||||
USE_QQ_BIAS=use_qq_bias,
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
USE_SOFTCAP=(softcap > 0),
|
USE_SOFTCAP=(softcap > 0),
|
||||||
USE_SINKS=(sinks is not None),
|
USE_SINKS=(sinks is not None),
|
||||||
|
USE_MM_PREFIX=use_mm_prefix,
|
||||||
|
MAX_MM_RANGES=max_mm_ranges,
|
||||||
|
mm_prefix_range_ptr=mm_prefix_range,
|
||||||
SLIDING_WINDOW=(1 + window_size[0]),
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
stride_k_cache_0=k.stride(0),
|
stride_k_cache_0=k.stride(0),
|
||||||
stride_k_cache_1=k.stride(1),
|
stride_k_cache_1=k.stride(1),
|
||||||
|
|||||||
@ -19,7 +19,6 @@ from collections.abc import Iterable
|
|||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Gemma3TextConfig
|
from transformers import Gemma3TextConfig
|
||||||
|
|
||||||
@ -226,77 +225,9 @@ class Gemma3Attention(nn.Module):
|
|||||||
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
if not kwargs.get("has_images", False):
|
|
||||||
# Fast path for text-only inputs. The performance for the text-only
|
|
||||||
# inputs are not affected by the naive attention below.
|
|
||||||
output, _ = self.o_proj(attn_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
|
|
||||||
# that correspond to the same image while using causal attention
|
|
||||||
# otherwise. Current attention backends cannot handle this pattern, so
|
|
||||||
# we temporarily use a naive attention implementation with mask tensors.
|
|
||||||
|
|
||||||
# We intentionally keep the attention backend as-is and only override
|
|
||||||
# `attn_output` with the naive implementation's output. This minimizes
|
|
||||||
# changes to existing model runners and attention backends. The call to
|
|
||||||
# `self.attn(q, k, v)` is only used to populate the KV cache - its
|
|
||||||
# output is discarded and overwritten below. While this duplicates
|
|
||||||
# computation, it maintains compatibility.
|
|
||||||
# TODO(woosuk): Optimize by implementing custom attention kernels.
|
|
||||||
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
|
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def naive_attn_with_masks(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
out: torch.Tensor,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# NOTE(woosuk): As described in the comment above, this code is not
|
|
||||||
# meant to be performant. It is only meant to be correct.
|
|
||||||
q = q.view(-1, self.num_heads, self.head_dim)
|
|
||||||
# Expand the key and value to handle GQA.
|
|
||||||
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
||||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
|
||||||
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
|
||||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
|
||||||
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
|
||||||
|
|
||||||
if self.is_sliding:
|
|
||||||
attn_masks = kwargs["local_attn_masks"]
|
|
||||||
else:
|
|
||||||
attn_masks = kwargs["global_attn_masks"]
|
|
||||||
|
|
||||||
seq_lens = kwargs["seq_lens"]
|
|
||||||
start_idx = 0
|
|
||||||
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
|
||||||
end_idx = start_idx + seq_len
|
|
||||||
query = q[start_idx:end_idx].unsqueeze(0)
|
|
||||||
key = k[start_idx:end_idx].unsqueeze(0)
|
|
||||||
value = v[start_idx:end_idx].unsqueeze(0)
|
|
||||||
|
|
||||||
# Transpose.
|
|
||||||
query = query.transpose(1, 2)
|
|
||||||
key = key.transpose(1, 2)
|
|
||||||
value = value.transpose(1, 2)
|
|
||||||
|
|
||||||
output = F.scaled_dot_product_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
attn_mask,
|
|
||||||
self.scaling,
|
|
||||||
)
|
|
||||||
output = output.transpose(1, 2).flatten(-2, -1)
|
|
||||||
out[start_idx:end_idx] = output
|
|
||||||
start_idx = end_idx
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3DecoderLayer(nn.Module):
|
class Gemma3DecoderLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -76,6 +76,39 @@ class TritonAttentionMetadata:
|
|||||||
# Optional aot scheduling
|
# Optional aot scheduling
|
||||||
scheduler_metadata: torch.Tensor | None = None
|
scheduler_metadata: torch.Tensor | None = None
|
||||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||||
|
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
|
||||||
|
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
|
||||||
|
|
||||||
|
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
|
||||||
|
Empty ranges have start==end==0, which kernel skips via is_valid check.
|
||||||
|
"""
|
||||||
|
# TODO(Isotr0py): Move to model runner's attention metadata
|
||||||
|
# preparation to avoid duplicate computation.
|
||||||
|
if self.mm_prefix_range is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
num_seqs = self.seq_lens.shape[0]
|
||||||
|
device = self.seq_lens.device
|
||||||
|
|
||||||
|
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
|
||||||
|
range_lists = [
|
||||||
|
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Return None if all ranges are trivial (only (0,0) placeholders)
|
||||||
|
if all(r == [(0, 0)] for r in range_lists):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create 2D tensors with shape (num_ranges, 2) for each sequence
|
||||||
|
range_tensors = [
|
||||||
|
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
|
||||||
|
for r in range_lists
|
||||||
|
]
|
||||||
|
|
||||||
|
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
|
||||||
|
|
||||||
|
|
||||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||||
@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
def supports_head_size(cls, head_size: int) -> bool:
|
def supports_head_size(cls, head_size: int) -> bool:
|
||||||
return head_size >= 32
|
return head_size >= 32
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_mm_prefix(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports_sink(cls) -> bool:
|
def supports_sink(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||||
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||||
|
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
|
||||||
|
|
||||||
unified_attention(
|
unified_attention(
|
||||||
q=query[:num_actual_tokens],
|
q=query[:num_actual_tokens],
|
||||||
@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
softmax_segm_expsum=softmax_segm_expsum,
|
softmax_segm_expsum=softmax_segm_expsum,
|
||||||
sinks=self.sinks,
|
sinks=self.sinks,
|
||||||
output_scale=output_scale,
|
output_scale=output_scale,
|
||||||
|
mm_prefix_range=mm_prefix_range_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user