[Flashinfer] Support Flashinfer TRTLLM FP8-qkv BF16/FP16-out Attention Kernel (#23647)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv 2025-09-09 11:53:07 +08:00 committed by GitHub
parent b6fbc15634
commit bba1042c6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 22 additions and 11 deletions

View File

@ -259,6 +259,7 @@ if __name__ == "__main__":
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]

View File

@ -274,6 +274,7 @@ if __name__ == "__main__":
quant_dtypes = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(FP8_DTYPE, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]

View File

@ -35,6 +35,7 @@ QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
@ -44,6 +45,7 @@ NUM_HEADS = [(64, 8), (40, 8)]
HEAD_SIZE = [128]
KV_LAYOUT = ["HND"] # currently only HND is supported
BLOCK_SIZE = [16]
WINDOW_LEFT = [-1, 127]
SOFT_CAP = [None, 50.0]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@ -57,6 +59,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
head_size: int,
kv_layout: str,
block_size: int,
window_left: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap)
output = torch.empty(ref_query.shape, dtype=dtype)
@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
window_left=window_left,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
@pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
head_size: int,
kv_layout: str,
block_size: int,
window_left: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap)
output = torch.empty(ref_query.shape, dtype=dtype)
@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
window_left=window_left,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline(
rtol, atol = 4e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
rtol, atol = 4e-2, 6e-2
else:
rtol, atol = 1e-2, 1e-2

View File

@ -258,8 +258,10 @@ class AttnFusionPass(VllmInductorPass):
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
pattern_fp8.register_if_supported(self.patterns)
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
pattern_nvfp4.register_if_supported(self.patterns)
if current_platform.is_cuda() and hasattr(torch.ops._C,
"scaled_fp4_quant"):
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0:
logger.warning(

View File

@ -194,19 +194,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
FlashInferBackend.validate_head_size(self.head_dim)
self.page_size = self.kv_cache_spec.block_size
self.enable_fusion = (
self.compilation_config.pass_config.enable_attn_fusion)
self.q_data_type = self.model_config.dtype
self.cache_dtype = self.cache_config.cache_dtype
if self.cache_dtype.startswith("fp8"):
self.kv_cache_dtype = (
FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.cache_dtype))
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
if self.enable_fusion:
self.q_data_type = self.kv_cache_dtype
else:
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype
self.q_data_type = self.kv_cache_dtype
self._cascade_wrapper = None # Wrapper for cascade attention
@ -668,8 +664,6 @@ class FlashInferImpl(AttentionImpl):
# The attn+quant fusion happens when output_scale is provided.
if output_scale is None:
assert attn_metadata.q_data_type != FP8_DTYPE, \
"Query can only be FP8 if output fusion happened."
assert output_block_scale is None, "output_block_scale "\
"is not supported when fusion has not happened"
else:
@ -697,7 +691,8 @@ class FlashInferImpl(AttentionImpl):
elif output.dtype == FP4_DTYPE:
self.o_sf_scale = layer._o_scale_float
# Insert FP8 quant for query
# Insert FP8 quant for query
if attn_metadata.q_data_type == FP8_DTYPE:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(