mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:15:17 +08:00
[Flashinfer] Support Flashinfer TRTLLM FP8-qkv BF16/FP16-out Attention Kernel (#23647)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
parent
b6fbc15634
commit
bba1042c6f
@ -259,6 +259,7 @@ if __name__ == "__main__":
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
@ -274,6 +274,7 @@ if __name__ == "__main__":
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
@ -35,6 +35,7 @@ QUANT_DTYPES = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
@ -44,6 +45,7 @@ NUM_HEADS = [(64, 8), (40, 8)]
|
||||
HEAD_SIZE = [128]
|
||||
KV_LAYOUT = ["HND"] # currently only HND is supported
|
||||
BLOCK_SIZE = [16]
|
||||
WINDOW_LEFT = [-1, 127]
|
||||
SOFT_CAP = [None, 50.0]
|
||||
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
@ -57,6 +59,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_decode_with_baseline(
|
||||
@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
head_size: int,
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
window_left: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
window_left=window_left,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
head_size: int,
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
window_left: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
window_left=window_left,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
rtol, atol = 4e-1, 1e0
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 5e-2, 7e-2
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||
rtol, atol = 4e-2, 6e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
|
||||
@ -258,8 +258,10 @@ class AttnFusionPass(VllmInductorPass):
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C,
|
||||
"scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
|
||||
@ -194,19 +194,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
FlashInferBackend.validate_head_size(self.head_dim)
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
self.enable_fusion = (
|
||||
self.compilation_config.pass_config.enable_attn_fusion)
|
||||
self.q_data_type = self.model_config.dtype
|
||||
self.cache_dtype = self.cache_config.cache_dtype
|
||||
if self.cache_dtype.startswith("fp8"):
|
||||
self.kv_cache_dtype = (
|
||||
FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
self.cache_dtype))
|
||||
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
|
||||
if self.enable_fusion:
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
assert self.kv_cache_spec.dtype == self.model_config.dtype
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
@ -668,8 +664,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
# The attn+quant fusion happens when output_scale is provided.
|
||||
if output_scale is None:
|
||||
assert attn_metadata.q_data_type != FP8_DTYPE, \
|
||||
"Query can only be FP8 if output fusion happened."
|
||||
assert output_block_scale is None, "output_block_scale "\
|
||||
"is not supported when fusion has not happened"
|
||||
else:
|
||||
@ -697,7 +691,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
elif output.dtype == FP4_DTYPE:
|
||||
self.o_sf_scale = layer._o_scale_float
|
||||
|
||||
# Insert FP8 quant for query
|
||||
# Insert FP8 quant for query
|
||||
if attn_metadata.q_data_type == FP8_DTYPE:
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user