mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:44:58 +08:00
[CI/Build] Refactor Attention backend for test_prefix_prefill from xformers to SDPA (#28424)
Signed-off-by: zhewenli <zhewenli@meta.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
5a1271d83a
commit
e553424919
@ -8,10 +8,8 @@ from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tests.kernels.utils import make_alibi_bias
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.platforms import current_platform
|
||||
@ -28,6 +26,74 @@ KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
||||
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
|
||||
|
||||
|
||||
def create_causal_attention_mask_for_sdpa(
|
||||
query_lens: list[int],
|
||||
seq_lens: list[int],
|
||||
sliding_window: int = 0,
|
||||
device: torch.device = None,
|
||||
dtype: torch.dtype = None,
|
||||
) -> torch.Tensor:
|
||||
total_queries = sum(query_lens)
|
||||
total_keys = sum(seq_lens)
|
||||
|
||||
# Create a mask filled with -inf
|
||||
mask = torch.full(
|
||||
(total_queries, total_keys), float("-inf"), device=device, dtype=dtype
|
||||
)
|
||||
|
||||
query_start = 0
|
||||
key_start = 0
|
||||
|
||||
for query_len, seq_len in zip(query_lens, seq_lens):
|
||||
query_end = query_start + query_len
|
||||
key_end = key_start + seq_len
|
||||
q_indices = torch.arange(query_len, device=device)
|
||||
k_indices = torch.arange(seq_len, device=device)
|
||||
q_pos_in_seq = seq_len - query_len + q_indices
|
||||
|
||||
valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None]
|
||||
|
||||
if sliding_window > 0:
|
||||
valid_mask &= k_indices[None, :] >= (
|
||||
q_pos_in_seq[:, None] - sliding_window + 1
|
||||
)
|
||||
|
||||
mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0
|
||||
|
||||
query_start = query_end
|
||||
key_start = key_end
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def create_alibi_causal_mask(
|
||||
query_len: int,
|
||||
seq_len: int,
|
||||
alibi_slopes: torch.Tensor,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
query_pos = torch.arange(
|
||||
seq_len - query_len, seq_len, device=device, dtype=torch.float32
|
||||
)
|
||||
key_pos = torch.arange(seq_len, device=device, dtype=torch.float32)
|
||||
|
||||
rel_pos = key_pos[None, :] - query_pos[:, None]
|
||||
|
||||
# Apply ALiBi slopes: [num_heads, query_len, seq_len]
|
||||
alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :]
|
||||
alibi_bias = alibi_bias.to(dtype)
|
||||
|
||||
# Apply causal mask: prevent attending to future positions
|
||||
# causal_mask[i, j] = True if key_pos[j] <= query_pos[i]
|
||||
causal_mask = key_pos[None, :] <= query_pos[:, None]
|
||||
alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf"))
|
||||
|
||||
# Add batch dimension: [1, num_heads, query_len, seq_len]
|
||||
# SDPA expects batch dimension even for single sequences
|
||||
return alibi_bias.unsqueeze(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@ -52,6 +118,13 @@ def test_contexted_kv_attention(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and op is chunked_prefill_paged_decode
|
||||
and kv_cache_dtype == "fp8_e5m2"
|
||||
):
|
||||
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@ -96,16 +169,16 @@ def test_contexted_kv_attention(
|
||||
)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = torch.arange(0, cache_size, dtype=torch.int32)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
@ -189,56 +262,57 @@ def test_contexted_kv_attention(
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
if num_kv_heads != num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
#
|
||||
# see also: vllm/model_executor/layers/attention.py
|
||||
query = query.view(
|
||||
query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1]
|
||||
)
|
||||
key = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
value = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
query_lens, seq_lens
|
||||
# Reshape for SDPA: (seq_len, num_heads, head_size) ->
|
||||
# (1, num_heads, seq_len, head_size)
|
||||
query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size)
|
||||
query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape(
|
||||
1, num_heads, num_tokens, head_size
|
||||
)
|
||||
if sliding_window > 0:
|
||||
attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window)
|
||||
output_ref = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
|
||||
# Expand key and value for GQA/MQA to match query heads
|
||||
key_sdpa = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape(
|
||||
1, num_heads, sum(seq_lens), head_size
|
||||
)
|
||||
|
||||
value_sdpa = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape(
|
||||
1, num_heads, sum(seq_lens), head_size
|
||||
)
|
||||
|
||||
attn_mask = create_causal_attention_mask_for_sdpa(
|
||||
query_lens, seq_lens, sliding_window, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
output_ref = F.scaled_dot_product_attention(
|
||||
query_sdpa,
|
||||
key_sdpa,
|
||||
value_sdpa,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
output_ref = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
output_ref = F.scaled_dot_product_attention(
|
||||
query_sdpa,
|
||||
key_sdpa,
|
||||
value_sdpa,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
output_ref = output_ref.reshape(output.shape)
|
||||
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
|
||||
# Reshape output back to (num_tokens, num_heads, head_size)
|
||||
output_ref = output_ref.view(num_heads, num_tokens, head_size)
|
||||
output_ref = output_ref.permute(1, 0, 2).contiguous()
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
@ -265,6 +339,13 @@ def test_contexted_kv_attention_alibi(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and op is chunked_prefill_paged_decode
|
||||
and kv_cache_dtype == "fp8_e5m2"
|
||||
):
|
||||
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@ -331,16 +412,16 @@ def test_contexted_kv_attention_alibi(
|
||||
)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = torch.arange(0, cache_size, dtype=torch.int32)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.int32), dim=0
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
@ -423,78 +504,75 @@ def test_contexted_kv_attention_alibi(
|
||||
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
|
||||
# we have to pad query tensor before MQA/GQA expanding.
|
||||
if query.shape[0] != key.shape[0]:
|
||||
query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype)
|
||||
query_pad.uniform_(-1e-3, 1e-3)
|
||||
seq_start = 0
|
||||
query_start = 0
|
||||
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
||||
seq_end = seq_start + seq_len
|
||||
query_end = query_start + query_len
|
||||
query_pad[seq_start:seq_end, ...] = torch.cat(
|
||||
[
|
||||
torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype),
|
||||
query[query_start:query_end, ...],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
seq_start += seq_len
|
||||
query_start += query_len
|
||||
query = query_pad
|
||||
# Prepare query, key, value for SDPA
|
||||
# Expand key and value for GQA/MQA to match query heads
|
||||
key_expanded = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
value_expanded = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
|
||||
if num_kv_heads != num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
#
|
||||
# see also: vllm/model_executor/layers/attention.py
|
||||
key = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
value = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
|
||||
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
|
||||
# codebase. We save some time reshaping alibi matrix at runtime.
|
||||
key = key.reshape(key.shape[0], -1, key.shape[-1])
|
||||
value = value.reshape(value.shape[0], -1, value.shape[-1])
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
|
||||
output_ref = torch.empty_like(output)
|
||||
seq_start = 0
|
||||
query_start = 0
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
# Attention with alibi slopes.
|
||||
# FIXME(DefTruth): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
# modified from: vllm/v1/attention/backends/xformers.py#L343
|
||||
|
||||
query_start = 0
|
||||
key_start = 0
|
||||
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
||||
seq_end = seq_start + seq_len
|
||||
query_end = query_start + query_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[:, seq_start:seq_end],
|
||||
key[:, seq_start:seq_end],
|
||||
value[:, seq_start:seq_end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
key_end = key_start + seq_len
|
||||
|
||||
# Get query, key, value for this sequence
|
||||
q = query[query_start:query_end] # [query_len, num_heads, head_size]
|
||||
k = key_expanded[
|
||||
key_start:key_end
|
||||
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
|
||||
v = value_expanded[
|
||||
key_start:key_end
|
||||
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
|
||||
|
||||
# Reshape for SDPA: (batch=1, num_heads, seq_len, head_size)
|
||||
q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size)
|
||||
q_sdpa = (
|
||||
q_sdpa.permute(1, 2, 0, 3)
|
||||
.reshape(1, num_heads, query_len, head_size)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
k_sdpa = (
|
||||
k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
|
||||
)
|
||||
v_sdpa = (
|
||||
v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
|
||||
)
|
||||
|
||||
# Create ALiBi causal mask for this sequence using utility function
|
||||
alibi_mask = create_alibi_causal_mask(
|
||||
query_len, seq_len, alibi_slopes, device, dtype
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
out = F.scaled_dot_product_attention(
|
||||
q_sdpa,
|
||||
k_sdpa,
|
||||
v_sdpa,
|
||||
attn_mask=alibi_mask,
|
||||
dropout_p=0.0,
|
||||
scale=scale,
|
||||
)
|
||||
out = out.view_as(query[:, seq_start:seq_end]).view(
|
||||
seq_len, num_heads, head_size
|
||||
)
|
||||
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...])
|
||||
seq_start += seq_len
|
||||
query_start += query_len
|
||||
|
||||
# Reshape output back to [query_len, num_heads, head_size]
|
||||
out = out.view(num_heads, query_len, head_size).permute(1, 0, 2)
|
||||
output_ref[query_start:query_end].copy_(out)
|
||||
|
||||
query_start = query_end
|
||||
key_start = key_end
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user