mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
294 lines
10 KiB
Python
294 lines
10 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Optional
|
|
|
|
import flashinfer
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
if not current_platform.is_device_capability(100):
|
|
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
|
allow_module_level=True)
|
|
|
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
|
|
|
# KV Cache Layout for TRT-LLM
|
|
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
|
|
|
MAX_Q_LEN = 1024
|
|
MAX_KV_LEN = 4096
|
|
BATCH_SIZES = [4, 12]
|
|
NUM_HEADS = [(16, 16), (40, 8)]
|
|
HEAD_SIZES = [128]
|
|
BLOCK_SIZES = [16]
|
|
KV_LAYOUTS = ["HND"]
|
|
DTYPES = [torch.bfloat16]
|
|
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
|
|
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
|
SOFT_CAPS = [None, 50.0]
|
|
|
|
|
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
|
finfo = torch.finfo(dtype)
|
|
min_val, max_val = x.aminmax()
|
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
|
scale = finfo.max / amax * 0.1
|
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
|
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
|
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
|
@torch.inference_mode
|
|
def test_flashinfer_trtllm_decode_with_baseline(
|
|
batch_size: int,
|
|
num_heads: tuple[int, int],
|
|
head_size: int,
|
|
block_size: int,
|
|
kv_layout: str,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: Optional[torch.dtype],
|
|
soft_cap: Optional[float],
|
|
) -> None:
|
|
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
|
|
|
torch.set_default_device("cuda")
|
|
current_platform.seed_everything(0)
|
|
|
|
kv_lens = torch.randint(1, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
|
|
kv_lens[-1] = MAX_KV_LEN
|
|
max_kv_len = torch.max(kv_lens).item()
|
|
num_seqs = len(kv_lens)
|
|
|
|
num_query_heads = num_heads[0]
|
|
num_kv_heads = num_heads[1]
|
|
assert num_query_heads % num_kv_heads == 0
|
|
|
|
scale = head_size**-0.5
|
|
|
|
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
|
|
|
kv_cache_shape = None
|
|
if kv_layout == "NHD":
|
|
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
|
elif kv_layout == "HND":
|
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
|
else:
|
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
|
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
|
kv_scale = 1.0
|
|
if kv_cache_dtype is current_platform.fp8_dtype():
|
|
key_value_cache, kv_scale = to_float8(key_value_cache,
|
|
current_platform.fp8_dtype())
|
|
|
|
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
|
block_tables = torch.randint(0,
|
|
NUM_BLOCKS,
|
|
(num_seqs, max_num_blocks_per_seq),
|
|
dtype=torch.int32)
|
|
k_scale = v_scale = kv_scale
|
|
kv_indptr = [0]
|
|
kv_indices = []
|
|
kv_last_page_lens = []
|
|
for i in range(num_seqs):
|
|
seq_len = kv_lens[i]
|
|
assert seq_len > 0
|
|
num_blocks = (seq_len + block_size - 1) // block_size
|
|
kv_indices.extend(block_tables[i, :num_blocks])
|
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
|
kv_last_page_len = seq_len % block_size
|
|
if kv_last_page_len == 0:
|
|
kv_last_page_len = block_size
|
|
kv_last_page_lens.append(kv_last_page_len)
|
|
|
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
|
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
|
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
|
workspace_buffer,
|
|
kv_layout,
|
|
use_tensor_cores=((num_query_heads // num_kv_heads) > 4))
|
|
wrapper.plan(kv_indptr,
|
|
kv_indices,
|
|
kv_last_page_lens,
|
|
num_query_heads,
|
|
num_kv_heads,
|
|
head_size,
|
|
block_size,
|
|
"NONE",
|
|
sm_scale=scale,
|
|
q_data_type=dtype,
|
|
kv_data_type=kv_cache_dtype,
|
|
logits_soft_cap=soft_cap)
|
|
|
|
output = torch.empty(query.shape, dtype=dtype)
|
|
wrapper.run(query,
|
|
key_value_cache,
|
|
k_scale=k_scale,
|
|
v_scale=v_scale,
|
|
out=output)
|
|
|
|
# TRTLLM Decode
|
|
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
|
output_trtllm = torch.empty(query.shape, dtype=dtype)
|
|
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
|
query=query.contiguous(),
|
|
kv_cache=key_value_cache,
|
|
workspace_buffer=workspace_buffer,
|
|
block_tables=block_tables,
|
|
seq_lens=kv_lens_tensor,
|
|
max_seq_len=max_kv_len,
|
|
bmm1_scale=k_scale * scale,
|
|
bmm2_scale=v_scale,
|
|
out=output_trtllm,
|
|
)
|
|
|
|
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
|
f"{torch.max(torch.abs(output - output_trtllm))}"
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
|
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
|
@pytest.mark.parametrize("soft_cap", [None])
|
|
@torch.inference_mode
|
|
def test_flashinfer_trtllm_prefill_with_baseline(
|
|
batch_size: int,
|
|
num_heads: tuple[int, int],
|
|
head_size: int,
|
|
block_size: int,
|
|
kv_layout: str,
|
|
dtype: torch.dtype,
|
|
kv_cache_dtype: Optional[torch.dtype],
|
|
soft_cap: Optional[float],
|
|
) -> None:
|
|
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
|
if dtype != kv_cache_dtype:
|
|
pytest.skip(f"Not supported dtype({dtype}) with "
|
|
"kv_cache_dtype({kv_cache_dtype})")
|
|
|
|
torch.set_default_device("cuda")
|
|
current_platform.seed_everything(0)
|
|
|
|
q_lens = torch.randint(1, MAX_Q_LEN, (batch_size, ), dtype=torch.int32)
|
|
q_lens[-1] = MAX_Q_LEN
|
|
max_q_len = torch.max(q_lens).item()
|
|
q_indptr = torch.cat([
|
|
torch.tensor([0], dtype=torch.int32),
|
|
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
|
])
|
|
|
|
kv_lens = torch.randint(0, MAX_KV_LEN, (batch_size, ), dtype=torch.int32)
|
|
kv_lens[-1] = MAX_KV_LEN
|
|
|
|
seq_lens = kv_lens + q_lens
|
|
max_seq_len = torch.max(seq_lens).item()
|
|
num_seqs = len(seq_lens)
|
|
|
|
num_query_heads = num_heads[0]
|
|
num_kv_heads = num_heads[1]
|
|
assert num_query_heads % num_kv_heads == 0
|
|
|
|
scale = head_size**-0.5
|
|
|
|
query = torch.randn(torch.sum(q_lens).item(),
|
|
num_query_heads,
|
|
head_size,
|
|
dtype=dtype)
|
|
|
|
kv_cache_shape = None
|
|
if kv_layout == "NHD":
|
|
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
|
elif kv_layout == "HND":
|
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
|
else:
|
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
|
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
|
kv_scale = 1.0
|
|
if kv_cache_dtype is current_platform.fp8_dtype():
|
|
key_value_cache, kv_scale = to_float8(key_value_cache,
|
|
current_platform.fp8_dtype())
|
|
|
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
|
block_tables = torch.randint(0,
|
|
NUM_BLOCKS,
|
|
(num_seqs, max_num_blocks_per_seq),
|
|
dtype=torch.int32)
|
|
k_scale = v_scale = kv_scale
|
|
kv_indptr = [0]
|
|
kv_indices = []
|
|
kv_last_page_lens = []
|
|
for i in range(num_seqs):
|
|
seq_len = seq_lens[i]
|
|
assert seq_len > 0
|
|
num_blocks = (seq_len + block_size - 1) // block_size
|
|
kv_indices.extend(block_tables[i, :num_blocks])
|
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
|
kv_last_page_len = seq_len % block_size
|
|
if kv_last_page_len == 0:
|
|
kv_last_page_len = block_size
|
|
kv_last_page_lens.append(kv_last_page_len)
|
|
|
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
|
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
|
workspace_buffer, kv_layout)
|
|
wrapper.plan(q_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_last_page_lens,
|
|
num_query_heads,
|
|
num_kv_heads,
|
|
head_size,
|
|
block_size,
|
|
causal=True,
|
|
sm_scale=scale,
|
|
q_data_type=dtype,
|
|
kv_data_type=kv_cache_dtype,
|
|
logits_soft_cap=soft_cap)
|
|
|
|
output = torch.empty(query.shape, dtype=dtype)
|
|
wrapper.run(query,
|
|
key_value_cache,
|
|
k_scale=k_scale,
|
|
v_scale=v_scale,
|
|
out=output)
|
|
|
|
# TRTLLM Decode
|
|
output_trtllm = torch.empty(query.shape, dtype=dtype)
|
|
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
|
query=query.contiguous(),
|
|
kv_cache=key_value_cache,
|
|
workspace_buffer=workspace_buffer,
|
|
block_tables=block_tables,
|
|
seq_lens=seq_lens,
|
|
max_q_len=max_q_len,
|
|
max_kv_len=max_seq_len,
|
|
bmm1_scale=k_scale * scale,
|
|
bmm2_scale=v_scale,
|
|
batch_size=num_seqs,
|
|
cum_seq_lens_q=q_indptr,
|
|
cum_seq_lens_kv=kv_indptr,
|
|
out=output_trtllm,
|
|
)
|
|
|
|
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
|
f"{torch.max(torch.abs(output - output_trtllm))}"
|