[CI/Build] Refactor Attention backend for test_prefix_prefill from xformers to SDPA (#28424)

Signed-off-by: zhewenli <zhewenli@meta.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Zhewen Li 2025-11-11 09:09:47 -08:00 committed by GitHub
parent 5a1271d83a
commit e553424919
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,10 +8,8 @@ from collections.abc import Callable
import pytest
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
import torch.nn.functional as F
from tests.kernels.utils import make_alibi_bias
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform
@ -28,6 +26,74 @@ KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
def create_causal_attention_mask_for_sdpa(
query_lens: list[int],
seq_lens: list[int],
sliding_window: int = 0,
device: torch.device = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
total_queries = sum(query_lens)
total_keys = sum(seq_lens)
# Create a mask filled with -inf
mask = torch.full(
(total_queries, total_keys), float("-inf"), device=device, dtype=dtype
)
query_start = 0
key_start = 0
for query_len, seq_len in zip(query_lens, seq_lens):
query_end = query_start + query_len
key_end = key_start + seq_len
q_indices = torch.arange(query_len, device=device)
k_indices = torch.arange(seq_len, device=device)
q_pos_in_seq = seq_len - query_len + q_indices
valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None]
if sliding_window > 0:
valid_mask &= k_indices[None, :] >= (
q_pos_in_seq[:, None] - sliding_window + 1
)
mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0
query_start = query_end
key_start = key_end
return mask
def create_alibi_causal_mask(
query_len: int,
seq_len: int,
alibi_slopes: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
query_pos = torch.arange(
seq_len - query_len, seq_len, device=device, dtype=torch.float32
)
key_pos = torch.arange(seq_len, device=device, dtype=torch.float32)
rel_pos = key_pos[None, :] - query_pos[:, None]
# Apply ALiBi slopes: [num_heads, query_len, seq_len]
alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :]
alibi_bias = alibi_bias.to(dtype)
# Apply causal mask: prevent attending to future positions
# causal_mask[i, j] = True if key_pos[j] <= query_pos[i]
causal_mask = key_pos[None, :] <= query_pos[:, None]
alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf"))
# Add batch dimension: [1, num_heads, query_len, seq_len]
# SDPA expects batch dimension even for single sequences
return alibi_bias.unsqueeze(0)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@ -52,6 +118,13 @@ def test_contexted_kv_attention(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
if (
current_platform.is_rocm()
and op is chunked_prefill_paged_decode
and kv_cache_dtype == "fp8_e5m2"
):
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
current_platform.seed_everything(0)
torch.set_default_device(device)
@ -96,16 +169,16 @@ def test_contexted_kv_attention(
)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = torch.arange(0, cache_size, dtype=torch.int32)
values = values[torch.randperm(cache_size)]
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
)
for i in range(BS):
for j in range(query_lens[i]):
@ -189,56 +262,57 @@ def test_contexted_kv_attention(
scale = float(1.0 / (head_size**0.5))
attn_op = xops.fmha.cutlass.FwOp()
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(
query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1]
)
key = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
value = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
query_lens, seq_lens
# Reshape for SDPA: (seq_len, num_heads, head_size) ->
# (1, num_heads, seq_len, head_size)
query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size)
query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape(
1, num_heads, num_tokens, head_size
)
if sliding_window > 0:
attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window)
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
# Expand key and value for GQA/MQA to match query heads
key_sdpa = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape(
1, num_heads, sum(seq_lens), head_size
)
value_sdpa = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape(
1, num_heads, sum(seq_lens), head_size
)
attn_mask = create_causal_attention_mask_for_sdpa(
query_lens, seq_lens, sliding_window, device=device, dtype=dtype
)
output_ref = F.scaled_dot_product_attention(
query_sdpa,
key_sdpa,
value_sdpa,
attn_mask=attn_mask,
dropout_p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
start_time = time.time()
output_ref = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias,
p=0.0,
output_ref = F.scaled_dot_product_attention(
query_sdpa,
key_sdpa,
value_sdpa,
attn_mask=attn_mask,
dropout_p=0.0,
scale=scale,
op=attn_op,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
# Reshape output back to (num_tokens, num_heads, head_size)
output_ref = output_ref.view(num_heads, num_tokens, head_size)
output_ref = output_ref.permute(1, 0, 2).contiguous()
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
@ -265,6 +339,13 @@ def test_contexted_kv_attention_alibi(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
if (
current_platform.is_rocm()
and op is chunked_prefill_paged_decode
and kv_cache_dtype == "fp8_e5m2"
):
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
current_platform.seed_everything(0)
torch.set_default_device(device)
@ -331,16 +412,16 @@ def test_contexted_kv_attention_alibi(
)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = torch.arange(0, cache_size, dtype=torch.int32)
values = values[torch.randperm(cache_size)]
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
)
for i in range(BS):
for j in range(query_lens[i]):
@ -423,78 +504,75 @@ def test_contexted_kv_attention_alibi(
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
# we have to pad query tensor before MQA/GQA expanding.
if query.shape[0] != key.shape[0]:
query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype)
query_pad.uniform_(-1e-3, 1e-3)
seq_start = 0
query_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
query_pad[seq_start:seq_end, ...] = torch.cat(
[
torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype),
query[query_start:query_end, ...],
],
dim=0,
)
seq_start += seq_len
query_start += query_len
query = query_pad
# Prepare query, key, value for SDPA
# Expand key and value for GQA/MQA to match query heads
key_expanded = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
value_expanded = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
if num_kv_heads != num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
#
# see also: vllm/model_executor/layers/attention.py
key = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
value = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
# codebase. We save some time reshaping alibi matrix at runtime.
key = key.reshape(key.shape[0], -1, key.shape[-1])
value = value.reshape(value.shape[0], -1, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
torch.cuda.synchronize()
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/v1/attention/backends/xformers.py#L343
query_start = 0
key_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(
query[:, seq_start:seq_end],
key[:, seq_start:seq_end],
value[:, seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
key_end = key_start + seq_len
# Get query, key, value for this sequence
q = query[query_start:query_end] # [query_len, num_heads, head_size]
k = key_expanded[
key_start:key_end
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
v = value_expanded[
key_start:key_end
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
# Reshape for SDPA: (batch=1, num_heads, seq_len, head_size)
q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size)
q_sdpa = (
q_sdpa.permute(1, 2, 0, 3)
.reshape(1, num_heads, query_len, head_size)
.contiguous()
)
k_sdpa = (
k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
)
v_sdpa = (
v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
)
# Create ALiBi causal mask for this sequence using utility function
alibi_mask = create_alibi_causal_mask(
query_len, seq_len, alibi_slopes, device, dtype
)
# Compute attention
out = F.scaled_dot_product_attention(
q_sdpa,
k_sdpa,
v_sdpa,
attn_mask=alibi_mask,
dropout_p=0.0,
scale=scale,
)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size
)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...])
seq_start += seq_len
query_start += query_len
# Reshape output back to [query_len, num_heads, head_size]
out = out.view(num_heads, query_len, head_size).permute(1, 0, 2)
output_ref[query_start:query_end].copy_(out)
query_start = query_end
key_start = key_end
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)