diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index d7ccfcddc6d6a..61157429ec9cc 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -49,6 +49,7 @@ KV_LAYOUT = ["HND"] # currently only HND is supported BLOCK_SIZE = [16] WINDOW_LEFT = [-1, 127] SOFT_CAP = [None, 50.0] +HAS_SINKS = [True, False] NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. @@ -63,6 +64,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. @pytest.mark.parametrize("block_size", BLOCK_SIZE) @pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", SOFT_CAP) +@pytest.mark.parametrize("has_sinks", HAS_SINKS) @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( dtype: torch.dtype, @@ -77,9 +79,10 @@ def test_flashinfer_trtllm_decode_with_baseline( block_size: int, window_left: int, soft_cap: Optional[float], + has_sinks: bool, ) -> None: torch.set_default_device("cuda") - current_platform.seed_everything(0) + current_platform.seed_everything(42) q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes q_quant_dtype = q_quant_dtype or dtype @@ -101,7 +104,16 @@ def test_flashinfer_trtllm_decode_with_baseline( else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + # max_q_len = 1 + q_lens = torch.ones((batch_size,), dtype=torch.int32) + q_indptr = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ] + ) + + query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) if q_quant_dtype == FP8_DTYPE: query, q_scale = to_float8(query) ref_query = query.to(dtype) * q_scale @@ -112,7 +124,7 @@ def test_flashinfer_trtllm_decode_with_baseline( kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_kv_len - seq_lens = kv_lens + seq_lens = kv_lens + q_lens max_seq_len = torch.max(seq_lens).item() kv_cache = torch.randn(kv_cache_shape, dtype=dtype) @@ -148,27 +160,36 @@ def test_flashinfer_trtllm_decode_with_baseline( workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) # Baseline Decode - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, use_tensor_cores=True - ) + if has_sinks: + sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5 + wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + else: + sinks = None + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + wrapper.plan( - kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_size, - block_size, - "NONE", + qo_indptr=q_indptr, + paged_kv_indptr=kv_indptr, + paged_kv_indices=kv_indices, + paged_kv_last_page_len=kv_last_page_lens, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_size, + page_size=block_size, + causal=True, sm_scale=sm_scale, - q_data_type=dtype, - kv_data_type=dtype, window_left=window_left, logits_soft_cap=soft_cap, + q_data_type=dtype, + kv_data_type=dtype, ) - output = torch.empty(ref_query.shape, dtype=dtype) - wrapper.run(ref_query, ref_kv_cache, out=output) + wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output) + o_scale = 1.0 o_sf_scale_float = None if o_quant_dtype == FP8_DTYPE: @@ -202,6 +223,7 @@ def test_flashinfer_trtllm_decode_with_baseline( bmm1_scale=q_scale * k_scale * sm_scale, bmm2_scale=v_scale / o_scale, window_left=window_left, + sinks=sinks, o_sf_scale=o_sf_scale_float, out=output_trtllm, ) @@ -217,11 +239,13 @@ def test_flashinfer_trtllm_decode_with_baseline( output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: - rtol, atol = 3e-1, 1e0 + rtol, atol = 7e-2, 9e-2 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: - rtol, atol = 5e-2, 7e-2 - else: + rtol, atol = 2e-2, 4e-2 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: rtol, atol = 1e-2, 2e-2 + else: + rtol, atol = 1e-2, 1e-2 ( torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), @@ -239,6 +263,7 @@ def test_flashinfer_trtllm_decode_with_baseline( @pytest.mark.parametrize("block_size", BLOCK_SIZE) @pytest.mark.parametrize("window_left", WINDOW_LEFT) @pytest.mark.parametrize("soft_cap", [None]) +@pytest.mark.parametrize("has_sinks", HAS_SINKS) @torch.inference_mode def test_flashinfer_trtllm_prefill_with_baseline( dtype: torch.dtype, @@ -253,9 +278,10 @@ def test_flashinfer_trtllm_prefill_with_baseline( block_size: int, window_left: int, soft_cap: Optional[float], + has_sinks: bool, ) -> None: torch.set_default_device("cuda") - current_platform.seed_everything(0) + current_platform.seed_everything(42) q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes q_quant_dtype = q_quant_dtype or dtype @@ -297,7 +323,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( q_scale = 1.0 ref_query = query - kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) + kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_kv_len seq_lens = kv_lens + q_lens @@ -336,28 +362,36 @@ def test_flashinfer_trtllm_prefill_with_baseline( workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8) # Baseline Prefill - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout - ) + if has_sinks: + sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5 + wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + else: + sinks = None + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2" + ) + wrapper.plan( - q_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_size, - block_size, + qo_indptr=q_indptr, + paged_kv_indptr=kv_indptr, + paged_kv_indices=kv_indices, + paged_kv_last_page_len=kv_last_page_lens, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_size, + page_size=block_size, causal=True, sm_scale=sm_scale, - q_data_type=dtype, - kv_data_type=dtype, window_left=window_left, logits_soft_cap=soft_cap, + q_data_type=dtype, + kv_data_type=dtype, ) - output = torch.empty(ref_query.shape, dtype=dtype) - wrapper.run(ref_query, ref_kv_cache, out=output) + wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output) + o_scale = 1.0 o_sf_scale_float = None if o_quant_dtype == FP8_DTYPE: @@ -395,6 +429,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, window_left=window_left, + sinks=sinks, o_sf_scale=o_sf_scale_float, out=output_trtllm, ) @@ -410,11 +445,11 @@ def test_flashinfer_trtllm_prefill_with_baseline( output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: - rtol, atol = 4e-1, 1e0 + rtol, atol = 1e-1, 2e-1 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: - rtol, atol = 5e-2, 7e-2 - elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: rtol, atol = 4e-2, 6e-2 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype: + rtol, atol = 2e-2, 3e-2 else: rtol, atol = 1e-2, 1e-2 diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 78d4a92dc1af1..b0bbdd6834f48 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -269,11 +269,6 @@ def use_trtllm_attention( # Must use TRTLLM attention if query is FP8 quantized if q_dtype == current_platform.fp8_dtype(): - if has_sinks: - raise RuntimeError( - "TRTLLM FP8-qkv kernel is not supported for attention sinks. " - "Use kv_cache_dtype=auto for now." - ) logger.info_once("Using TRTLLM attention (query is quantized).") return True