# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math import random import time from collections.abc import Callable import pytest import torch import torch.nn.functional as F from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 64] HEAD_SIZES = [24, 128] DTYPES = [torch.float16] CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] SLIDING_WINDOW = [0, 16, 2048] KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] OPS = [chunked_prefill_paged_decode, context_attention_fwd] def create_causal_attention_mask_for_sdpa( query_lens: list[int], seq_lens: list[int], sliding_window: int = 0, device: torch.device = None, dtype: torch.dtype = None, ) -> torch.Tensor: total_queries = sum(query_lens) total_keys = sum(seq_lens) # Create a mask filled with -inf mask = torch.full( (total_queries, total_keys), float("-inf"), device=device, dtype=dtype ) query_start = 0 key_start = 0 for query_len, seq_len in zip(query_lens, seq_lens): query_end = query_start + query_len key_end = key_start + seq_len q_indices = torch.arange(query_len, device=device) k_indices = torch.arange(seq_len, device=device) q_pos_in_seq = seq_len - query_len + q_indices valid_mask = k_indices[None, :] <= q_pos_in_seq[:, None] if sliding_window > 0: valid_mask &= k_indices[None, :] >= ( q_pos_in_seq[:, None] - sliding_window + 1 ) mask[query_start:query_end, key_start:key_end][valid_mask] = 0.0 query_start = query_end key_start = key_end return mask def create_alibi_causal_mask( query_len: int, seq_len: int, alibi_slopes: torch.Tensor, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: query_pos = torch.arange( seq_len - query_len, seq_len, device=device, dtype=torch.float32 ) key_pos = torch.arange(seq_len, device=device, dtype=torch.float32) rel_pos = key_pos[None, :] - query_pos[:, None] # Apply ALiBi slopes: [num_heads, query_len, seq_len] alibi_bias = alibi_slopes[:, None, None] * rel_pos[None, :, :] alibi_bias = alibi_bias.to(dtype) # Apply causal mask: prevent attending to future positions # causal_mask[i, j] = True if key_pos[j] <= query_pos[i] causal_mask = key_pos[None, :] <= query_pos[:, None] alibi_bias = alibi_bias.masked_fill(~causal_mask[None, :, :], float("-inf")) # Add batch dimension: [1, num_heads, query_len, seq_len] # SDPA expects batch dimension even for single sequences return alibi_bias.unsqueeze(0) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) @pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, num_queries_per_kv: int, head_size: int, sliding_window: int, dtype: torch.dtype, kv_cache_dtype: str, device: str, op: Callable, ) -> None: if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): pytest.skip( "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" ) if ( current_platform.is_rocm() and op is chunked_prefill_paged_decode and kv_cache_dtype == "fp8_e5m2" ): pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache") current_platform.seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process # for GPU 1 would run on both GPU0 and GPU1 and things would hang # # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 torch.cuda.set_device(device) MAX_SEQ_LEN = 1024 MAX_CTX_LEN = 1024 BS = 10 cache_size = 640 block_size = 32 max_block_per_request = 64 query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] # ensure one sequence in batch is a decode query_lens[-1] = 1 ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] num_kv_heads = num_heads // num_queries_per_kv num_tokens = sum(query_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) kv.uniform_(-1e-3, 1e-3) key, value = kv.unbind(dim=1) if kv_cache_dtype == "auto": cache_dtype = dtype else: cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] k_cache = torch.zeros( cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype ) v_cache = torch.zeros( cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32) max_input_len = MAX_SEQ_LEN # copy kv to cache b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to( torch.int32 ) for i in range(BS): for j in range(query_lens[i]): k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) cur_ctx = 0 block_id = 0 while cur_ctx < b_ctx_len[i]: start_loc = b_seq_start_loc[i] + cur_ctx if cur_ctx + block_size > b_ctx_len[i]: end_loc = b_seq_start_loc[i] + b_ctx_len[i] else: end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( key[start_loc:end_loc] ) v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( value[start_loc:end_loc] ) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] k_cache = ( k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8) .permute(0, 2, 3, 1, 4) .contiguous() ) # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = ( v_cache.view(-1, block_size, num_kv_heads, head_size) .permute(0, 2, 3, 1) .contiguous() ) k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time op( query, k, v, output, kv_cache_dtype, k_cache, v_cache, block_table, b_start_loc, b_seq_len, MAX_CTX_LEN, max_input_len, k_scale, v_scale, sliding_window=sliding_window, ) torch.cuda.synchronize() start_time = time.time() op( query, k, v, output, kv_cache_dtype, k_cache, v_cache, block_table, b_start_loc, b_seq_len, MAX_CTX_LEN, max_input_len, k_scale, v_scale, sliding_window=sliding_window, ) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) # Reshape for SDPA: (seq_len, num_heads, head_size) -> # (1, num_heads, seq_len, head_size) query_sdpa = query.view(num_tokens, num_kv_heads, num_queries_per_kv, head_size) query_sdpa = query_sdpa.permute(1, 2, 0, 3).reshape( 1, num_heads, num_tokens, head_size ) # Expand key and value for GQA/MQA to match query heads key_sdpa = key[:, :, None, :].expand( key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] ) key_sdpa = key_sdpa.permute(1, 2, 0, 3).reshape( 1, num_heads, sum(seq_lens), head_size ) value_sdpa = value[:, :, None, :].expand( value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] ) value_sdpa = value_sdpa.permute(1, 2, 0, 3).reshape( 1, num_heads, sum(seq_lens), head_size ) attn_mask = create_causal_attention_mask_for_sdpa( query_lens, seq_lens, sliding_window, device=device, dtype=dtype ) output_ref = F.scaled_dot_product_attention( query_sdpa, key_sdpa, value_sdpa, attn_mask=attn_mask, dropout_p=0.0, scale=scale, ) torch.cuda.synchronize() start_time = time.time() output_ref = F.scaled_dot_product_attention( query_sdpa, key_sdpa, value_sdpa, attn_mask=attn_mask, dropout_p=0.0, scale=scale, ) torch.cuda.synchronize() end_time = time.time() print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms") # Reshape output back to (num_tokens, num_heads, head_size) output_ref = output_ref.view(num_heads, num_tokens, head_size) output_ref = output_ref.permute(1, 0, 2).contiguous() atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention_alibi( num_heads: int, num_queries_per_kv: int, head_size: int, dtype: torch.dtype, kv_cache_dtype: str, device: str, op: Callable, ) -> None: if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): pytest.skip( "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" ) if ( current_platform.is_rocm() and op is chunked_prefill_paged_decode and kv_cache_dtype == "fp8_e5m2" ): pytest.skip("ROCm custom paged attention does not support fp8_e5m2 KV cache") current_platform.seed_everything(0) torch.set_default_device(device) # Need this, otherwise when we capture the graph the process # for GPU 1 would run on both GPU0 and GPU1 and things would hang # # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 torch.cuda.set_device(device) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) base = torch.tensor( 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) num_remaining_heads = min( closest_power_of_2, total_num_heads - closest_power_of_2 ) extra_powers = torch.arange( start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32 ) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes alibi_slopes = _get_alibi_slopes(num_heads).to(device) MAX_SEQ_LEN = 1024 MAX_CTX_LEN = 1024 BS = 10 cache_size = 640 block_size = 32 max_block_per_request = 64 query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] num_kv_heads = num_heads // num_queries_per_kv num_tokens = sum(query_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) kv.uniform_(-1e-3, 1e-3) key, value = kv.unbind(dim=1) if kv_cache_dtype == "auto": cache_dtype = dtype else: cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] k_cache = torch.zeros( cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype ) v_cache = torch.zeros( cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.int32) values = values[torch.randperm(cache_size)] block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.int32) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.int32) b_start_loc = torch.cumsum(torch.tensor([0] + query_lens), dim=0).to(torch.int32) max_input_len = MAX_SEQ_LEN # copy kv to cache b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1]), dim=0).to( torch.int32 ) for i in range(BS): for j in range(query_lens[i]): k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) cur_ctx = 0 block_id = 0 while cur_ctx < b_ctx_len[i]: start_loc = b_seq_start_loc[i] + cur_ctx if cur_ctx + block_size > b_ctx_len[i]: end_loc = b_seq_start_loc[i] + b_ctx_len[i] else: end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( key[start_loc:end_loc] ) v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( value[start_loc:end_loc] ) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] k_cache = ( k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8) .permute(0, 2, 3, 1, 4) .contiguous() ) # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = ( v_cache.view(-1, block_size, num_kv_heads, head_size) .permute(0, 2, 3, 1) .contiguous() ) k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time op( query, k, v, output, kv_cache_dtype, k_cache, v_cache, block_table, b_start_loc, b_seq_len, MAX_CTX_LEN, max_input_len, k_scale, v_scale, alibi_slopes=alibi_slopes, ) torch.cuda.synchronize() start_time = time.time() op( query, k, v, output, kv_cache_dtype, k_cache, v_cache, block_table, b_start_loc, b_seq_len, MAX_CTX_LEN, max_input_len, k_scale, v_scale, alibi_slopes=alibi_slopes, ) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) # Prepare query, key, value for SDPA # Expand key and value for GQA/MQA to match query heads key_expanded = key[:, :, None, :].expand( key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] ) value_expanded = value[:, :, None, :].expand( value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] ) output_ref = torch.empty_like(output) torch.cuda.synchronize() start_time = time.time() query_start = 0 key_start = 0 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): query_end = query_start + query_len key_end = key_start + seq_len # Get query, key, value for this sequence q = query[query_start:query_end] # [query_len, num_heads, head_size] k = key_expanded[ key_start:key_end ] # [seq_len, num_kv_heads, num_queries_per_kv, head_size] v = value_expanded[ key_start:key_end ] # [seq_len, num_kv_heads, num_queries_per_kv, head_size] # Reshape for SDPA: (batch=1, num_heads, seq_len, head_size) q_sdpa = q.view(query_len, num_kv_heads, num_queries_per_kv, head_size) q_sdpa = ( q_sdpa.permute(1, 2, 0, 3) .reshape(1, num_heads, query_len, head_size) .contiguous() ) k_sdpa = ( k.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous() ) v_sdpa = ( v.permute(1, 2, 0, 3).reshape(1, num_heads, seq_len, head_size).contiguous() ) # Create ALiBi causal mask for this sequence using utility function alibi_mask = create_alibi_causal_mask( query_len, seq_len, alibi_slopes, device, dtype ) # Compute attention out = F.scaled_dot_product_attention( q_sdpa, k_sdpa, v_sdpa, attn_mask=alibi_mask, dropout_p=0.0, scale=scale, ) # Reshape output back to [query_len, num_heads, head_size] out = out.view(num_heads, query_len, head_size).permute(1, 0, 2) output_ref[query_start:query_end].copy_(out) query_start = query_end key_start = key_end torch.cuda.synchronize() end_time = time.time() print(f"PyTorch SDPA Time: {(end_time - start_time) * 1000:.2f} ms") atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) # These tests are optional to only run when explicitly invoked # # pytest -v -s --optional \ # tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32 # # These tests are useful to test model dtype float32 on Turing devices. # We skip them to not increase the time when running tests on CI @pytest.mark.optional @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) @pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention_f32( num_heads: int, num_queries_per_kv: int, head_size: int, sliding_window: int, dtype: torch.dtype, kv_cache_dtype: str, device: str, op: Callable, ) -> None: test_contexted_kv_attention( num_heads, num_queries_per_kv, head_size, sliding_window, dtype, kv_cache_dtype, device, op, ) @pytest.mark.optional @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("op", OPS) @torch.inference_mode() def test_contexted_kv_attention_alibi_f32( num_heads: int, num_queries_per_kv: int, head_size: int, dtype: torch.dtype, kv_cache_dtype: str, device: str, op: Callable, ) -> None: test_contexted_kv_attention_alibi( num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op )