[Performance][B200] silu_mul_quant: pack scales in int32 (#28358)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-11-13 13:16:55 -05:00 committed by GitHub
parent fdfd5075aa
commit fe1cd7704d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 466 additions and 151 deletions

View File

@ -279,17 +279,17 @@ __device__ __forceinline__ void token_bounds(int32_t n_tokens,
} }
template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type, template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type,
int THREADS, typename Idx_t, bool USE_UE8M0, int GROUP_SIZE = 128, typename scale_t, int THREADS, typename Idx_t, bool CEIL_UE8M0,
int NUM_STAGES = 3> int GROUP_SIZE = 128, int NUM_STAGES = 3>
__global__ void silu_mul_fp8_quant_deep_gemm_kernel( __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q, const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert, scale_t* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
// sizes // sizes
Idx_t E, Idx_t T, Idx_t H, Idx_t E, Idx_t T, Idx_t H,
// strides (in elements) // strides (in elements)
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
Idx_t stride_ys_g, Idx_t stride_counts_e) { Idx_t stride_ys_g, Idx_t stride_ys_p, Idx_t stride_counts_e) {
#ifndef USE_ROCM #ifndef USE_ROCM
static constexpr int NUM_WARPS = THREADS / WARP_SIZE; static constexpr int NUM_WARPS = THREADS / WARP_SIZE;
@ -466,9 +466,22 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
__nv_fp8x4_e4m3* y_q_base_ptr = __nv_fp8x4_e4m3* y_q_base_ptr =
reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id; reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id;
auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g;
Idx_t scale_group_offset = 0;
if constexpr (std::is_same<scale_t, uint8_t>::value) {
// packed int32_t format
int pack_id = warp_position_scales / 4;
int scale_in_pack = warp_position_scales % 4;
scale_group_offset = pack_id * stride_ys_p + scale_in_pack * stride_ys_g;
} else {
scale_group_offset = warp_position_scales * stride_ys_g;
}
scale_t* const y_scale_base_ptr = _y_s + scale_group_offset;
for (auto j = tokens_lower; j < tokens_upper; j++) { for (auto j = tokens_lower; j < tokens_upper; j++) {
int current_group_id = warp_position_scales; // Running count of which
// group is being processed
const Idx_t base_ys = expert_id * stride_ys_e; const Idx_t base_ys = expert_id * stride_ys_e;
auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t; auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t;
__nv_fp8x4_e4m3* y_q_ptr = __nv_fp8x4_e4m3* y_q_ptr =
@ -509,7 +522,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
__nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv); __nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv);
if constexpr (USE_UE8M0) { if constexpr (CEIL_UE8M0) {
y_s = hexp2(hceil(hlog2(y_s))); y_s = hexp2(hceil(hlog2(y_s)));
} }
@ -527,8 +540,24 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
y_q_ptr += WARP_SIZE * stride_yq_h; y_q_ptr += WARP_SIZE * stride_yq_h;
if (!lane_id) { if (!lane_id) {
*y_s_ptr = y_s; // Store scales.
y_s_ptr += stride_ys_g; if constexpr (std::is_same<scale_t, uint8_t>::value) {
// Packed UE8MO format. Remove Mantissa.
*y_s_ptr = reinterpret_cast<int16_t&>(y_s) >> 7;
bool const jump_pack = (current_group_id + 1) % 4 == 0;
// Minus 3 because we need to get to the first group in the
// next pack.
y_s_ptr += jump_pack ? (stride_ys_p - 3) : stride_ys_g;
} else {
// float32 format
static_assert(std::is_same<scale_t, float>::value);
*y_s_ptr = y_s;
y_s_ptr += stride_ys_g;
}
current_group_id += 1;
} }
} }
} }
@ -573,7 +602,7 @@ void persistent_masked_m_silu_mul_quant(
const at::Tensor& tokens_per_expert, // (E) const at::Tensor& tokens_per_expert, // (E)
at::Tensor& y_q, // (E, T, H) [OUT] at::Tensor& y_q, // (E, T, H) [OUT]
at::Tensor& y_s, // (E, T, H//group_size) [OUT] at::Tensor& y_s, // (E, T, H//group_size) [OUT]
bool use_ue8m0) { bool cast_scale_ue8m0) {
#ifndef USE_ROCM #ifndef USE_ROCM
// This kernel currently only supports H % 128 == 0 and assumes a // This kernel currently only supports H % 128 == 0 and assumes a
@ -583,9 +612,12 @@ void persistent_masked_m_silu_mul_quant(
TORCH_CHECK(input.dtype() == torch::kBFloat16); TORCH_CHECK(input.dtype() == torch::kBFloat16);
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
y_q.dtype() == torch::kFloat8_e4m3fnuz); y_q.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0); TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
bool const is_packed_ue8m0 =
(y_s.dtype() == torch::kInt32 && cast_scale_ue8m0);
TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0);
using Idx_t = int64_t; using Idx_t = int64_t;
Idx_t E = input.size(0); Idx_t E = input.size(0);
@ -597,15 +629,18 @@ void persistent_masked_m_silu_mul_quant(
Idx_t stride_yq_e = y_q.stride(0); Idx_t stride_yq_e = y_q.stride(0);
Idx_t stride_yq_t = y_q.stride(1); Idx_t stride_yq_t = y_q.stride(1);
Idx_t stride_yq_h = y_q.stride(2); Idx_t stride_yq_h = y_q.stride(2);
Idx_t stride_ys_e = y_s.stride(0);
Idx_t stride_ys_t = y_s.stride(1);
Idx_t stride_ys_g = y_s.stride(2);
Idx_t stride_counts_e = tokens_per_expert.stride(0); Idx_t stride_counts_e = tokens_per_expert.stride(0);
int const NUM_GROUPS = H / GROUP_SIZE;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ // TODO: Get this from cuda_arch ?
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
#define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \ static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
int sms = SILU_V2_BLOCK_COUNT; \ int sms = SILU_V2_BLOCK_COUNT; \
static constexpr int max_shared_mem_bytes = \ static constexpr int max_shared_mem_bytes = \
@ -615,43 +650,86 @@ void persistent_masked_m_silu_mul_quant(
VLLM_DISPATCH_FP8_TYPES( \ VLLM_DISPATCH_FP8_TYPES( \
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \ y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \ vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \ BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \
USE_UE8M0, GROUP_SIZE, STAGES> \ Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \ <<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \ (fp8_t*)y_q.data_ptr(), \
reinterpret_cast<scale_t*>(y_s.data_ptr()), \
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \ reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \ T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \ stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \
stride_ys_g, stride_counts_e); \ STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \
}); });
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; #define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \
STRIDE_YS_P, CEIL_UE8M0) \
if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \
/* 8 warp config */ \
static constexpr int NUM_STAGES = 4; \
static constexpr int THREAD_COUNT = 256; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \
} else { \
/* 1 warp config */ \
static constexpr int THREAD_COUNT = 32; \
KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \
STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \
}
int const NUM_GROUPS = H / GROUP_SIZE; Idx_t stride_ys_e = y_s.stride(0);
if (!use_ue8m0) { Idx_t stride_ys_t = y_s.stride(1);
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { Idx_t stride_ys_g = y_s.stride(2);
/* 8 warps config */ Idx_t stride_ys_p = 0;
static constexpr int NUM_STAGES = 4; if (!cast_scale_ue8m0) {
static constexpr int THREAD_COUNT = 256; TORCH_CHECK(!is_packed_ue8m0);
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
} else { false);
/* 1 warp config */ return;
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
}
} else {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
}
} }
if (!is_packed_ue8m0) {
// UE8M0 but not packed
LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
true);
return;
}
TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0);
TORCH_CHECK(y_s.dtype() == torch::kInt32);
// Int32 packed ue8m0 scales tensor.
// Let E, T, G be the number to experts, number of tokens and number of groups
// respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales
// tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected
// to be arranged as follows,
// [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,],
// [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,]
// [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,]
// [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]]
// where, TxGy is the scale ue8m0 scale value of Token x, Group y.
//
// In memory (in bytes) the scale values are arranged as,
// [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4,
// T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5,
// X, X, T3G4, T3G5, X, X]
//
// An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented
// as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In
// english, ignoring the Experts dimension, the original int32 tensor is
// simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32
// tensor). The following strides setting reflects this change. Caveat: This
// means that the G dimension is no longer contiguous. i.e. Note that to move
// from G3 to G4, we need to jump along the packing dimension. The kernel
// handles this case.
stride_ys_e *= sizeof(int32_t);
stride_ys_p = T * sizeof(int32_t); // Packing dimension
stride_ys_t = sizeof(int32_t);
stride_ys_g = 1;
LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p,
true);
#endif #endif
} }

View File

@ -1384,3 +1384,16 @@ def image_urls(request, local_asset_server) -> list[str]:
"""Indirect fixture: takes a list of names, returns list of full URLs.""" """Indirect fixture: takes a list of names, returns list of full URLs."""
names: list[str] = request.param names: list[str] = request.param
return [local_asset_server.url_for(name) for name in names] return [local_asset_server.url_for(name) for name in names]
@pytest.fixture
def disable_deepgemm_ue8m0(monkeypatch):
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
with monkeypatch.context() as monkeypatch_ctx:
monkeypatch_ctx.setenv("VLLM_USE_DEEP_GEMM_E8M0", "0")
is_deep_gemm_e8m0_used.cache_clear()
yield
# Clear cache so the next time it is used it is processed with the
# default VLLM_USE_DEEP_GEMM_E8M0 setting.
is_deep_gemm_e8m0_used.cache_clear()

View File

@ -21,7 +21,11 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from ...utils import multi_gpu_test from ...utils import multi_gpu_test
@ -413,19 +417,16 @@ NUM_EXPERTS = [32]
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@requires_deep_ep @requires_deep_ep
@requires_deep_gemm @requires_deep_gemm
@pytest.mark.skipif(
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
)
def test_ht_deepep_deepgemm_moe( def test_ht_deepep_deepgemm_moe(
mnk: tuple[int, int, int], mnk: tuple[int, int, int],
num_experts: int, num_experts: int,
topk: int, topk: int,
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
disable_deepgemm_ue8m0,
): ):
""" """
Tests for High-Throughput DeepEP + DeepGemm integration. Tests for High-Throughput DeepEP + DeepGemm integration.
""" """
import deep_gemm
m, n, k = mnk m, n, k = mnk
current_platform.seed_everything(7) current_platform.seed_everything(7)
@ -433,7 +434,7 @@ def test_ht_deepep_deepgemm_moe(
if topk > num_experts: if topk > num_experts:
pytest.skip(f"Skipping test: topk={topk} > E={num_experts}") pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")
block_m = deep_gemm.get_m_alignment_for_contiguous_layout() block_m = get_mk_alignment_for_contiguous_layout()[0]
block_size = [block_m, block_m] block_size = [block_m, block_m]
world_size, dp_size = world_dp_size world_size, dp_size = world_dp_size
@ -487,9 +488,6 @@ USE_FP8_DISPATCH = [False]
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@requires_deep_ep @requires_deep_ep
@requires_deep_gemm @requires_deep_gemm
@pytest.mark.skipif(
is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM"
)
def test_ll_deepep_deepgemm_moe( def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int], mnk: tuple[int, int, int],
num_experts: int, num_experts: int,
@ -497,10 +495,12 @@ def test_ll_deepep_deepgemm_moe(
use_fp8_dispatch: bool, use_fp8_dispatch: bool,
block_size: list[int], block_size: list[int],
world_dp_size: tuple[int, int], world_dp_size: tuple[int, int],
disable_deepgemm_ue8m0,
): ):
""" """
Tests for Low-Latency DeepEP + DeepGemm integration. Tests for Low-Latency DeepEP + DeepGemm integration.
""" """
assert not is_deep_gemm_e8m0_used()
m, n, k = mnk m, n, k = mnk
current_platform.seed_everything(7) current_platform.seed_everything(7)

View File

@ -294,7 +294,7 @@ def torch_moe_impl(
# blockwise quant and de-quant. # blockwise quant and de-quant.
assert not per_act_token_quant assert not per_act_token_quant
a = test_tensors.rank_tokens a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128) aq, aq_scale = per_token_group_quant_fp8(a, 128, use_ue8m0=False)
a = ( a = (
(aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)) (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1))
.view(a.shape) .view(a.shape)

View File

@ -1,6 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest import pytest
import torch import torch
@ -8,27 +11,30 @@ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
persistent_masked_m_silu_mul_quant, persistent_masked_m_silu_mul_quant,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
from vllm.utils.math_utils import cdiv, round_up
fp8_dtype = torch.float8_e4m3fn fp8_dtype = torch.float8_e4m3fn
CASES = [ CASES = [
(1, 1, 128, fp8_dtype), (1, 1, 128, fp8_dtype),
(1, 4, 128, fp8_dtype), (1, 4, 128 * 1, fp8_dtype),
(2, 4, 256, fp8_dtype), (2, 4, 128 * 2, fp8_dtype),
(32, 64, 256, fp8_dtype), (1, 4, 128 * 3, fp8_dtype),
(17, 31, 768, fp8_dtype), (8, 16, 128 * 4, fp8_dtype),
(1, 1, 128 * 1, fp8_dtype), (8, 16, 128 * 5, fp8_dtype),
(1, 1, 128 * 3, fp8_dtype), (8, 16, 128 * 6, fp8_dtype),
(1, 1, 128 * 4, fp8_dtype), (8, 16, 128 * 7, fp8_dtype),
(8, 16, 128 * 1, fp8_dtype), (8, 16, 128 * 8, fp8_dtype),
(8, 16, 128 * 2, fp8_dtype), (8, 16, 128 * 9, fp8_dtype),
(8, 16, 128 * 3, fp8_dtype),
(8, 64, 7168, fp8_dtype), (8, 64, 7168, fp8_dtype),
(8, 128, 128 * 33, fp8_dtype), (8, 128, 128 * 33, fp8_dtype),
(1, 4, 128 * 10, fp8_dtype),
(8, 128, 7168, fp8_dtype), (8, 128, 7168, fp8_dtype),
(8, 512, 7168, fp8_dtype), (8, 512, 7168, fp8_dtype),
(8, 1024, 7168, fp8_dtype), (8, 1024, 7168, fp8_dtype),
(17, 31, 768, fp8_dtype),
(32, 64, 256, fp8_dtype),
(256, 8, 7168, fp8_dtype), (256, 8, 7168, fp8_dtype),
(256, 32, 7168, fp8_dtype), (256, 32, 7168, fp8_dtype),
(256, 64, 7168, fp8_dtype), (256, 64, 7168, fp8_dtype),
@ -38,14 +44,159 @@ CASES = [
] ]
def as_uint8(x) -> torch.Tensor:
return (
torch.empty(x.shape, dtype=x.dtype, device=x.device).copy_(x).view(torch.uint8)
)
def silu(x: torch.Tensor) -> torch.Tensor:
one_f32 = torch.tensor([1.0], device=x.device, dtype=torch.float32)
x_f32 = x.to(torch.float32)
act_f32 = x_f32 / (one_f32 + torch.exp(-x_f32))
assert act_f32.dtype == torch.float32
return act_f32.to(torch.bfloat16)
def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool):
eps_bf16 = torch.tensor([1e-10], device=x.device, dtype=torch.bfloat16)
one_bf16 = torch.tensor([1.0], device=x.device, dtype=torch.bfloat16)
fp8_max_bf16 = torch.tensor(
[torch.finfo(fp8_dtype).max], device=x.device, dtype=torch.bfloat16
)
fp8_min_bf16 = torch.tensor(
[torch.finfo(fp8_dtype).min], device=x.device, dtype=torch.bfloat16
)
fp8_max_inv = one_bf16 / fp8_max_bf16
assert fp8_max_inv.dtype == torch.bfloat16
assert x.size(-1) % group_size == 0
num_groups = x.numel() // group_size
x_og_shape = x.shape
x = x.to(torch.bfloat16)
x = x.view((-1, group_size))
amax = x.abs().amax(dim=1).clamp(min=eps_bf16)
assert amax.dtype == torch.bfloat16
s = amax * fp8_max_inv
if ceil_ue8m0:
s = torch.exp2(
torch.ceil(torch.log2(s).to(torch.bfloat16)).to(torch.bfloat16)
).to(torch.bfloat16)
inv_s = one_bf16 / s
inv_s = inv_s.view((num_groups, 1))
xq = torch.clamp(x * inv_s, min=fp8_min_bf16.item(), max=fp8_max_bf16.item()).to(
fp8_dtype
)
xq = xq.view(x_og_shape)
xs = s.view((-1, xq.size(-1) // group_size))
return xq, xs
def silu_mul_quant(
gate: torch.Tensor, up: torch.Tensor, group_size: int, ceil_ue8m0: bool
) -> tuple[torch.Tensor, torch.Tensor]:
assert gate.size(-1) % group_size == 0
assert up.size(-1) % group_size == 0
assert gate.dtype == torch.bfloat16
assert up.dtype == torch.bfloat16
act_bf16 = silu(gate)
assert act_bf16.dtype == torch.bfloat16
# act & mul
a_m = act_bf16 * up
assert a_m.dtype == torch.bfloat16
q, s = do_quant(a_m, group_size, ceil_ue8m0)
return q, s
def pack_scales(x: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Tensor:
"""
pack float32 scales into a int32 tensor
"""
assert x.dtype == torch.float32
E, T, G = x.size()
# Add i32_padding here so we can view it as a i32 tensor later on.
i32_padding = round_up(G, 4) - G
ref_s_i8 = torch.empty((E, T, G + i32_padding), dtype=torch.uint8, device="cuda")
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s_i8[e, :nt, :G] = x[e, :nt].view(torch.int32) >> 23
ref_s_i32 = ref_s_i8.view(torch.int32)
return ref_s_i32
def ref_with_scale_fmt(
E: int,
T: int,
H: int,
group_size: int,
tokens_per_expert: torch.Tensor,
gate: torch.Tensor,
up: torch.Tensor,
scale_fmt: DeepGemmQuantScaleFMT,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
The precision types of the operations triggered by this function
match closely with the kernel implementation so we compare more
accurately.
"""
scale_dtype = (
torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32
)
ceil_ue8m0 = scale_fmt in [
DeepGemmQuantScaleFMT.UE8M0,
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
]
ref_q = torch.empty((E, T, H), dtype=fp8_dtype, device="cuda")
ref_s_f32 = torch.empty(
(E, T, cdiv(H, group_size)), dtype=torch.float32, device="cuda"
)
for e in range(E):
nt = tokens_per_expert[e].item()
if nt == 0:
continue
ref_q[e, :nt], ref_s_f32[e, :nt] = silu_mul_quant(
gate[e, :nt], up[e, :nt], group_size, ceil_ue8m0=ceil_ue8m0
)
if scale_dtype == torch.float32:
return ref_q, ref_s_f32
assert scale_dtype == torch.int32
return ref_q, pack_scales(ref_s_f32, tokens_per_expert)
def token_random(E, T, H2, tokens_per_expert):
"""
Initialize each token in a random range so we test a range of
scale values.
"""
y = torch.empty((E, T, H2), dtype=torch.bfloat16, device="cuda")
for e in range(E):
for t in range(tokens_per_expert[e].item()):
exp = random.choice(range(1, 20))
y[e, t].uniform_(-(2**exp), 2**exp)
return y
@pytest.mark.parametrize("E,T,H,fp8_type", CASES) @pytest.mark.parametrize("E,T,H,fp8_type", CASES)
@torch.inference_mode() @torch.inference_mode()
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dtype):
group_size = 128 group_size = 128
current_platform.seed_everything(42) current_platform.seed_everything(42)
# Input tensor of shape (E, T, 2*H)
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
tokens_per_expert = torch.randint( tokens_per_expert = torch.randint(
low=0, low=0,
high=T, high=T,
@ -54,71 +205,83 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
device="cuda", device="cuda",
) )
# Input tensor of shape (E, T, 2*H)
y = token_random(E, T, 2 * H, tokens_per_expert)
gate = y[..., :H].to(torch.bfloat16)
up = y[..., H:].to(torch.bfloat16)
scale_fmts = [
DeepGemmQuantScaleFMT.FLOAT32,
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
DeepGemmQuantScaleFMT.UE8M0,
]
# Run the SiLU V2 kernel # Run the SiLU V2 kernel
# TODO (varun): use_e8m0 is set to false as the reference impl does for scale_fmt in scale_fmts:
# not handle that case. y_q, y_s = persistent_masked_m_silu_mul_quant(
y_q, y_s = persistent_masked_m_silu_mul_quant( y,
y, tokens_per_expert, group_size=group_size, use_ue8m0=False tokens_per_expert,
) group_size=group_size,
quant_scale_fmt=scale_fmt,
torch.cuda.synchronize()
fp8_info = torch.finfo(fp8_dtype)
fp8_max = fp8_info.max
fp8_min = fp8_info.min
eps = 1e-10
y1 = y[..., :H].float()
y2 = y[..., H:]
silu_x = y1 * torch.sigmoid(y1)
merged = silu_x * y2
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s = torch.empty(
(T, cdiv(H, group_size)), dtype=torch.float32, device="cuda"
)
ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda")
for t in range(nt):
data = merged[e, t].float()
ref_q_row = torch.empty_like(data)
# process full groups
n_full_groups = H // group_size
if n_full_groups > 0:
data_grp = data[: n_full_groups * group_size].view(
n_full_groups, group_size
)
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
scale = amax / fp8_max
scaled = data[: n_full_groups * group_size] / scale.repeat_interleave(
group_size
)
ref_q_row[: n_full_groups * group_size] = scaled.clamp(
fp8_min, fp8_max
).to(fp8_dtype)
ref_s[t, :n_full_groups] = scale
# process remainder group
rem = H % group_size
if rem > 0:
data_rem = data[-rem:]
amax = data_rem.abs().amax().clamp(min=eps)
scale = amax / fp8_max
scaled = data_rem / scale
ref_q_row[-rem:] = scaled.clamp(fp8_min, fp8_max).to(fp8_dtype)
ref_s[t, -1] = scale
ref_q[t] = ref_q_row
y_se = y_s[e].float()
y_qe = y_q[e].float()
torch.testing.assert_close(
y_qe[:nt].to(torch.float32),
ref_q[:nt].to(torch.float32),
atol=2,
rtol=2e-1,
) )
torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) ref_y_q, ref_y_s = ref_with_scale_fmt(
E, T, H, group_size, tokens_per_expert, gate, up, scale_fmt=scale_fmt
)
# deepgemm scales transform
dg_scales = None
if (
has_deep_gemm()
and current_platform.has_device_capability(100)
and scale_fmt == DeepGemmQuantScaleFMT.UE8M0
):
from deep_gemm import transform_sf_into_required_layout
_q, _s = ref_with_scale_fmt(
E,
T,
H,
group_size,
tokens_per_expert,
gate,
up,
scale_fmt=DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
)
dg_scales = transform_sf_into_required_layout(
sf=_s,
mn=_q.size(1),
k=_q.size(2),
recipe=(1, 128, 128),
num_groups=_q.size(0),
is_sfa=True,
)
expected_scale_dtype = (
torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32
)
assert y_s.dtype == expected_scale_dtype
assert ref_y_s.dtype == expected_scale_dtype
for e in range(E):
nt = tokens_per_expert[e].item()
torch.testing.assert_close(
y_q[e, :nt].to(torch.float32),
ref_y_q[e, :nt].to(torch.float32),
)
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
G = H // group_size
y_s_sliced = as_uint8(y_s[e])
ref_s_sliced = as_uint8(ref_y_s[e])
torch.testing.assert_close(y_s_sliced[:nt, :G], ref_s_sliced[:nt, :G])
if dg_scales is not None:
dg_sliced = as_uint8(dg_scales[e])
torch.testing.assert_close(y_s_sliced[:nt, :G], dg_sliced[:nt, :G])
else:
torch.testing.assert_close(
y_s[e, :nt],
ref_y_s[e, :nt],
)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@ -13,14 +14,33 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_m_grouped_gemm_nt_masked, fp8_m_grouped_gemm_nt_masked,
get_mk_alignment_for_contiguous_layout, get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
) )
from vllm.utils.math_utils import cdiv
logger = init_logger(__name__) logger = init_logger(__name__)
def scales_shape_stride_dtype(
E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT
) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]:
shape = (E, T, G)
strides = (T * G, 1, T)
if quant_scale_fmt in [
DeepGemmQuantScaleFMT.FLOAT32,
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
]:
return shape, strides, torch.float32
assert quant_scale_fmt == DeepGemmQuantScaleFMT.UE8M0
shape = (E, T, cdiv(G, 4))
strides = (T * cdiv(G, 4), 1, T)
return shape, strides, torch.int32
@triton.jit @triton.jit
def _silu_mul_fp8_quant_deep_gemm( def _silu_mul_fp8_quant_deep_gemm(
# Pointers ------------------------------------------------------------ # Pointers ------------------------------------------------------------
@ -49,7 +69,7 @@ def _silu_mul_fp8_quant_deep_gemm(
eps: tl.constexpr, eps: tl.constexpr,
fp8_min: tl.constexpr, fp8_min: tl.constexpr,
fp8_max: tl.constexpr, fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr, ceil_ue8m0: tl.constexpr,
# Meta --------------------------------------------------------------- # Meta ---------------------------------------------------------------
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
NUM_STAGES: tl.constexpr, NUM_STAGES: tl.constexpr,
@ -86,7 +106,7 @@ def _silu_mul_fp8_quant_deep_gemm(
y = gate * up y = gate * up
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
if use_ue8m0: if ceil_ue8m0:
y_s = tl.exp2(tl.ceil(tl.log2(y_s))) y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
@ -100,7 +120,7 @@ def persistent_masked_m_silu_mul_quant(
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
num_parallel_tokens=16, num_parallel_tokens=16,
group_size: int = 128, group_size: int = 128,
use_ue8m0: bool | None = None, quant_scale_fmt: DeepGemmQuantScaleFMT = DeepGemmQuantScaleFMT.FLOAT32,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is y has shape (E, T, 2*H). The first half of the last dimension is
@ -137,7 +157,13 @@ def persistent_masked_m_silu_mul_quant(
Returns `(y_q, y_s)` where Returns `(y_q, y_s)` where
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) * `y_s` depends on quant_scale_fmt,
- quant_scale_fmt == FLOAT32,
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
- quant_scale_fmt == E8M0,
`y_s`: Int32 tensor, shape (E, T, H // group_size // 4), strides (T*G, 1, T)
- quant_scale_fmt == E8M0_FLOAT32_SPARSE
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
Let NUM_WARPS be the number of warps in a single thread block and Let NUM_WARPS be the number of warps in a single thread block and
`GROUP_SIZE = 128` be the size of the quantization group. `GROUP_SIZE = 128` be the size of the quantization group.
""" """
@ -155,17 +181,18 @@ def persistent_masked_m_silu_mul_quant(
fp8_dtype = torch.float8_e4m3fn fp8_dtype = torch.float8_e4m3fn
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
stride_ys_e = T * G ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt)
stride_ys_t = 1
stride_ys_g = T
y_s = torch.empty_strided( y_s = torch.empty_strided(
(E, T, G), ys_shape,
(stride_ys_e, stride_ys_t, stride_ys_g), ys_strides,
dtype=torch.float32, dtype=ys_dtype,
device=y.device, device=y.device,
) )
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used() ceil_ue8m0 = quant_scale_fmt in [
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
DeepGemmQuantScaleFMT.UE8M0,
]
cuda_arch = current_platform.get_device_capability( cuda_arch = current_platform.get_device_capability(
device_id=y.device.index device_id=y.device.index
@ -173,7 +200,7 @@ def persistent_masked_m_silu_mul_quant(
if cuda_arch >= 80: if cuda_arch >= 80:
torch.ops._C.persistent_masked_m_silu_mul_quant( torch.ops._C.persistent_masked_m_silu_mul_quant(
y, tokens_per_expert, y_q, y_s, use_ue8m0 y, tokens_per_expert, y_q, y_s, ceil_ue8m0
) )
else: else:
stride_cnt_e = tokens_per_expert.stride()[0] stride_cnt_e = tokens_per_expert.stride()[0]
@ -189,6 +216,10 @@ def persistent_masked_m_silu_mul_quant(
fp8_max = f_info.max fp8_max = f_info.max
fp8_min = f_info.min fp8_min = f_info.min
eps: float = 1e-10 eps: float = 1e-10
assert y_s.dtype == torch.float32, (
"_silu_mul_fp8_quant_deep_gemm does"
"not support {y_s.dtype} scales. Only torch.float32 supported."
)
_silu_mul_fp8_quant_deep_gemm[grid]( _silu_mul_fp8_quant_deep_gemm[grid](
y, y,
y_q, y_q,
@ -202,14 +233,14 @@ def persistent_masked_m_silu_mul_quant(
stride_yq_e, stride_yq_e,
stride_yq_t, stride_yq_t,
stride_yq_h, stride_yq_h,
stride_ys_e, ys_strides[0],
stride_ys_t, ys_strides[1],
stride_ys_g, ys_strides[2],
stride_cnt_e, stride_cnt_e,
eps, eps,
fp8_min, fp8_min,
fp8_max, fp8_max,
is_deep_gemm_e8m0_used(), ceil_ue8m0,
BLOCK=group_size, BLOCK=group_size,
NUM_STAGES=4, NUM_STAGES=4,
num_warps=1, num_warps=1,
@ -255,7 +286,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
""" """
DeepGemm supports packed ue8m0 activation scales format in devices == sm100 DeepGemm supports packed ue8m0 activation scales format in devices == sm100
""" """
return current_platform.is_device_capability(100) return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. # Let PrepareAndFinalize::finalize() decide the impl.
@ -329,10 +360,17 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m, expected_m,
) )
quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
a2q, a2q_scale = persistent_masked_m_silu_mul_quant( a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
workspace1, expert_num_tokens workspace1,
expert_num_tokens,
quant_scale_fmt=quant_scale_fmt,
) )
fp8_m_grouped_gemm_nt_masked( fp8_m_grouped_gemm_nt_masked(
(a2q, a2q_scale), (w2, self.w2_scale), output, expert_num_tokens, expected_m (a2q, a2q_scale),
(w2, self.w2_scale),
output,
expert_num_tokens,
expected_m,
) )

View File

@ -9,6 +9,7 @@ import functools
import importlib import importlib
import os import os
from collections.abc import Callable from collections.abc import Callable
from enum import Enum
from typing import Any, NoReturn from typing import Any, NoReturn
import torch import torch
@ -20,6 +21,28 @@ from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
class DeepGemmQuantScaleFMT(Enum):
# Float32 scales in Float32 tensor
FLOAT32 = 0
# Compute float32 scales and ceil the scales to UE8M0.
# Keep the scales in Float32 tensor.
FLOAT32_CEIL_UE8M0 = 1
# Compute float32 scales and ceil the scales to UE8M0.
# Pack the scales into a int32 tensor where each int32
# element contains 4 scale values.
UE8M0 = 2
@staticmethod
def from_oracle() -> "DeepGemmQuantScaleFMT":
if not is_deep_gemm_e8m0_used():
return DeepGemmQuantScaleFMT.FLOAT32
return (
DeepGemmQuantScaleFMT.UE8M0
if current_platform.is_device_capability(100)
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
)
@functools.cache @functools.cache
def is_deep_gemm_supported() -> bool: def is_deep_gemm_supported() -> bool:
"""Return `True` if DeepGEMM is supported on the current platform. """Return `True` if DeepGEMM is supported on the current platform.