mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 00:57:56 +08:00
[Flashinfer][gpt-oss] Support FP8-qkv Flashinfer TRTLLM Sinks Attention (#25674)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
parent
a462331e36
commit
44f633dba1
@ -49,6 +49,7 @@ KV_LAYOUT = ["HND"] # currently only HND is supported
|
|||||||
BLOCK_SIZE = [16]
|
BLOCK_SIZE = [16]
|
||||||
WINDOW_LEFT = [-1, 127]
|
WINDOW_LEFT = [-1, 127]
|
||||||
SOFT_CAP = [None, 50.0]
|
SOFT_CAP = [None, 50.0]
|
||||||
|
HAS_SINKS = [True, False]
|
||||||
|
|
||||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||||
|
|
||||||
@ -63,6 +64,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
|||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||||
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
||||||
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
|
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
|
||||||
|
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_trtllm_decode_with_baseline(
|
def test_flashinfer_trtllm_decode_with_baseline(
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -77,9 +79,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
window_left: int,
|
window_left: int,
|
||||||
soft_cap: Optional[float],
|
soft_cap: Optional[float],
|
||||||
|
has_sinks: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(42)
|
||||||
|
|
||||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||||
q_quant_dtype = q_quant_dtype or dtype
|
q_quant_dtype = q_quant_dtype or dtype
|
||||||
@ -101,7 +104,16 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||||
|
|
||||||
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
# max_q_len = 1
|
||||||
|
q_lens = torch.ones((batch_size,), dtype=torch.int32)
|
||||||
|
q_indptr = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([0], dtype=torch.int32),
|
||||||
|
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
||||||
if q_quant_dtype == FP8_DTYPE:
|
if q_quant_dtype == FP8_DTYPE:
|
||||||
query, q_scale = to_float8(query)
|
query, q_scale = to_float8(query)
|
||||||
ref_query = query.to(dtype) * q_scale
|
ref_query = query.to(dtype) * q_scale
|
||||||
@ -112,7 +124,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
|||||||
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
|
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||||
kv_lens[-1] = max_kv_len
|
kv_lens[-1] = max_kv_len
|
||||||
|
|
||||||
seq_lens = kv_lens
|
seq_lens = kv_lens + q_lens
|
||||||
max_seq_len = torch.max(seq_lens).item()
|
max_seq_len = torch.max(seq_lens).item()
|
||||||
|
|
||||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||||
@ -148,27 +160,36 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
|||||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||||
|
|
||||||
# Baseline Decode
|
# Baseline Decode
|
||||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
if has_sinks:
|
||||||
workspace_buffer, kv_layout, use_tensor_cores=True
|
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
|
||||||
)
|
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
|
||||||
|
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sinks = None
|
||||||
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
||||||
|
)
|
||||||
|
|
||||||
wrapper.plan(
|
wrapper.plan(
|
||||||
kv_indptr,
|
qo_indptr=q_indptr,
|
||||||
kv_indices,
|
paged_kv_indptr=kv_indptr,
|
||||||
kv_last_page_lens,
|
paged_kv_indices=kv_indices,
|
||||||
num_qo_heads,
|
paged_kv_last_page_len=kv_last_page_lens,
|
||||||
num_kv_heads,
|
num_qo_heads=num_qo_heads,
|
||||||
head_size,
|
num_kv_heads=num_kv_heads,
|
||||||
block_size,
|
head_dim_qk=head_size,
|
||||||
"NONE",
|
page_size=block_size,
|
||||||
|
causal=True,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
q_data_type=dtype,
|
|
||||||
kv_data_type=dtype,
|
|
||||||
window_left=window_left,
|
window_left=window_left,
|
||||||
logits_soft_cap=soft_cap,
|
logits_soft_cap=soft_cap,
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
|
||||||
|
|
||||||
o_scale = 1.0
|
o_scale = 1.0
|
||||||
o_sf_scale_float = None
|
o_sf_scale_float = None
|
||||||
if o_quant_dtype == FP8_DTYPE:
|
if o_quant_dtype == FP8_DTYPE:
|
||||||
@ -202,6 +223,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
|||||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||||
bmm2_scale=v_scale / o_scale,
|
bmm2_scale=v_scale / o_scale,
|
||||||
window_left=window_left,
|
window_left=window_left,
|
||||||
|
sinks=sinks,
|
||||||
o_sf_scale=o_sf_scale_float,
|
o_sf_scale=o_sf_scale_float,
|
||||||
out=output_trtllm,
|
out=output_trtllm,
|
||||||
)
|
)
|
||||||
@ -217,11 +239,13 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
|||||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||||
|
|
||||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||||
rtol, atol = 3e-1, 1e0
|
rtol, atol = 7e-2, 9e-2
|
||||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||||
rtol, atol = 5e-2, 7e-2
|
rtol, atol = 2e-2, 4e-2
|
||||||
else:
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||||
rtol, atol = 1e-2, 2e-2
|
rtol, atol = 1e-2, 2e-2
|
||||||
|
else:
|
||||||
|
rtol, atol = 1e-2, 1e-2
|
||||||
|
|
||||||
(
|
(
|
||||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
||||||
@ -239,6 +263,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
|||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||||
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
||||||
@pytest.mark.parametrize("soft_cap", [None])
|
@pytest.mark.parametrize("soft_cap", [None])
|
||||||
|
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_flashinfer_trtllm_prefill_with_baseline(
|
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
@ -253,9 +278,10 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
window_left: int,
|
window_left: int,
|
||||||
soft_cap: Optional[float],
|
soft_cap: Optional[float],
|
||||||
|
has_sinks: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(42)
|
||||||
|
|
||||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||||
q_quant_dtype = q_quant_dtype or dtype
|
q_quant_dtype = q_quant_dtype or dtype
|
||||||
@ -297,7 +323,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
|||||||
q_scale = 1.0
|
q_scale = 1.0
|
||||||
ref_query = query
|
ref_query = query
|
||||||
|
|
||||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||||
kv_lens[-1] = max_kv_len
|
kv_lens[-1] = max_kv_len
|
||||||
|
|
||||||
seq_lens = kv_lens + q_lens
|
seq_lens = kv_lens + q_lens
|
||||||
@ -336,28 +362,36 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
|||||||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
||||||
|
|
||||||
# Baseline Prefill
|
# Baseline Prefill
|
||||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
if has_sinks:
|
||||||
workspace_buffer, kv_layout
|
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
|
||||||
)
|
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
|
||||||
|
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sinks = None
|
||||||
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
||||||
|
)
|
||||||
|
|
||||||
wrapper.plan(
|
wrapper.plan(
|
||||||
q_indptr,
|
qo_indptr=q_indptr,
|
||||||
kv_indptr,
|
paged_kv_indptr=kv_indptr,
|
||||||
kv_indices,
|
paged_kv_indices=kv_indices,
|
||||||
kv_last_page_lens,
|
paged_kv_last_page_len=kv_last_page_lens,
|
||||||
num_qo_heads,
|
num_qo_heads=num_qo_heads,
|
||||||
num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_size,
|
head_dim_qk=head_size,
|
||||||
block_size,
|
page_size=block_size,
|
||||||
causal=True,
|
causal=True,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
q_data_type=dtype,
|
|
||||||
kv_data_type=dtype,
|
|
||||||
window_left=window_left,
|
window_left=window_left,
|
||||||
logits_soft_cap=soft_cap,
|
logits_soft_cap=soft_cap,
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
|
||||||
|
|
||||||
o_scale = 1.0
|
o_scale = 1.0
|
||||||
o_sf_scale_float = None
|
o_sf_scale_float = None
|
||||||
if o_quant_dtype == FP8_DTYPE:
|
if o_quant_dtype == FP8_DTYPE:
|
||||||
@ -395,6 +429,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
|||||||
cum_seq_lens_q=q_indptr,
|
cum_seq_lens_q=q_indptr,
|
||||||
cum_seq_lens_kv=kv_indptr,
|
cum_seq_lens_kv=kv_indptr,
|
||||||
window_left=window_left,
|
window_left=window_left,
|
||||||
|
sinks=sinks,
|
||||||
o_sf_scale=o_sf_scale_float,
|
o_sf_scale=o_sf_scale_float,
|
||||||
out=output_trtllm,
|
out=output_trtllm,
|
||||||
)
|
)
|
||||||
@ -410,11 +445,11 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
|||||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||||
|
|
||||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||||
rtol, atol = 4e-1, 1e0
|
rtol, atol = 1e-1, 2e-1
|
||||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||||
rtol, atol = 5e-2, 7e-2
|
|
||||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
|
||||||
rtol, atol = 4e-2, 6e-2
|
rtol, atol = 4e-2, 6e-2
|
||||||
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||||
|
rtol, atol = 2e-2, 3e-2
|
||||||
else:
|
else:
|
||||||
rtol, atol = 1e-2, 1e-2
|
rtol, atol = 1e-2, 1e-2
|
||||||
|
|
||||||
|
|||||||
@ -269,11 +269,6 @@ def use_trtllm_attention(
|
|||||||
|
|
||||||
# Must use TRTLLM attention if query is FP8 quantized
|
# Must use TRTLLM attention if query is FP8 quantized
|
||||||
if q_dtype == current_platform.fp8_dtype():
|
if q_dtype == current_platform.fp8_dtype():
|
||||||
if has_sinks:
|
|
||||||
raise RuntimeError(
|
|
||||||
"TRTLLM FP8-qkv kernel is not supported for attention sinks. "
|
|
||||||
"Use kv_cache_dtype=auto for now."
|
|
||||||
)
|
|
||||||
logger.info_once("Using TRTLLM attention (query is quantized).")
|
logger.info_once("Using TRTLLM attention (query is quantized).")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user