mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
457 lines
15 KiB
Python
457 lines
15 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import flashinfer
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.kernels.quantization.nvfp4_utils import (
|
|
dequantize_nvfp4_to_dtype,
|
|
get_nvfp4_global_scale,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.math_utils import round_up
|
|
|
|
if not current_platform.is_device_capability(100):
|
|
pytest.skip(
|
|
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
|
)
|
|
|
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
|
FP8_DTYPE = current_platform.fp8_dtype()
|
|
FP4_DTYPE = torch.uint8
|
|
|
|
|
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
|
finfo = torch.finfo(dtype)
|
|
min_val, max_val = x.aminmax()
|
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
|
scale = finfo.max / amax * 0.1
|
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
|
|
|
|
|
DTYPE = [torch.bfloat16]
|
|
QUANT_DTYPES = [
|
|
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
|
(None, None, None),
|
|
(None, FP8_DTYPE, None),
|
|
(FP8_DTYPE, FP8_DTYPE, None),
|
|
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
|
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
|
]
|
|
BATCH_SIZE = [4, 12]
|
|
MAX_SEQ_LENS = [(1024, 4096)]
|
|
NUM_HEADS = [(64, 8), (40, 8)]
|
|
HEAD_SIZE = [128]
|
|
KV_LAYOUT = ["HND"] # currently only HND is supported
|
|
BLOCK_SIZE = [16]
|
|
WINDOW_LEFT = [-1, 127]
|
|
SOFT_CAP = [None, 50.0]
|
|
HAS_SINKS = [True, False]
|
|
|
|
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", DTYPE)
|
|
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
|
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
|
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
|
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
|
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
|
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
|
|
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
|
|
@torch.inference_mode
|
|
def test_flashinfer_trtllm_decode_with_baseline(
|
|
dtype: torch.dtype,
|
|
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
|
|
batch_size: int,
|
|
max_seq_lens: tuple[int, int],
|
|
num_heads: tuple[int, int],
|
|
head_size: int,
|
|
kv_layout: str,
|
|
block_size: int,
|
|
window_left: int,
|
|
soft_cap: float | None,
|
|
has_sinks: bool,
|
|
) -> None:
|
|
torch.set_default_device("cuda")
|
|
current_platform.seed_everything(42)
|
|
|
|
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
|
q_quant_dtype = q_quant_dtype or dtype
|
|
kv_quant_dtype = kv_quant_dtype or dtype
|
|
o_quant_dtype = o_quant_dtype or dtype
|
|
|
|
_, max_kv_len = max_seq_lens
|
|
|
|
num_qo_heads, num_kv_heads = num_heads
|
|
assert num_qo_heads % num_kv_heads == 0
|
|
|
|
sm_scale = float(1.0 / (head_size**0.5))
|
|
|
|
kv_cache_shape = None
|
|
if kv_layout == "NHD":
|
|
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
|
elif kv_layout == "HND":
|
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
|
else:
|
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
|
|
|
# max_q_len = 1
|
|
q_lens = torch.ones((batch_size,), dtype=torch.int32)
|
|
q_indptr = torch.cat(
|
|
[
|
|
torch.tensor([0], dtype=torch.int32),
|
|
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
|
]
|
|
)
|
|
|
|
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
|
if q_quant_dtype == FP8_DTYPE:
|
|
query, q_scale = to_float8(query)
|
|
ref_query = query.to(dtype) * q_scale
|
|
else:
|
|
q_scale = 1.0
|
|
ref_query = query
|
|
|
|
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
|
|
kv_lens[-1] = max_kv_len
|
|
|
|
seq_lens = kv_lens + q_lens
|
|
max_seq_len = torch.max(seq_lens).item()
|
|
|
|
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
|
if kv_quant_dtype == FP8_DTYPE:
|
|
kv_cache, kv_scale = to_float8(kv_cache)
|
|
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
|
else:
|
|
kv_scale = 1.0
|
|
ref_kv_cache = kv_cache
|
|
k_scale = v_scale = kv_scale
|
|
|
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
|
block_tables = torch.randint(
|
|
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
|
)
|
|
kv_indptr = [0]
|
|
kv_indices = []
|
|
kv_last_page_lens = []
|
|
for i in range(batch_size):
|
|
seq_len = seq_lens[i]
|
|
assert seq_len > 0
|
|
num_blocks = (seq_len + block_size - 1) // block_size
|
|
kv_indices.extend(block_tables[i, :num_blocks])
|
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
|
kv_last_page_len = seq_len % block_size
|
|
if kv_last_page_len == 0:
|
|
kv_last_page_len = block_size
|
|
kv_last_page_lens.append(kv_last_page_len)
|
|
|
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
|
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
|
|
|
# Baseline Decode
|
|
if has_sinks:
|
|
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
|
|
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
|
|
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
|
)
|
|
else:
|
|
sinks = None
|
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
|
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
|
)
|
|
|
|
wrapper.plan(
|
|
qo_indptr=q_indptr,
|
|
paged_kv_indptr=kv_indptr,
|
|
paged_kv_indices=kv_indices,
|
|
paged_kv_last_page_len=kv_last_page_lens,
|
|
num_qo_heads=num_qo_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
head_dim_qk=head_size,
|
|
page_size=block_size,
|
|
causal=True,
|
|
sm_scale=sm_scale,
|
|
window_left=window_left,
|
|
logits_soft_cap=soft_cap,
|
|
q_data_type=dtype,
|
|
kv_data_type=dtype,
|
|
)
|
|
output = torch.empty(ref_query.shape, dtype=dtype)
|
|
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
|
|
|
|
o_scale = 1.0
|
|
o_sf_scale_float = None
|
|
if o_quant_dtype == FP8_DTYPE:
|
|
_, o_scale = to_float8(output)
|
|
elif o_quant_dtype == FP4_DTYPE:
|
|
o_sf_scale = get_nvfp4_global_scale(output)
|
|
o_sf_scale_float = o_sf_scale.item()
|
|
|
|
# TRTLLM Decode
|
|
if o_quant_dtype == FP4_DTYPE:
|
|
output_trtllm = flashinfer.utils.FP4Tensor(
|
|
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
|
torch.empty(
|
|
(
|
|
round_up(query.shape[0], 128),
|
|
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
|
),
|
|
dtype=torch.float8_e4m3fn,
|
|
),
|
|
)
|
|
else:
|
|
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
|
|
|
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
|
query=query,
|
|
kv_cache=kv_cache,
|
|
workspace_buffer=workspace_buffer,
|
|
block_tables=block_tables,
|
|
seq_lens=seq_lens,
|
|
max_seq_len=max_seq_len,
|
|
bmm1_scale=q_scale * k_scale * sm_scale,
|
|
bmm2_scale=v_scale / o_scale,
|
|
window_left=window_left,
|
|
sinks=sinks,
|
|
o_sf_scale=o_sf_scale_float,
|
|
out=output_trtllm,
|
|
)
|
|
if o_quant_dtype == FP8_DTYPE:
|
|
output_trtllm = output_trtllm.to(dtype) * o_scale
|
|
elif o_quant_dtype == FP4_DTYPE:
|
|
output_trtllm.data = output_trtllm.data.reshape(
|
|
-1, query.shape[1] * query.shape[2] // 2
|
|
)
|
|
output_trtllm = dequantize_nvfp4_to_dtype(
|
|
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
|
|
)
|
|
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
|
|
|
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
|
rtol, atol = 7e-2, 9e-2
|
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
|
rtol, atol = 3e-2, 4e-2
|
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
|
rtol, atol = 2e-2, 2e-2
|
|
elif kv_quant_dtype == FP8_DTYPE:
|
|
rtol, atol = 4e-2, 6e-2
|
|
else:
|
|
rtol, atol = 1e-2, 1e-2
|
|
|
|
(
|
|
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
|
f"{torch.max(torch.abs(output - output_trtllm))}",
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", DTYPE)
|
|
@pytest.mark.parametrize("quant_dtypes", QUANT_DTYPES)
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
|
@pytest.mark.parametrize("max_seq_lens", MAX_SEQ_LENS)
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
|
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
|
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
|
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
|
@pytest.mark.parametrize("soft_cap", [None])
|
|
@pytest.mark.parametrize("has_sinks", HAS_SINKS)
|
|
@torch.inference_mode
|
|
def test_flashinfer_trtllm_prefill_with_baseline(
|
|
dtype: torch.dtype,
|
|
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
|
|
batch_size: int,
|
|
max_seq_lens: tuple[int, int],
|
|
num_heads: tuple[int, int],
|
|
head_size: int,
|
|
kv_layout: str,
|
|
block_size: int,
|
|
window_left: int,
|
|
soft_cap: float | None,
|
|
has_sinks: bool,
|
|
) -> None:
|
|
torch.set_default_device("cuda")
|
|
current_platform.seed_everything(42)
|
|
|
|
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
|
q_quant_dtype = q_quant_dtype or dtype
|
|
kv_quant_dtype = kv_quant_dtype or dtype
|
|
o_quant_dtype = o_quant_dtype or dtype
|
|
|
|
if q_quant_dtype != kv_quant_dtype:
|
|
pytest.skip("Skipped mixed QKV dtypes for prefill")
|
|
|
|
max_q_len, max_kv_len = max_seq_lens
|
|
|
|
num_qo_heads, num_kv_heads = num_heads
|
|
assert num_qo_heads % num_kv_heads == 0
|
|
|
|
sm_scale = float(1.0 / (head_size**0.5))
|
|
|
|
kv_cache_shape = None
|
|
if kv_layout == "NHD":
|
|
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
|
elif kv_layout == "HND":
|
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
|
else:
|
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
|
|
|
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
|
|
q_lens[-1] = max_q_len
|
|
q_indptr = torch.cat(
|
|
[
|
|
torch.tensor([0], dtype=torch.int32),
|
|
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
|
]
|
|
)
|
|
|
|
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
|
if q_quant_dtype == FP8_DTYPE:
|
|
query, q_scale = to_float8(query)
|
|
ref_query = query.to(dtype) * q_scale
|
|
else:
|
|
q_scale = 1.0
|
|
ref_query = query
|
|
|
|
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
|
|
kv_lens[-1] = max_kv_len
|
|
|
|
seq_lens = kv_lens + q_lens
|
|
max_seq_len = torch.max(seq_lens).item()
|
|
|
|
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
|
if kv_quant_dtype == FP8_DTYPE:
|
|
kv_cache, kv_scale = to_float8(kv_cache)
|
|
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
|
else:
|
|
kv_scale = 1.0
|
|
ref_kv_cache = kv_cache
|
|
k_scale = v_scale = kv_scale
|
|
|
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
|
block_tables = torch.randint(
|
|
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
|
)
|
|
kv_indptr = [0]
|
|
kv_indices = []
|
|
kv_last_page_lens = []
|
|
for i in range(batch_size):
|
|
seq_len = seq_lens[i]
|
|
assert seq_len > 0
|
|
num_blocks = (seq_len + block_size - 1) // block_size
|
|
kv_indices.extend(block_tables[i, :num_blocks])
|
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
|
kv_last_page_len = seq_len % block_size
|
|
if kv_last_page_len == 0:
|
|
kv_last_page_len = block_size
|
|
kv_last_page_lens.append(kv_last_page_len)
|
|
|
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
|
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
|
|
|
|
# Baseline Prefill
|
|
if has_sinks:
|
|
sinks = torch.rand(num_qo_heads, dtype=torch.float32) * 5
|
|
wrapper = flashinfer.BatchAttentionWithAttentionSinkWrapper(
|
|
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
|
)
|
|
else:
|
|
sinks = None
|
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
|
float_workspace_buffer=workspace_buffer, kv_layout=kv_layout, backend="fa2"
|
|
)
|
|
|
|
wrapper.plan(
|
|
qo_indptr=q_indptr,
|
|
paged_kv_indptr=kv_indptr,
|
|
paged_kv_indices=kv_indices,
|
|
paged_kv_last_page_len=kv_last_page_lens,
|
|
num_qo_heads=num_qo_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
head_dim_qk=head_size,
|
|
page_size=block_size,
|
|
causal=True,
|
|
sm_scale=sm_scale,
|
|
window_left=window_left,
|
|
logits_soft_cap=soft_cap,
|
|
q_data_type=dtype,
|
|
kv_data_type=dtype,
|
|
)
|
|
output = torch.empty(ref_query.shape, dtype=dtype)
|
|
wrapper.run(ref_query, ref_kv_cache, sinks, sm_scale, out=output)
|
|
|
|
o_scale = 1.0
|
|
o_sf_scale_float = None
|
|
if o_quant_dtype == FP8_DTYPE:
|
|
_, o_scale = to_float8(output)
|
|
elif o_quant_dtype == FP4_DTYPE:
|
|
o_sf_scale = get_nvfp4_global_scale(output)
|
|
o_sf_scale_float = o_sf_scale.item()
|
|
|
|
# TRTLLM Prefill
|
|
if o_quant_dtype == FP4_DTYPE:
|
|
output_trtllm = flashinfer.utils.FP4Tensor(
|
|
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
|
torch.empty(
|
|
(
|
|
round_up(query.shape[0], 128),
|
|
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
|
),
|
|
dtype=torch.float8_e4m3fn,
|
|
),
|
|
)
|
|
else:
|
|
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
|
|
|
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
|
query=query,
|
|
kv_cache=kv_cache,
|
|
workspace_buffer=workspace_buffer,
|
|
block_tables=block_tables,
|
|
seq_lens=seq_lens,
|
|
max_q_len=max_q_len,
|
|
max_kv_len=max_seq_len,
|
|
bmm1_scale=q_scale * k_scale * sm_scale,
|
|
bmm2_scale=v_scale / o_scale,
|
|
batch_size=batch_size,
|
|
cum_seq_lens_q=q_indptr,
|
|
cum_seq_lens_kv=kv_indptr,
|
|
window_left=window_left,
|
|
sinks=sinks,
|
|
o_sf_scale=o_sf_scale_float,
|
|
out=output_trtllm,
|
|
)
|
|
if o_quant_dtype == FP8_DTYPE:
|
|
output_trtllm = output_trtllm.to(dtype) * o_scale
|
|
elif o_quant_dtype == FP4_DTYPE:
|
|
output_trtllm.data = output_trtllm.data.reshape(
|
|
-1, query.shape[1] * query.shape[2] // 2
|
|
)
|
|
output_trtllm = dequantize_nvfp4_to_dtype(
|
|
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
|
|
)
|
|
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
|
|
|
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
|
rtol, atol = 1e-1, 2e-1
|
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
|
rtol, atol = 4e-2, 6e-2
|
|
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
|
rtol, atol = 2e-2, 3e-2
|
|
else:
|
|
rtol, atol = 1e-2, 1e-2
|
|
|
|
(
|
|
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
|
f"{torch.max(torch.abs(output - output_trtllm))}",
|
|
)
|