vllm/tests/kernels/attention/test_merge_attn_states.py
Harry Mellor d6953beb91
Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-05 07:06:22 -07:00

321 lines
10 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton,
)
from vllm.platforms import current_platform
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def merge_attn_states_torch(
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS]
):
p_lse = prefix_lse
s_lse = suffix_lse
# inf -> -inf
p_lse[p_lse == torch.inf] = -torch.inf
s_lse[s_lse == torch.inf] = -torch.inf
# max_lse [NUM_HEADS, NUM_TOKENS]
max_lse = torch.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
p_lse_exp = torch.exp(p_lse)
s_lse_exp = torch.exp(s_lse)
out_se = p_lse_exp + s_lse_exp
if output_lse is not None:
output_lse = torch.log(out_se) + max_lse
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
output = prefix_output * p_scale + suffix_output * s_scale
return output, output_lse
NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536, 4096]
NUM_QUERY_HEADS = [4, 8, 16, 32, 48, 64]
HEAD_SIZES = [32, 48, 64, 96, 128, 256]
DTYPES = [torch.float32, torch.half, torch.bfloat16]
all_case_info: list[tuple] = []
def generate_markdown_table():
global all_case_info
table_header = (
"| tokens | heads | headsize | dtype "
"| device | torch | triton | cuda | speedup |"
)
table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |"
def shortly_dtype(dtype: torch.dtype) -> str:
return str(dtype).removeprefix("torch.")
def shortly_device(device: str) -> str:
return device.removeprefix("NVIDIA").strip()
print(table_header)
print(table_separator)
for info in all_case_info:
(
num_tokens,
num_heads,
head_size,
dtype,
device,
avg_time_torch_kernel,
avg_time_triton_kernel,
avg_time_cuda_kernel,
performance_improved,
) = info
dtype = shortly_dtype(dtype)
device = shortly_device(device)
print(
f"| {num_tokens} | {num_heads} | {head_size} "
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
f"| {avg_time_triton_kernel:.5f}ms "
f"| {avg_time_cuda_kernel:.5f}ms "
f"| {performance_improved:.4f}x |"
)
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
@torch.inference_mode()
def test_merge_attn_states(
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
):
if not current_platform.is_cuda():
pytest.skip(
"Currently only support compare triton merge_attn_states "
"with custom cuda merge_attn_states kernel"
)
NUM_TOKENS = num_tokens
NUM_HEADS = num_query_heads
HEAD_SIZE = head_size
print(
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"Device: {current_platform.get_device_name()}"
)
# prefix_lse and suffix_lse contain inf and normal values
prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda")
suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda")
# Generate boolean masks
mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
mask_suffix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
# Ensure that the same position is not True at the same time
combined_mask = torch.logical_and(mask_prefix, mask_suffix)
mask_prefix = torch.logical_and(mask_prefix, ~combined_mask)
mask_suffix = torch.logical_and(mask_suffix, ~combined_mask)
prefix_lse[mask_prefix] = float("inf")
suffix_lse[mask_suffix] = float("inf")
# Other input tensors (need to be initialized but
# no actual calculation needed)
output = torch.zeros(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
output_lse = torch.zeros(
(NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda"
)
prefix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
suffix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
)
warmup_times = 2
repeat_times = 20
output_torch = output.clone()
output_lse_torch = output_lse.clone()
total_time_torch_kernel = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# 0. Run the Torch kernel
prefix_lse_torch = prefix_lse.clone()
suffix_lse_torch = suffix_lse.clone()
for _ in range(warmup_times):
output_torch, output_lse_torch = merge_attn_states_torch(
output_torch,
prefix_output,
prefix_lse_torch,
suffix_output,
suffix_lse_torch,
output_lse_torch,
)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
output_torch, output_lse_torch = merge_attn_states_torch(
output_torch,
prefix_output,
prefix_lse_torch,
suffix_output,
suffix_lse_torch,
output_lse_torch,
)
end.record()
torch.cuda.synchronize()
total_time_torch_kernel += start.elapsed_time(end)
avg_time_torch_kernel = total_time_torch_kernel / repeat_times
# 1. Run the Triton kernel
output_ref_triton = output.clone()
output_lse_ref_triton = output_lse.clone()
total_time_triton_kernel = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for _ in range(warmup_times):
merge_attn_states_triton(
output_ref_triton,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_ref_triton,
)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
merge_attn_states_triton(
output_ref_triton,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_ref_triton,
)
end.record()
torch.cuda.synchronize()
total_time_triton_kernel += start.elapsed_time(end)
avg_time_triton_kernel = total_time_triton_kernel / repeat_times
# 2. Run the CUDA kernel
total_time_cuda_kernel = 0
output_cuda = output.clone()
output_lse_cuda = output_lse.clone()
for _ in range(warmup_times):
merge_attn_states_cuda(
output_cuda,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_cuda,
)
torch.cuda.synchronize()
for _ in range(repeat_times):
start.record()
merge_attn_states_cuda(
output_cuda,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
output_lse_cuda,
)
end.record()
torch.cuda.synchronize()
total_time_cuda_kernel += start.elapsed_time(end)
avg_time_cuda_kernel = total_time_cuda_kernel / repeat_times
# 3. Performance compare
performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel
print(f" Torch time: {avg_time_torch_kernel:.6f}ms")
print(f"Triton time: {avg_time_triton_kernel:.6f}ms")
print(
f" CUDA time: {avg_time_cuda_kernel:.6f}ms, "
f"Performance: {performance_improved:.5f}x"
)
print("-" * 100)
# 4. Correctness compare
# Liger Kernel: Efficient Triton Kernels for LLM Training
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
# use rtol = 1e-2 for bfloat16.
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
def diff(a: torch.Tensor, b: torch.Tensor):
max_diff = torch.max(torch.abs(a.float() - b.float()))
return max_diff
# Use Triton output as reference because we want to replace
# the Triton kernel with custom CUDA kernel for merge attn
# states operation.
output_ref = output_ref_triton
output_lse_ref = output_lse_ref_triton
torch.testing.assert_close(
output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol
)
print("Output all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}")
print("-" * 100)
torch.testing.assert_close(
output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
)
print("Output LSE all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
print(f" (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}")
print(f" (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}")
print("-" * 100)
print(
"All output values test passed! All inf values "
"are correctly replaced with -inf."
)
print("-" * 100)
device = current_platform.get_device_name()
all_case_info.append(
(
NUM_TOKENS,
NUM_HEADS,
HEAD_SIZE,
output_dtype,
device,
avg_time_torch_kernel,
avg_time_triton_kernel,
avg_time_cuda_kernel,
performance_improved,
)
)
if len(all_case_info) == (
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
):
generate_markdown_table()