[Flashinfer][gpt-oss] Support FP8-qkv Flashinfer TRTLLM Sinks Attention (#25674)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-10-10 04:13:39 +08:00 committed by GitHub
parent a462331e36
commit 44f633dba1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 76 additions and 46 deletions

View File

@ -49,6 +49,7 @@ KV_LAYOUT = ["HND"] # currently only HND is supported
BLOCK_SIZE = [16]
WINDOW_LEFT = [-1, 127]
SOFT_CAP = [None, 50.0]
HAS_SINKS = [True, False]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@ -63,6 +64,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
dtype: torch.dtype,
@ -77,9 +79,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
block_size: int,
window_left: int,
soft_cap: Optional[float],
has_sinks: bool,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
current_platform.seed_everything(42)
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
@ -101,7 +104,16 @@ def test_flashinfer_trtllm_decode_with_baseline(
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
# max_q_len = 1
q_lens = torch.ones((batch_size,), dtype=torch.int32)
q_indptr = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
]
)
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
@ -112,7 +124,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
@ -148,27 +160,36 @@ def test_flashinfer_trtllm_decode_with_baseline(
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Decode
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout, use_tensor_cores=True
)
if has_sinks:
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
)
else:
sinks = None
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
qo_indptr=q_indptr,
paged_kv_indptr=kv_indptr,
paged_kv_indices=kv_indices,
paged_kv_last_page_len=kv_last_page_lens,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_size,
page_size=block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap,
q_data_type=dtype,
kv_data_type=dtype,
)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
o_scale = 1.0
o_sf_scale_float = None
if o_quant_dtype == FP8_DTYPE:
@ -202,6 +223,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
window_left=window_left,
sinks=sinks,
o_sf_scale=o_sf_scale_float,
out=output_trtllm,
)
@ -217,11 +239,13 @@ def test_flashinfer_trtllm_decode_with_baseline(
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 1e0
rtol, atol = 7e-2, 9e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 2e-2, 4e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
rtol, atol = 1e-2, 2e-2
else:
rtol, atol = 1e-2, 1e-2
(
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
@ -239,6 +263,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
dtype: torch.dtype,
@ -253,9 +278,10 @@ def test_flashinfer_trtllm_prefill_with_baseline(
block_size: int,
window_left: int,
soft_cap: Optional[float],
has_sinks: bool,
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
current_platform.seed_everything(42)
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
@ -297,7 +323,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
q_scale = 1.0
ref_query = query
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len
seq_lens = kv_lens + q_lens
@ -336,28 +362,36 @@ def test_flashinfer_trtllm_prefill_with_baseline(
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
# Baseline Prefill
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
)
if has_sinks:
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
)
else:
sinks = None
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
)
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_size,
block_size,
qo_indptr=q_indptr,
paged_kv_indptr=kv_indptr,
paged_kv_indices=kv_indices,
paged_kv_last_page_len=kv_last_page_lens,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_size,
page_size=block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap,
q_data_type=dtype,
kv_data_type=dtype,
)
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
o_scale = 1.0
o_sf_scale_float = None
if o_quant_dtype == FP8_DTYPE:
@ -395,6 +429,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
window_left=window_left,
sinks=sinks,
o_sf_scale=o_sf_scale_float,
out=output_trtllm,
)
@ -410,11 +445,11 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 4e-1, 1e0
rtol, atol = 1e-1, 2e-1
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
rtol, atol = 4e-2, 6e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
rtol, atol = 2e-2, 3e-2
else:
rtol, atol = 1e-2, 1e-2

View File

@ -269,11 +269,6 @@ def use_trtllm_attention(
# Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype():
if has_sinks:
raise RuntimeError(
"TRTLLM FP8-qkv kernel is not supported for attention sinks. "
"Use kv_cache_dtype=auto for now."
)
logger.info_once("Using TRTLLM attention (query is quantized).")
return True