vllm/tests/kernels/attention/test_prefix_prefill.py
Ryan Rock 68d7231991
[CI/Build] Fix test_prefix_prefill for AMD (#28905)
Signed-off-by: Ryan Rock <ryan.rock@amd.com>
2025-11-19 16:04:36 -05:00

640 lines
21 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import random
import time
from collections.abc import Callable
import pytest
import torch
import torch.nn.functional as F
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 64]
HEAD_SIZES = [24, 128]
DTYPES = [torch.float16]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
SLIDING_WINDOW = [0, 16, 2048]
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
OPS = [chunked_prefill_paged_decode, context_attention_fwd]
def create_causal_attention_mask_for_sdpa(
query_lens: list[int],
seq_lens: list[int],
sliding_window: int = 0,
device: torch.device = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
total_queries = sum(query_lens)
total_keys = sum(seq_lens)
# Create a mask filled with -inf
mask = torch.full(
(total_queries, total_keys), float("-inf"), device=device, dtype=dtype
)
query_start = 0
key_start = 0
for query_len, seq_len in zip(query_lens, seq_lens):
query_end = query_start + query_len
key_end = key_start + seq_len
q_indices = torch.arange(query_len, device=device)
k_indices = torch.arange(seq_len, device=device)
q_pos_in_seq = seq_len - query_len + q_indices
valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None]
if sliding_window > 0:
valid_mask &= k_indices[None, :] >= (
q_pos_in_seq[:, None] - sliding_window + 1
)
mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0
query_start = query_end
key_start = key_end
return mask
def create_alibi_causal_mask(
query_len: int,
seq_len: int,
alibi_slopes: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
query_pos = torch.arange(
seq_len - query_len, seq_len, device=device, dtype=torch.float32
)
key_pos = torch.arange(seq_len, device=device, dtype=torch.float32)
rel_pos = key_pos[None, :] - query_pos[:, None]
# Apply ALiBi slopes: [num_heads, query_len, seq_len]
alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :]
alibi_bias = alibi_bias.to(dtype)
# Apply causal mask: prevent attending to future positions
# causal_mask[i, j] = True if key_pos[j] <= query_pos[i]
causal_mask = key_pos[None, :] <= query_pos[:, None]
alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf"))
# Add batch dimension: [1, num_heads, query_len, seq_len]
# SDPA expects batch dimension even for single sequences
return alibi_bias.unsqueeze(0)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
sliding_window: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
if (
current_platform.is_rocm()
and op is chunked_prefill_paged_decode
and kv_cache_dtype == "fp8_e5m2"
):
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
current_platform.seed_everything(0)
torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)
MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
# ensure one sequence in batch is a decode
query_lens[-1] = 1
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)
if kv_cache_dtype == "auto":
cache_dtype = dtype
else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
k_cache = torch.zeros(
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
)
v_cache = torch.zeros(
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.int32)
values = values[torch.randperm(cache_size)]
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
torch.int32
)
for i in range(BS):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else:
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc]
)
v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc]
)
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = (
k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
.permute(0, 2, 3, 1, 4)
.contiguous()
)
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = (
v_cache.view(-1, block_size, num_kv_heads, head_size)
.permute(0, 2, 3, 1)
.contiguous()
)
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
op(
query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window,
)
torch.cuda.synchronize()
start_time = time.time()
op(
query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
# Reshape for SDPA: (seq_len, num_heads, head_size) ->
# (1, num_heads, seq_len, head_size)
query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size)
query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape(
1, num_heads, num_tokens, head_size
)
# Expand key and value for GQA/MQA to match query heads
key_sdpa = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape(
1, num_heads, sum(seq_lens), head_size
)
value_sdpa = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape(
1, num_heads, sum(seq_lens), head_size
)
attn_mask = create_causal_attention_mask_for_sdpa(
query_lens, seq_lens, sliding_window, device=device, dtype=dtype
)
output_ref = F.scaled_dot_product_attention(
query_sdpa,
key_sdpa,
value_sdpa,
attn_mask=attn_mask,
dropout_p=0.0,
scale=scale,
)
torch.cuda.synchronize()
start_time = time.time()
output_ref = F.scaled_dot_product_attention(
query_sdpa,
key_sdpa,
value_sdpa,
attn_mask=attn_mask,
dropout_p=0.0,
scale=scale,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
# Reshape output back to (num_tokens, num_heads, head_size)
output_ref = output_ref.view(num_heads, num_tokens, head_size)
output_ref = output_ref.permute(1, 0, 2).contiguous()
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
pytest.skip(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
if (
current_platform.is_rocm()
and op is chunked_prefill_paged_decode
and kv_cache_dtype == "fp8_e5m2"
):
pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache")
current_platform.seed_everything(0)
torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(
closest_power_of_2, total_num_heads - closest_power_of_2
)
extra_powers = torch.arange(
start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
alibi_slopes = _get_alibi_slopes(num_heads).to(device)
MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
num_tokens = sum(query_lens)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)
if kv_cache_dtype == "auto":
cache_dtype = dtype
else:
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
k_cache = torch.zeros(
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
)
v_cache = torch.zeros(
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.int32)
values = values[torch.randperm(cache_size)]
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.int32)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to(
torch.int32
)
for i in range(BS):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else:
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
key[start_loc:end_loc]
)
v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
value[start_loc:end_loc]
)
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = (
k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
.permute(0, 2, 3, 1, 4)
.contiguous()
)
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = (
v_cache.view(-1, block_size, num_kv_heads, head_size)
.permute(0, 2, 3, 1)
.contiguous()
)
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring
# generation time
op(
query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes,
)
torch.cuda.synchronize()
start_time = time.time()
op(
query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes,
)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
scale = float(1.0 / (head_size**0.5))
# Prepare query, key, value for SDPA
# Expand key and value for GQA/MQA to match query heads
key_expanded = key[:, :, None, :].expand(
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
)
value_expanded = value[:, :, None, :].expand(
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
)
output_ref = torch.empty_like(output)
torch.cuda.synchronize()
start_time = time.time()
query_start = 0
key_start = 0
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
query_end = query_start + query_len
key_end = key_start + seq_len
# Get query, key, value for this sequence
q = query[query_start:query_end] # [query_len, num_heads, head_size]
k = key_expanded[
key_start:key_end
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
v = value_expanded[
key_start:key_end
] # [seq_len, num_kv_heads, num_queries_per_kv, head_size]
# Reshape for SDPA: (batch=1, num_heads, seq_len, head_size)
q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size)
q_sdpa = (
q_sdpa.permute(1, 2, 0, 3)
.reshape(1, num_heads, query_len, head_size)
.contiguous()
)
k_sdpa = (
k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
)
v_sdpa = (
v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous()
)
# Create ALiBi causal mask for this sequence using utility function
alibi_mask = create_alibi_causal_mask(
query_len, seq_len, alibi_slopes, device, dtype
)
# Compute attention
out = F.scaled_dot_product_attention(
q_sdpa,
k_sdpa,
v_sdpa,
attn_mask=alibi_mask,
dropout_p=0.0,
scale=scale,
)
# Reshape output back to [query_len, num_heads, head_size]
out = out.view(num_heads, query_len, head_size).permute(1, 0, 2)
output_ref[query_start:query_end].copy_(out)
query_start = query_end
key_start = key_end
torch.cuda.synchronize()
end_time = time.time()
print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
# These tests are optional to only run when explicitly invoked
#
# pytest -v -s --optional \
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
#
# These tests are useful to test model dtype float32 on Turing devices.
# We skip them to not increase the time when running tests on CI
@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_f32(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
sliding_window: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:
test_contexted_kv_attention(
num_heads,
num_queries_per_kv,
head_size,
sliding_window,
dtype,
kv_cache_dtype,
device,
op,
)
@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_alibi_f32(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:
test_contexted_kv_attention_alibi(
num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
)