[Performance][B200] silu_mul_quant: pack scales in int32 (#28358)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-11-13 13:16:55 -05:00 committed by GitHub
parent fdfd5075aa
commit fe1cd7704d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 466 additions and 151 deletions

View File

@ -279,17 +279,17 @@ __device__ __forceinline__ void token_bounds(int32_t n_tokens,
}
template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type,
int THREADS, typename Idx_t, bool USE_UE8M0, int GROUP_SIZE = 128,
int NUM_STAGES = 3>
typename scale_t, int THREADS, typename Idx_t, bool CEIL_UE8M0,
int GROUP_SIZE = 128, int NUM_STAGES = 3>
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
scale_t* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
// sizes
Idx_t E, Idx_t T, Idx_t H,
// strides (in elements)
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
Idx_t stride_ys_g, Idx_t stride_counts_e) {
Idx_t stride_ys_g, Idx_t stride_ys_p, Idx_t stride_counts_e) {
#ifndef USE_ROCM
static constexpr int NUM_WARPS = THREADS / WARP_SIZE;
@ -466,9 +466,22 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
__nv_fp8x4_e4m3* y_q_base_ptr =
reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id;
auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g;
Idx_t scale_group_offset = 0;
if constexpr (std::is_same<scale_t, uint8_t>::value) {
// packed int32_t format
int pack_id = warp_position_scales / 4;
int scale_in_pack = warp_position_scales % 4;
scale_group_offset = pack_id * stride_ys_p + scale_in_pack * stride_ys_g;
} else {
scale_group_offset = warp_position_scales * stride_ys_g;
}
scale_t* const y_scale_base_ptr = _y_s + scale_group_offset;
for (auto j = tokens_lower; j < tokens_upper; j++) {
int current_group_id = warp_position_scales; // Running count of which
// group is being processed
const Idx_t base_ys = expert_id * stride_ys_e;
auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t;
__nv_fp8x4_e4m3* y_q_ptr =
@ -509,7 +522,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
__nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv);
if constexpr (USE_UE8M0) {
if constexpr (CEIL_UE8M0) {
y_s = hexp2(hceil(hlog2(y_s)));
}
@ -527,8 +540,24 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
y_q_ptr += WARP_SIZE * stride_yq_h;
if (!lane_id) {
*y_s_ptr = y_s;
y_s_ptr += stride_ys_g;
// Store scales.
if constexpr (std::is_same<scale_t, uint8_t>::value) {
// Packed UE8MO format. Remove Mantissa.
*y_s_ptr = reinterpret_cast<int16_t&>(y_s) >> 7;
bool const jump_pack = (current_group_id + 1) % 4 == 0;
// Minus 3 because we need to get to the first group in the
// next pack.
y_s_ptr += jump_pack ? (stride_ys_p - 3) : stride_ys_g;
} else {
// float32 format
static_assert(std::is_same<scale_t, float>::value);
*y_s_ptr = y_s;
y_s_ptr += stride_ys_g;
}
current_group_id += 1;
}
}
}
@ -573,7 +602,7 @@ void persistent_masked_m_silu_mul_quant(
const at::Tensor& tokens_per_expert, // (E)
at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool use_ue8m0) {
bool cast_scale_ue8m0) {
#ifndef USE_ROCM
// This kernel currently only supports H % 128 == 0 and assumes a
@ -583,9 +612,12 @@ void persistent_masked_m_silu_mul_quant(
TORCH_CHECK(input.dtype() == torch::kBFloat16);
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
bool const is_packed_ue8m0 =
(y_s.dtype() == torch::kInt32 && cast_scale_ue8m0);
TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0);
using Idx_t = int64_t;
Idx_t E = input.size(0);
@ -597,15 +629,18 @@ void persistent_masked_m_silu_mul_quant(
Idx_t stride_yq_e = y_q.stride(0);
Idx_t stride_yq_t = y_q.stride(1);
Idx_t stride_yq_h = y_q.stride(2);
Idx_t stride_ys_e = y_s.stride(0);
Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2);
Idx_t stride_counts_e = tokens_per_expert.stride(0);
int const NUM_GROUPS = H / GROUP_SIZE;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
// TODO: Get this from cuda_arch ?
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
#define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
int sms = SILU_V2_BLOCK_COUNT; \
static constexpr int max_shared_mem_bytes = \
@ -615,43 +650,86 @@ void persistent_masked_m_silu_mul_quant(
VLLM_DISPATCH_FP8_TYPES( \
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \
USE_UE8M0, GROUP_SIZE, STAGES> \
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
(fp8_t*)y_q.data_ptr(), \
reinterpret_cast<scale_t*>(y_s.data_ptr()), \
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \
stride_ys_g, stride_counts_e); \
stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
});
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
#define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0) \
if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
/* 8 warp config */ \
static constexpr int NUM_STAGES = 4; \
static constexpr int THREAD_COUNT = 256; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
} else { \
/* 1 warp config */ \
static constexpr int THREAD_COUNT = 32; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
}
int const NUM_GROUPS = H / GROUP_SIZE;
if (!use_ue8m0) {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
}
} else {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
}
Idx_t stride_ys_e = y_s.stride(0);
Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2);
Idx_t stride_ys_p = 0;
if (!cast_scale_ue8m0) {
TORCH_CHECK(!is_packed_ue8m0);
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
false);
return;
}
if (!is_packed_ue8m0) {
// UE8M0 but not packed
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
true);
return;
}
TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0);
TORCH_CHECK(y_s.dtype() == torch::kInt32);
// Int32 packed ue8m0 scales tensor.
// Let E, T, G be the number to experts, number of tokens and number of groups
// respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
// tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
// to be arranged as follows,
// [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
// [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
// [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
// [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
// where, TxGy is the scale ue8m0 scale value of Token x, Group y.
//
// In memory (in bytes) the scale values are arranged as,
// [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
// T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
// X, X, T3G4, T3G5, X, X]
//
// An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
// as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
// english, ignoring the Experts dimension, the original int32 tensor is
// simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
// tensor). The following strides setting reflects this change. Caveat: This
// means that the G dimension is no longer contiguous. i.e. Note that to move
// from G3 to G4, we need to jump along the packing dimension. The kernel
// handles this case.
stride_ys_e *= sizeof(int32_t);
stride_ys_p = T * sizeof(int32_t); // Packing dimension
stride_ys_t = sizeof(int32_t);
stride_ys_g = 1;
LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
true);
#endif
}

View File

@ -1384,3 +1384,16 @@ def image_urls(request, local_asset_server) -> list[str]:
"""Indirect fixture: takes a list of names, returns list of full URLs."""
names: list[str] = request.param
return [local_asset_server.url_for(name) for name in names]
@pytest.fixture
def disable_deepgemm_ue8m0(monkeypatch):
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
with monkeypatch.context() as monkeypatch_ctx:
monkeypatch_ctx.setenv("VLLM_USE_DEEP_GEMM_E8M0", "0")
is_deep_gemm_e8m0_used.cache_clear()
yield
# Clear cache so the next time it is used it is processed with the
# default VLLM_USE_DEEP_GEMM_E8M0 setting.
is_deep_gemm_e8m0_used.cache_clear()

View File

@ -21,7 +21,11 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from ...utils import multi_gpu_test
@ -413,19 +417,16 @@ NUM_EXPERTS = [32]
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
)
def test_ht_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
disable_deepgemm_ue8m0,
):
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
"""
import deep_gemm
m, n, k = mnk
current_platform.seed_everything(7)
@ -433,7 +434,7 @@ def test_ht_deepep_deepgemm_moe(
if topk > num_experts:
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_m = get_mk_alignment_for_contiguous_layout()[0]
block_size = [block_m, block_m]
world_size, dp_size = world_dp_size
@ -487,9 +488,6 @@ USE_FP8_DISPATCH = [False]
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
@requires_deep_gemm
@pytest.mark.skipif(
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
)
def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,
@ -497,10 +495,12 @@ def test_ll_deepep_deepgemm_moe(
use_fp8_dispatch: bool,
block_size: list[int],
world_dp_size: tuple[int, int],
disable_deepgemm_ue8m0,
):
"""
Tests for Low-Latency DeepEP + DeepGemm integration.
"""
assert not is_deep_gemm_e8m0_used()
m, n, k = mnk
current_platform.seed_everything(7)

View File

@ -294,7 +294,7 @@ def torch_moe_impl(
# blockwise quant and de-quant.
assert not per_act_token_quant
a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128)
aq, aq_scale = per_token_group_quant_fp8(a, 128, use_ue8m0=False)
a = (
(aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1))
.view(a.shape)

View File

@ -1,6 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import torch
@ -8,27 +11,30 @@ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
persistent_masked_m_silu_mul_quant,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
from vllm.utils.math_utils import cdiv, round_up
fp8_dtype = torch.float8_e4m3fn
CASES = [
(1, 1, 128, fp8_dtype),
(1, 4, 128, fp8_dtype),
(2, 4, 256, fp8_dtype),
(32, 64, 256, fp8_dtype),
(17, 31, 768, fp8_dtype),
(1, 1, 128 * 1, fp8_dtype),
(1, 1, 128 * 3, fp8_dtype),
(1, 1, 128 * 4, fp8_dtype),
(8, 16, 128 * 1, fp8_dtype),
(8, 16, 128 * 2, fp8_dtype),
(8, 16, 128 * 3, fp8_dtype),
(1, 4, 128 * 1, fp8_dtype),
(2, 4, 128 * 2, fp8_dtype),
(1, 4, 128 * 3, fp8_dtype),
(8, 16, 128 * 4, fp8_dtype),
(8, 16, 128 * 5, fp8_dtype),
(8, 16, 128 * 6, fp8_dtype),
(8, 16, 128 * 7, fp8_dtype),
(8, 16, 128 * 8, fp8_dtype),
(8, 16, 128 * 9, fp8_dtype),
(8, 64, 7168, fp8_dtype),
(8, 128, 128 * 33, fp8_dtype),
(1, 4, 128 * 10, fp8_dtype),
(8, 128, 7168, fp8_dtype),
(8, 512, 7168, fp8_dtype),
(8, 1024, 7168, fp8_dtype),
(17, 31, 768, fp8_dtype),
(32, 64, 256, fp8_dtype),
(256, 8, 7168, fp8_dtype),
(256, 32, 7168, fp8_dtype),
(256, 64, 7168, fp8_dtype),
@ -38,14 +44,159 @@ CASES = [
]
def as_uint8(x) -> torch.Tensor:
return (
torch.empty(x.shape, dtype=x.dtype, device=x.device).copy_(x).view(torch.uint8)
)
def silu(x: torch.Tensor) -> torch.Tensor:
one_f32 = torch.tensor([1.0], device=x.device, dtype=torch.float32)
x_f32 = x.to(torch.float32)
act_f32 = x_f32 / (one_f32 + torch.exp(-x_f32))
assert act_f32.dtype == torch.float32
return act_f32.to(torch.bfloat16)
def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool):
eps_bf16 = torch.tensor([1e-10], device=x.device, dtype=torch.bfloat16)
one_bf16 = torch.tensor([1.0], device=x.device, dtype=torch.bfloat16)
fp8_max_bf16 = torch.tensor(
[torch.finfo(fp8_dtype).max], device=x.device, dtype=torch.bfloat16
)
fp8_min_bf16 = torch.tensor(
[torch.finfo(fp8_dtype).min], device=x.device, dtype=torch.bfloat16
)
fp8_max_inv = one_bf16 / fp8_max_bf16
assert fp8_max_inv.dtype == torch.bfloat16
assert x.size(-1) % group_size == 0
num_groups = x.numel() // group_size
x_og_shape = x.shape
x = x.to(torch.bfloat16)
x = x.view((-1, group_size))
amax = x.abs().amax(dim=1).clamp(min=eps_bf16)
assert amax.dtype == torch.bfloat16
s = amax * fp8_max_inv
if ceil_ue8m0:
s = torch.exp2(
torch.ceil(torch.log2(s).to(torch.bfloat16)).to(torch.bfloat16)
).to(torch.bfloat16)
inv_s = one_bf16 / s
inv_s = inv_s.view((num_groups, 1))
xq = torch.clamp(x * inv_s, min=fp8_min_bf16.item(), max=fp8_max_bf16.item()).to(
fp8_dtype
)
xq = xq.view(x_og_shape)
xs = s.view((-1, xq.size(-1) // group_size))
return xq, xs
def silu_mul_quant(
gate: torch.Tensor, up: torch.Tensor, group_size: int, ceil_ue8m0: bool
) -> tuple[torch.Tensor, torch.Tensor]:
assert gate.size(-1) % group_size == 0
assert up.size(-1) % group_size == 0
assert gate.dtype == torch.bfloat16
assert up.dtype == torch.bfloat16
act_bf16 = silu(gate)
assert act_bf16.dtype == torch.bfloat16
# act & mul
a_m = act_bf16 * up
assert a_m.dtype == torch.bfloat16
q, s = do_quant(a_m, group_size, ceil_ue8m0)
return q, s
def pack_scales(x: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Tensor:
"""
pack float32 scales into a int32 tensor
"""
assert x.dtype == torch.float32
E, T, G = x.size()
# Add i32_padding here so we can view it as a i32 tensor later on.
i32_padding = round_up(G, 4) - G
ref_s_i8 = torch.empty((E, T, G + i32_padding), dtype=torch.uint8, device="cuda")
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s_i8[e, :nt, :G] = x[e, :nt].view(torch.int32) >> 23
ref_s_i32 = ref_s_i8.view(torch.int32)
return ref_s_i32
def ref_with_scale_fmt(
E: int,
T: int,
H: int,
group_size: int,
tokens_per_expert: torch.Tensor,
gate: torch.Tensor,
up: torch.Tensor,
scale_fmt: DeepGemmQuantScaleFMT,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
The precision types of the operations triggered by this function
match closely with the kernel implementation so we compare more
accurately.
"""
scale_dtype = (
torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32
)
ceil_ue8m0 = scale_fmt in [
DeepGemmQuantScaleFMT.UE8M0,
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
]
ref_q = torch.empty((E, T, H), dtype=fp8_dtype, device="cuda")
ref_s_f32 = torch.empty(
(E, T, cdiv(H, group_size)), dtype=torch.float32, device="cuda"
)
for e in range(E):
nt = tokens_per_expert[e].item()
if nt == 0:
continue
ref_q[e, :nt], ref_s_f32[e, :nt] = silu_mul_quant(
gate[e, :nt], up[e, :nt], group_size, ceil_ue8m0=ceil_ue8m0
)
if scale_dtype == torch.float32:
return ref_q, ref_s_f32
assert scale_dtype == torch.int32
return ref_q, pack_scales(ref_s_f32, tokens_per_expert)
def token_random(E, T, H2, tokens_per_expert):
"""
Initialize each token in a random range so we test a range of
scale values.
"""
y = torch.empty((E, T, H2), dtype=torch.bfloat16, device="cuda")
for e in range(E):
for t in range(tokens_per_expert[e].item()):
exp = random.choice(range(1, 20))
y[e, t].uniform_(-(2**exp), 2**exp)
return y
@pytest.mark.parametrize("E,T,H,fp8_type", CASES)
@torch.inference_mode()
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dtype):
group_size = 128
current_platform.seed_everything(42)
# Input tensor of shape (E, T, 2*H)
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
tokens_per_expert = torch.randint(
low=0,
high=T,
@ -54,71 +205,83 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
device="cuda",
)
# Input tensor of shape (E, T, 2*H)
y = token_random(E, T, 2 * H, tokens_per_expert)
gate = y[..., :H].to(torch.bfloat16)
up = y[..., H:].to(torch.bfloat16)
scale_fmts = [
DeepGemmQuantScaleFMT.FLOAT32,
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
DeepGemmQuantScaleFMT.UE8M0,
]
# Run the SiLU V2 kernel
# TODO (varun): use_e8m0 is set to false as the reference impl does
# not handle that case.
y_q, y_s = persistent_masked_m_silu_mul_quant(
y, tokens_per_expert, group_size=group_size, use_ue8m0=False
)
torch.cuda.synchronize()
fp8_info = torch.finfo(fp8_dtype)
fp8_max = fp8_info.max
fp8_min = fp8_info.min
eps = 1e-10
y1 = y[..., :H].float()
y2 = y[..., H:]
silu_x = y1 * torch.sigmoid(y1)
merged = silu_x * y2
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s = torch.empty(
(T, cdiv(H, group_size)), dtype=torch.float32, device="cuda"
)
ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda")
for t in range(nt):
data = merged[e, t].float()
ref_q_row = torch.empty_like(data)
# process full groups
n_full_groups = H // group_size
if n_full_groups > 0:
data_grp = data[: n_full_groups * group_size].view(
n_full_groups, group_size
)
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
scale = amax / fp8_max
scaled = data[: n_full_groups * group_size] / scale.repeat_interleave(
group_size
)
ref_q_row[: n_full_groups * group_size] = scaled.clamp(
fp8_min, fp8_max
).to(fp8_dtype)
ref_s[t, :n_full_groups] = scale
# process remainder group
rem = H % group_size
if rem > 0:
data_rem = data[-rem:]
amax = data_rem.abs().amax().clamp(min=eps)
scale = amax / fp8_max
scaled = data_rem / scale
ref_q_row[-rem:] = scaled.clamp(fp8_min, fp8_max).to(fp8_dtype)
ref_s[t, -1] = scale
ref_q[t] = ref_q_row
y_se = y_s[e].float()
y_qe = y_q[e].float()
torch.testing.assert_close(
y_qe[:nt].to(torch.float32),
ref_q[:nt].to(torch.float32),
atol=2,
rtol=2e-1,
for scale_fmt in scale_fmts:
y_q, y_s = persistent_masked_m_silu_mul_quant(
y,
tokens_per_expert,
group_size=group_size,
quant_scale_fmt=scale_fmt,
)
torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2)
ref_y_q, ref_y_s = ref_with_scale_fmt(
E, T, H, group_size, tokens_per_expert, gate, up, scale_fmt=scale_fmt
)
# deepgemm scales transform
dg_scales = None
if (
has_deep_gemm()
and current_platform.has_device_capability(100)
and scale_fmt == DeepGemmQuantScaleFMT.UE8M0
):
from deep_gemm import transform_sf_into_required_layout
_q, _s = ref_with_scale_fmt(
E,
T,
H,
group_size,
tokens_per_expert,
gate,
up,
scale_fmt=DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
)
dg_scales = transform_sf_into_required_layout(
sf=_s,
mn=_q.size(1),
k=_q.size(2),
recipe=(1, 128, 128),
num_groups=_q.size(0),
is_sfa=True,
)
expected_scale_dtype = (
torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32
)
assert y_s.dtype == expected_scale_dtype
assert ref_y_s.dtype == expected_scale_dtype
for e in range(E):
nt = tokens_per_expert[e].item()
torch.testing.assert_close(
y_q[e, :nt].to(torch.float32),
ref_y_q[e, :nt].to(torch.float32),
)
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
G = H // group_size
y_s_sliced = as_uint8(y_s[e])
ref_s_sliced = as_uint8(ref_y_s[e])
torch.testing.assert_close(y_s_sliced[:nt, :G], ref_s_sliced[:nt, :G])
if dg_scales is not None:
dg_sliced = as_uint8(dg_scales[e])
torch.testing.assert_close(y_s_sliced[:nt, :G], dg_sliced[:nt, :G])
else:
torch.testing.assert_close(
y_s[e, :nt],
ref_y_s[e, :nt],
)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@ -13,14 +14,33 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_m_grouped_gemm_nt_masked,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
)
from vllm.utils.math_utils import cdiv
logger = init_logger(__name__)
def scales_shape_stride_dtype(
E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT
) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]:
shape = (E, T, G)
strides = (T * G, 1, T)
if quant_scale_fmt in [
DeepGemmQuantScaleFMT.FLOAT32,
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
]:
return shape, strides, torch.float32
assert quant_scale_fmt == DeepGemmQuantScaleFMT.UE8M0
shape = (E, T, cdiv(G, 4))
strides = (T * cdiv(G, 4), 1, T)
return shape, strides, torch.int32
@triton.jit
def _silu_mul_fp8_quant_deep_gemm(
# Pointers ------------------------------------------------------------
@ -49,7 +69,7 @@ def _silu_mul_fp8_quant_deep_gemm(
eps: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr,
ceil_ue8m0: tl.constexpr,
# Meta ---------------------------------------------------------------
BLOCK: tl.constexpr,
NUM_STAGES: tl.constexpr,
@ -86,7 +106,7 @@ def _silu_mul_fp8_quant_deep_gemm(
y = gate * up
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
if use_ue8m0:
if ceil_ue8m0:
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
@ -100,7 +120,7 @@ def persistent_masked_m_silu_mul_quant(
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
num_parallel_tokens=16,
group_size: int = 128,
use_ue8m0: bool | None = None,
quant_scale_fmt: DeepGemmQuantScaleFMT = DeepGemmQuantScaleFMT.FLOAT32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
@ -137,7 +157,13 @@ def persistent_masked_m_silu_mul_quant(
Returns `(y_q, y_s)` where
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
* `y_s` depends on quant_scale_fmt,
- quant_scale_fmt == FLOAT32,
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
- quant_scale_fmt == E8M0,
`y_s`: Int32 tensor, shape (E, T, H // group_size // 4), strides (T*G, 1, T)
- quant_scale_fmt == E8M0_FLOAT32_SPARSE
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
Let NUM_WARPS be the number of warps in a single thread block and
`GROUP_SIZE = 128` be the size of the quantization group.
"""
@ -155,17 +181,18 @@ def persistent_masked_m_silu_mul_quant(
fp8_dtype = torch.float8_e4m3fn
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
stride_ys_e = T * G
stride_ys_t = 1
stride_ys_g = T
ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt)
y_s = torch.empty_strided(
(E, T, G),
(stride_ys_e, stride_ys_t, stride_ys_g),
dtype=torch.float32,
ys_shape,
ys_strides,
dtype=ys_dtype,
device=y.device,
)
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()
ceil_ue8m0 = quant_scale_fmt in [
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
DeepGemmQuantScaleFMT.UE8M0,
]
cuda_arch = current_platform.get_device_capability(
device_id=y.device.index
@ -173,7 +200,7 @@ def persistent_masked_m_silu_mul_quant(
if cuda_arch >= 80:
torch.ops._C.persistent_masked_m_silu_mul_quant(
y, tokens_per_expert, y_q, y_s, use_ue8m0
y, tokens_per_expert, y_q, y_s, ceil_ue8m0
)
else:
stride_cnt_e = tokens_per_expert.stride()[0]
@ -189,6 +216,10 @@ def persistent_masked_m_silu_mul_quant(
fp8_max = f_info.max
fp8_min = f_info.min
eps: float = 1e-10
assert y_s.dtype == torch.float32, (
"_silu_mul_fp8_quant_deep_gemm does"
"not support {y_s.dtype} scales. Only torch.float32 supported."
)
_silu_mul_fp8_quant_deep_gemm[grid](
y,
y_q,
@ -202,14 +233,14 @@ def persistent_masked_m_silu_mul_quant(
stride_yq_e,
stride_yq_t,
stride_yq_h,
stride_ys_e,
stride_ys_t,
stride_ys_g,
ys_strides[0],
ys_strides[1],
ys_strides[2],
stride_cnt_e,
eps,
fp8_min,
fp8_max,
is_deep_gemm_e8m0_used(),
ceil_ue8m0,
BLOCK=group_size,
NUM_STAGES=4,
num_warps=1,
@ -255,7 +286,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
"""
return current_platform.is_device_capability(100)
return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
@ -329,10 +360,17 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m,
)
quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
workspace1, expert_num_tokens
workspace1,
expert_num_tokens,
quant_scale_fmt=quant_scale_fmt,
)
fp8_m_grouped_gemm_nt_masked(
(a2q, a2q_scale), (w2, self.w2_scale), output, expert_num_tokens, expected_m
(a2q, a2q_scale),
(w2, self.w2_scale),
output,
expert_num_tokens,
expected_m,
)

View File

@ -9,6 +9,7 @@ import functools
import importlib
import os
from collections.abc import Callable
from enum import Enum
from typing import Any, NoReturn
import torch
@ -20,6 +21,28 @@ from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.math_utils import cdiv
class DeepGemmQuantScaleFMT(Enum):
# Float32 scales in Float32 tensor
FLOAT32 = 0
# Compute float32 scales and ceil the scales to UE8M0.
# Keep the scales in Float32 tensor.
FLOAT32_CEIL_UE8M0 = 1
# Compute float32 scales and ceil the scales to UE8M0.
# Pack the scales into a int32 tensor where each int32
# element contains 4 scale values.
UE8M0 = 2
@staticmethod
def from_oracle() -> "DeepGemmQuantScaleFMT":
if not is_deep_gemm_e8m0_used():
return DeepGemmQuantScaleFMT.FLOAT32
return (
DeepGemmQuantScaleFMT.UE8M0
if current_platform.is_device_capability(100)
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
)
@functools.cache
def is_deep_gemm_supported() -> bool:
"""Return `True` if DeepGEMM is supported on the current platform.