mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:24:56 +08:00
Signed-off-by: Yeshwanth Surya <yeshsurya@gmail.com> Signed-off-by: Yeshwanth N <yeshsurya@gmail.com> Signed-off-by: yeshsurya <yeshsurya@gmail.com>
295 lines
10 KiB
Python
295 lines
10 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import random
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.deep_gemm import (
|
|
_ceil_to_ue8m0,
|
|
calc_diff,
|
|
fp8_mqa_logits,
|
|
fp8_paged_mqa_logits,
|
|
get_num_sms,
|
|
get_paged_mqa_logits_metadata,
|
|
)
|
|
from vllm.utils.import_utils import has_deep_gemm
|
|
from vllm.utils.math_utils import cdiv
|
|
|
|
|
|
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
|
|
# x: (num_blocks, block_size, 1, head_dim)
|
|
num_blocks, block_size, num_heads, head_dim = x.shape
|
|
assert num_heads == 1
|
|
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
|
|
sf = x_amax / 448.0
|
|
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
|
x_fp8 = torch.empty(
|
|
(num_blocks, block_size * (head_dim + 4)),
|
|
device=x.device,
|
|
dtype=torch.uint8,
|
|
)
|
|
x_fp8[:, : block_size * head_dim] = x_scaled.view(
|
|
num_blocks, block_size * head_dim
|
|
).view(dtype=torch.uint8)
|
|
x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view(
|
|
dtype=torch.uint8
|
|
)
|
|
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
|
|
|
|
|
|
def per_custom_dims_cast_to_fp8(
|
|
x: torch.Tensor, dims: tuple, use_ue8m0: bool
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
|
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
|
sf = x_amax / 448.0
|
|
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
|
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
|
return x_scaled, sf.squeeze()
|
|
|
|
|
|
def _generate_cp_test_data(seq_len: int, seq_len_kv: int):
|
|
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
|
|
chunk_size = seq_len // 2
|
|
cp_size = seq_len_kv // seq_len
|
|
cp_id = cp_size // 3
|
|
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
|
ke = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
|
for i in range(chunk_size):
|
|
ke[i] = cp_id * chunk_size + i
|
|
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
|
|
return ks, ke
|
|
|
|
|
|
def _ref_fp8_mqa_logits(
|
|
q: torch.Tensor,
|
|
kv: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
cu_seqlen_ks: torch.Tensor,
|
|
cu_seqlen_ke: torch.Tensor,
|
|
):
|
|
seq_len_kv = kv.shape[0]
|
|
|
|
k = kv
|
|
q = q.float()
|
|
k = k.float()
|
|
|
|
mask_lo = (
|
|
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
|
)
|
|
mask_hi = (
|
|
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
|
)
|
|
mask = mask_lo & mask_hi
|
|
score = torch.einsum("mhd,nd->hmn", q, k)
|
|
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
|
logits = logits.masked_fill(~mask, float("-inf"))
|
|
|
|
return logits
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
|
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
|
@pytest.mark.skipif(
|
|
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
|
)
|
|
def test_deepgemm_fp8_mqa_logits():
|
|
torch.manual_seed(0)
|
|
random.seed(0)
|
|
num_heads, head_dim = 32, 128
|
|
for seq_len in (512,):
|
|
for seq_len_kv in (1024,):
|
|
for disable_cp in (False, True):
|
|
q = torch.randn(
|
|
seq_len,
|
|
num_heads,
|
|
head_dim,
|
|
device="cuda",
|
|
dtype=torch.bfloat16,
|
|
)
|
|
kv = torch.randn(
|
|
seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16
|
|
)
|
|
weights = torch.randn(
|
|
seq_len, num_heads, device="cuda", dtype=torch.float32
|
|
)
|
|
|
|
if disable_cp:
|
|
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
|
ke = torch.arange(seq_len, dtype=torch.int, device="cuda") + (
|
|
seq_len_kv - seq_len
|
|
)
|
|
else:
|
|
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
|
|
|
|
q_fp8 = q.to(torch.float8_e4m3fn)
|
|
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
|
|
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
|
|
|
|
ref_logits = _ref_fp8_mqa_logits(
|
|
q=q,
|
|
kv=kv,
|
|
weights=weights,
|
|
cu_seqlen_ks=ks,
|
|
cu_seqlen_ke=ke,
|
|
)
|
|
|
|
ref_neginf_mask = ref_logits == float("-inf")
|
|
neginf_mask = logits == float("-inf")
|
|
assert torch.equal(neginf_mask, ref_neginf_mask)
|
|
|
|
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
|
|
logits = logits.masked_fill(neginf_mask, 0)
|
|
diff = calc_diff(logits, ref_logits)
|
|
assert diff < 1e-3, f"{diff=}"
|
|
|
|
|
|
def _ref_fp8_paged_mqa_logits(
|
|
q: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
context_lens: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
max_model_len: int,
|
|
):
|
|
batch_size, next_n, _, _ = q.size()
|
|
_, block_size, _, _ = kv_cache.size()
|
|
logits = torch.full(
|
|
[batch_size * next_n, max_model_len],
|
|
float("-inf"),
|
|
device=q.device,
|
|
dtype=torch.float32,
|
|
)
|
|
context_lens_list = context_lens.tolist()
|
|
for i in range(batch_size):
|
|
context_len = context_lens_list[i]
|
|
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
|
|
weight_slice = (
|
|
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
|
)
|
|
for block_rk in range(cdiv(context_len, block_size)):
|
|
block_idx = block_tables[i][block_rk]
|
|
qx, kx = q[i], kv_cache[block_idx]
|
|
k_offsets = torch.arange(
|
|
block_rk * block_size,
|
|
(block_rk + 1) * block_size,
|
|
device="cuda",
|
|
)
|
|
mask = (k_offsets[None, :] < context_len) & (
|
|
k_offsets[None, :] <= q_offsets[:, None]
|
|
)
|
|
s = torch.where(
|
|
mask[None, :, :],
|
|
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
|
logits.dtype
|
|
),
|
|
float("-inf"),
|
|
)
|
|
s = torch.relu(s) * weight_slice[..., None]
|
|
s = s.sum(dim=0)
|
|
logits[
|
|
i * next_n : (i + 1) * next_n,
|
|
block_rk * block_size : (block_rk + 1) * block_size,
|
|
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
|
return logits
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
|
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
|
@pytest.mark.skipif(
|
|
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
|
)
|
|
def test_deepgemm_fp8_paged_mqa_logits():
|
|
torch.manual_seed(0)
|
|
random.seed(0)
|
|
|
|
max_model_len = 4096
|
|
for batch_size, next_n in [(4, 1), (2, 2)]:
|
|
for heads, index_dim in [(32, 128)]:
|
|
for avg_kv in (2048,):
|
|
num_blocks, blocksize = max_model_len * 2, 64
|
|
|
|
q = torch.randn(
|
|
(batch_size, next_n, heads, index_dim),
|
|
device="cuda",
|
|
dtype=torch.bfloat16,
|
|
)
|
|
kv_cache = torch.randn(
|
|
(num_blocks, blocksize, 1, index_dim),
|
|
device="cuda",
|
|
dtype=torch.bfloat16,
|
|
)
|
|
weights = torch.randn(
|
|
(batch_size * next_n, heads),
|
|
device="cuda",
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
context_lens = (
|
|
torch.randint(int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,))
|
|
.cuda()
|
|
.to(torch.int32)
|
|
)
|
|
max_block_len = (
|
|
(context_lens.max().item() + blocksize - 1) // blocksize * blocksize
|
|
)
|
|
block_tables = torch.zeros(
|
|
(batch_size, max_block_len),
|
|
device="cuda",
|
|
dtype=torch.int32,
|
|
)
|
|
|
|
counter = 0
|
|
block_idx_pool = list(range(num_blocks))
|
|
random.shuffle(block_idx_pool)
|
|
for i in range(batch_size):
|
|
ctx_len = int(context_lens[i].item())
|
|
for j in range((ctx_len + blocksize - 1) // blocksize):
|
|
block_tables[i][j] = block_idx_pool[counter]
|
|
counter += 1
|
|
|
|
q_fp8 = q.to(torch.float8_e4m3fn)
|
|
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
|
|
|
|
schedule_metadata = get_paged_mqa_logits_metadata(
|
|
context_lens, blocksize, get_num_sms()
|
|
)
|
|
logits = fp8_paged_mqa_logits(
|
|
q_fp8,
|
|
kv_cache_fp8,
|
|
weights,
|
|
context_lens,
|
|
block_tables,
|
|
schedule_metadata,
|
|
max_model_len,
|
|
)
|
|
|
|
ref_logits = _ref_fp8_paged_mqa_logits(
|
|
q,
|
|
kv_cache,
|
|
weights,
|
|
context_lens,
|
|
block_tables,
|
|
max_model_len,
|
|
)
|
|
|
|
positions = (
|
|
torch.arange(max_model_len, device="cuda")
|
|
.unsqueeze(0)
|
|
.expand(batch_size * next_n, -1)
|
|
)
|
|
row_indices = torch.arange(batch_size * next_n, device="cuda") // next_n
|
|
next_n_offset = (
|
|
torch.arange(batch_size * next_n, device="cuda") % next_n
|
|
)
|
|
mask = positions <= (
|
|
context_lens[row_indices] - next_n + next_n_offset
|
|
).unsqueeze(1)
|
|
|
|
logits = logits.masked_fill(~mask, 0)
|
|
ref_logits = ref_logits.masked_fill(~mask, 0)
|
|
diff = calc_diff(logits, ref_logits)
|
|
assert diff < 1e-3, f"{diff=}"
|