diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 2521b2797e2c2..0c3bcf3b64b26 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -279,17 +279,17 @@ __device__ __forceinline__ void token_bounds(int32_t n_tokens, } template + typename scale_t, int THREADS, typename Idx_t, bool CEIL_UE8M0, + int GROUP_SIZE = 128, int NUM_STAGES = 3> __global__ void silu_mul_fp8_quant_deep_gemm_kernel( const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q, - float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert, + scale_t* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert, // sizes Idx_t E, Idx_t T, Idx_t H, // strides (in elements) Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, - Idx_t stride_ys_g, Idx_t stride_counts_e) { + Idx_t stride_ys_g, Idx_t stride_ys_p, Idx_t stride_counts_e) { #ifndef USE_ROCM static constexpr int NUM_WARPS = THREADS / WARP_SIZE; @@ -466,9 +466,22 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel( __nv_fp8x4_e4m3* y_q_base_ptr = reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id; - auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g; + + Idx_t scale_group_offset = 0; + if constexpr (std::is_same::value) { + // packed int32_t format + int pack_id = warp_position_scales / 4; + int scale_in_pack = warp_position_scales % 4; + scale_group_offset = pack_id * stride_ys_p + scale_in_pack * stride_ys_g; + } else { + scale_group_offset = warp_position_scales * stride_ys_g; + } + + scale_t* const y_scale_base_ptr = _y_s + scale_group_offset; for (auto j = tokens_lower; j < tokens_upper; j++) { + int current_group_id = warp_position_scales; // Running count of which + // group is being processed const Idx_t base_ys = expert_id * stride_ys_e; auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t; __nv_fp8x4_e4m3* y_q_ptr = @@ -509,7 +522,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel( __nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv); - if constexpr (USE_UE8M0) { + if constexpr (CEIL_UE8M0) { y_s = hexp2(hceil(hlog2(y_s))); } @@ -527,8 +540,24 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel( y_q_ptr += WARP_SIZE * stride_yq_h; if (!lane_id) { - *y_s_ptr = y_s; - y_s_ptr += stride_ys_g; + // Store scales. + if constexpr (std::is_same::value) { + // Packed UE8MO format. Remove Mantissa. + *y_s_ptr = reinterpret_cast(y_s) >> 7; + + bool const jump_pack = (current_group_id + 1) % 4 == 0; + // Minus 3 because we need to get to the first group in the + // next pack. + y_s_ptr += jump_pack ? (stride_ys_p - 3) : stride_ys_g; + + } else { + // float32 format + static_assert(std::is_same::value); + *y_s_ptr = y_s; + y_s_ptr += stride_ys_g; + } + + current_group_id += 1; } } } @@ -573,7 +602,7 @@ void persistent_masked_m_silu_mul_quant( const at::Tensor& tokens_per_expert, // (E) at::Tensor& y_q, // (E, T, H) [OUT] at::Tensor& y_s, // (E, T, H//group_size) [OUT] - bool use_ue8m0) { + bool cast_scale_ue8m0) { #ifndef USE_ROCM // This kernel currently only supports H % 128 == 0 and assumes a @@ -583,9 +612,12 @@ void persistent_masked_m_silu_mul_quant( TORCH_CHECK(input.dtype() == torch::kBFloat16); TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || y_q.dtype() == torch::kFloat8_e4m3fnuz); - TORCH_CHECK(y_s.dtype() == torch::kFloat32); TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0); + bool const is_packed_ue8m0 = + (y_s.dtype() == torch::kInt32 && cast_scale_ue8m0); + TORCH_CHECK(y_s.dtype() == torch::kFloat32 || is_packed_ue8m0); + using Idx_t = int64_t; Idx_t E = input.size(0); @@ -597,15 +629,18 @@ void persistent_masked_m_silu_mul_quant( Idx_t stride_yq_e = y_q.stride(0); Idx_t stride_yq_t = y_q.stride(1); Idx_t stride_yq_h = y_q.stride(2); - Idx_t stride_ys_e = y_s.stride(0); - Idx_t stride_ys_t = y_s.stride(1); - Idx_t stride_ys_g = y_s.stride(2); Idx_t stride_counts_e = tokens_per_expert.stride(0); + int const NUM_GROUPS = H / GROUP_SIZE; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - #define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ + // TODO: Get this from cuda_arch ? + static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + + #define KERNEL(BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \ + STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, STAGES) \ static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \ int sms = SILU_V2_BLOCK_COUNT; \ static constexpr int max_shared_mem_bytes = \ @@ -615,43 +650,86 @@ void persistent_masked_m_silu_mul_quant( VLLM_DISPATCH_FP8_TYPES( \ y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \ vllm::silu_mul_fp8_quant_deep_gemm_kernel< \ - BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \ - USE_UE8M0, GROUP_SIZE, STAGES> \ + BLOCK_COUNT, max_shared_mem_bytes, fp8_t, scale_t, THREAD_COUNT, \ + Idx_t, CEIL_UE8M0, GROUP_SIZE, STAGES> \ <<>>( \ reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ - (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + (fp8_t*)y_q.data_ptr(), \ + reinterpret_cast(y_s.data_ptr()), \ reinterpret_cast(tokens_per_expert.data_ptr()), E, \ T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \ - stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \ - stride_ys_g, stride_counts_e); \ + stride_yq_t, stride_yq_h, STRIDE_YS_E, STRIDE_YS_T, \ + STRIDE_YS_G, STRIDE_YS_P, stride_counts_e); \ }); - static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + #define LAUNCH_ON_H(scale_t, STRIDE_YS_E, STRIDE_YS_T, STRIDE_YS_G, \ + STRIDE_YS_P, CEIL_UE8M0) \ + if (H >= 4096 && (NUM_GROUPS % 8) == 0) { \ + /* 8 warp config */ \ + static constexpr int NUM_STAGES = 4; \ + static constexpr int THREAD_COUNT = 256; \ + KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \ + STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, NUM_STAGES); \ + } else { \ + /* 1 warp config */ \ + static constexpr int THREAD_COUNT = 32; \ + KERNEL(SILU_V2_BLOCK_COUNT, scale_t, STRIDE_YS_E, STRIDE_YS_T, \ + STRIDE_YS_G, STRIDE_YS_P, CEIL_UE8M0, THREAD_COUNT, 2); \ + } - int const NUM_GROUPS = H / GROUP_SIZE; - if (!use_ue8m0) { - if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { - /* 8 warps config */ - static constexpr int NUM_STAGES = 4; - static constexpr int THREAD_COUNT = 256; - KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); - } else { - /* 1 warp config */ - static constexpr int THREAD_COUNT = 32; - KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2); - } - } else { - if (H >= 4096 && (NUM_GROUPS % 8 == 0)) { - /* 8 warps config */ - static constexpr int NUM_STAGES = 4; - static constexpr int THREAD_COUNT = 256; - KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES); - } else { - /* 1 warp config */ - static constexpr int THREAD_COUNT = 32; - KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2); - } + Idx_t stride_ys_e = y_s.stride(0); + Idx_t stride_ys_t = y_s.stride(1); + Idx_t stride_ys_g = y_s.stride(2); + Idx_t stride_ys_p = 0; + if (!cast_scale_ue8m0) { + TORCH_CHECK(!is_packed_ue8m0); + LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p, + false); + return; } + if (!is_packed_ue8m0) { + // UE8M0 but not packed + LAUNCH_ON_H(float, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p, + true); + return; + } + + TORCH_CHECK(cast_scale_ue8m0 && is_packed_ue8m0); + TORCH_CHECK(y_s.dtype() == torch::kInt32); + + // Int32 packed ue8m0 scales tensor. + // Let E, T, G be the number to experts, number of tokens and number of groups + // respectively. Let, E = 2, T = 4, G = 6, in this case the int32 scales + // tensor are of shape [1, 4, 2] and stride [8, 1, 4]. The scales are expected + // to be arranged as follows, + // [[T0G0-T0G1-T0G2-T0G3, T0G4-T0G5-X-X,], + // [T1G0-T1G1-T1G2-T1G3, T1G4-T1G5-X-X,] + // [T2G0-T2G1-T2G2-T2G3, T2G4-T2G5-X-X,] + // [T3G0-T3G1-T3G2-T3G3, T3G4-T3G5-X-X,]] + // where, TxGy is the scale ue8m0 scale value of Token x, Group y. + // + // In memory (in bytes) the scale values are arranged as, + // [T0G0, T0G1, T0G2, T0G3, T1G0, T1G2, T1G3, T1G4, T2G0, T2G1, T2G3, T2G4, + // T3G0, T3G1, T3G2, T3G3, T0G4, T0G5, X, X, T1G4, T1G5, X, X, T2G4, T2G5, + // X, X, T3G4, T3G5, X, X] + // + // An Int32 tensor of size [1, 4, 2] and stride [8, 1, 4] can be represented + // as an uint8 tensor of shape [1, 2, 4, 4] and stride [32, 16, 4, 1]. In + // english, ignoring the Experts dimension, the original int32 tensor is + // simply treated as two packed [4, 4] uint8 tensor (or two [4, 1] int32 + // tensor). The following strides setting reflects this change. Caveat: This + // means that the G dimension is no longer contiguous. i.e. Note that to move + // from G3 to G4, we need to jump along the packing dimension. The kernel + // handles this case. + + stride_ys_e *= sizeof(int32_t); + stride_ys_p = T * sizeof(int32_t); // Packing dimension + stride_ys_t = sizeof(int32_t); + stride_ys_g = 1; + + LAUNCH_ON_H(uint8_t, stride_ys_e, stride_ys_t, stride_ys_g, stride_ys_p, + true); + #endif } diff --git a/tests/conftest.py b/tests/conftest.py index 5e127e4e939e6..b17081352edcf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1384,3 +1384,16 @@ def image_urls(request, local_asset_server) -> list[str]: """Indirect fixture: takes a list of names, returns list of full URLs.""" names: list[str] = request.param return [local_asset_server.url_for(name) for name in names] + + +@pytest.fixture +def disable_deepgemm_ue8m0(monkeypatch): + from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + + with monkeypatch.context() as monkeypatch_ctx: + monkeypatch_ctx.setenv("VLLM_USE_DEEP_GEMM_E8M0", "0") + is_deep_gemm_e8m0_used.cache_clear() + yield + # Clear cache so the next time it is used it is processed with the + # default VLLM_USE_DEEP_GEMM_E8M0 setting. + is_deep_gemm_e8m0_used.cache_clear() diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 9d039b81690a1..0faf8bc95d2ec 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -21,7 +21,11 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform -from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported +from vllm.utils.deep_gemm import ( + get_mk_alignment_for_contiguous_layout, + is_deep_gemm_e8m0_used, + is_deep_gemm_supported, +) from vllm.utils.import_utils import has_deep_ep, has_deep_gemm from ...utils import multi_gpu_test @@ -413,19 +417,16 @@ NUM_EXPERTS = [32] @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif( - is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM" -) def test_ht_deepep_deepgemm_moe( mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int], + disable_deepgemm_ue8m0, ): """ Tests for High-Throughput DeepEP + DeepGemm integration. """ - import deep_gemm m, n, k = mnk current_platform.seed_everything(7) @@ -433,7 +434,7 @@ def test_ht_deepep_deepgemm_moe( if topk > num_experts: pytest.skip(f"Skipping test: topk={topk} > E={num_experts}") - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_m = get_mk_alignment_for_contiguous_layout()[0] block_size = [block_m, block_m] world_size, dp_size = world_dp_size @@ -487,9 +488,6 @@ USE_FP8_DISPATCH = [False] @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif( - is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM" -) def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], num_experts: int, @@ -497,10 +495,12 @@ def test_ll_deepep_deepgemm_moe( use_fp8_dispatch: bool, block_size: list[int], world_dp_size: tuple[int, int], + disable_deepgemm_ue8m0, ): """ Tests for Low-Latency DeepEP + DeepGemm integration. """ + assert not is_deep_gemm_e8m0_used() m, n, k = mnk current_platform.seed_everything(7) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index b49319a7e6f54..d78b8250463a9 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -294,7 +294,7 @@ def torch_moe_impl( # blockwise quant and de-quant. assert not per_act_token_quant a = test_tensors.rank_tokens - aq, aq_scale = per_token_group_quant_fp8(a, 128) + aq, aq_scale = per_token_group_quant_fp8(a, 128, use_ue8m0=False) a = ( (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)) .view(a.shape) diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 420dbbffaac08..d6b78dd2c2323 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random + import pytest import torch @@ -8,27 +11,30 @@ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( persistent_masked_m_silu_mul_quant, ) from vllm.platforms import current_platform -from vllm.utils.math_utils import cdiv +from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm +from vllm.utils.math_utils import cdiv, round_up fp8_dtype = torch.float8_e4m3fn CASES = [ (1, 1, 128, fp8_dtype), - (1, 4, 128, fp8_dtype), - (2, 4, 256, fp8_dtype), - (32, 64, 256, fp8_dtype), - (17, 31, 768, fp8_dtype), - (1, 1, 128 * 1, fp8_dtype), - (1, 1, 128 * 3, fp8_dtype), - (1, 1, 128 * 4, fp8_dtype), - (8, 16, 128 * 1, fp8_dtype), - (8, 16, 128 * 2, fp8_dtype), - (8, 16, 128 * 3, fp8_dtype), + (1, 4, 128 * 1, fp8_dtype), + (2, 4, 128 * 2, fp8_dtype), + (1, 4, 128 * 3, fp8_dtype), + (8, 16, 128 * 4, fp8_dtype), + (8, 16, 128 * 5, fp8_dtype), + (8, 16, 128 * 6, fp8_dtype), + (8, 16, 128 * 7, fp8_dtype), + (8, 16, 128 * 8, fp8_dtype), + (8, 16, 128 * 9, fp8_dtype), (8, 64, 7168, fp8_dtype), (8, 128, 128 * 33, fp8_dtype), + (1, 4, 128 * 10, fp8_dtype), (8, 128, 7168, fp8_dtype), (8, 512, 7168, fp8_dtype), (8, 1024, 7168, fp8_dtype), + (17, 31, 768, fp8_dtype), + (32, 64, 256, fp8_dtype), (256, 8, 7168, fp8_dtype), (256, 32, 7168, fp8_dtype), (256, 64, 7168, fp8_dtype), @@ -38,14 +44,159 @@ CASES = [ ] +def as_uint8(x) -> torch.Tensor: + return ( + torch.empty(x.shape, dtype=x.dtype, device=x.device).copy_(x).view(torch.uint8) + ) + + +def silu(x: torch.Tensor) -> torch.Tensor: + one_f32 = torch.tensor([1.0], device=x.device, dtype=torch.float32) + x_f32 = x.to(torch.float32) + act_f32 = x_f32 / (one_f32 + torch.exp(-x_f32)) + assert act_f32.dtype == torch.float32 + return act_f32.to(torch.bfloat16) + + +def do_quant(x: torch.Tensor, group_size: int, ceil_ue8m0: bool): + eps_bf16 = torch.tensor([1e-10], device=x.device, dtype=torch.bfloat16) + one_bf16 = torch.tensor([1.0], device=x.device, dtype=torch.bfloat16) + fp8_max_bf16 = torch.tensor( + [torch.finfo(fp8_dtype).max], device=x.device, dtype=torch.bfloat16 + ) + fp8_min_bf16 = torch.tensor( + [torch.finfo(fp8_dtype).min], device=x.device, dtype=torch.bfloat16 + ) + fp8_max_inv = one_bf16 / fp8_max_bf16 + assert fp8_max_inv.dtype == torch.bfloat16 + + assert x.size(-1) % group_size == 0 + num_groups = x.numel() // group_size + x_og_shape = x.shape + + x = x.to(torch.bfloat16) + x = x.view((-1, group_size)) + amax = x.abs().amax(dim=1).clamp(min=eps_bf16) + assert amax.dtype == torch.bfloat16 + s = amax * fp8_max_inv + + if ceil_ue8m0: + s = torch.exp2( + torch.ceil(torch.log2(s).to(torch.bfloat16)).to(torch.bfloat16) + ).to(torch.bfloat16) + + inv_s = one_bf16 / s + inv_s = inv_s.view((num_groups, 1)) + xq = torch.clamp(x * inv_s, min=fp8_min_bf16.item(), max=fp8_max_bf16.item()).to( + fp8_dtype + ) + + xq = xq.view(x_og_shape) + xs = s.view((-1, xq.size(-1) // group_size)) + return xq, xs + + +def silu_mul_quant( + gate: torch.Tensor, up: torch.Tensor, group_size: int, ceil_ue8m0: bool +) -> tuple[torch.Tensor, torch.Tensor]: + assert gate.size(-1) % group_size == 0 + assert up.size(-1) % group_size == 0 + + assert gate.dtype == torch.bfloat16 + assert up.dtype == torch.bfloat16 + + act_bf16 = silu(gate) + assert act_bf16.dtype == torch.bfloat16 + + # act & mul + a_m = act_bf16 * up + assert a_m.dtype == torch.bfloat16 + + q, s = do_quant(a_m, group_size, ceil_ue8m0) + return q, s + + +def pack_scales(x: torch.Tensor, tokens_per_expert: torch.Tensor) -> torch.Tensor: + """ + pack float32 scales into a int32 tensor + """ + assert x.dtype == torch.float32 + E, T, G = x.size() + + # Add i32_padding here so we can view it as a i32 tensor later on. + i32_padding = round_up(G, 4) - G + ref_s_i8 = torch.empty((E, T, G + i32_padding), dtype=torch.uint8, device="cuda") + for e in range(E): + nt = tokens_per_expert[e].item() + ref_s_i8[e, :nt, :G] = x[e, :nt].view(torch.int32) >> 23 + + ref_s_i32 = ref_s_i8.view(torch.int32) + + return ref_s_i32 + + +def ref_with_scale_fmt( + E: int, + T: int, + H: int, + group_size: int, + tokens_per_expert: torch.Tensor, + gate: torch.Tensor, + up: torch.Tensor, + scale_fmt: DeepGemmQuantScaleFMT, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + The precision types of the operations triggered by this function + match closely with the kernel implementation so we compare more + accurately. + """ + scale_dtype = ( + torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32 + ) + ceil_ue8m0 = scale_fmt in [ + DeepGemmQuantScaleFMT.UE8M0, + DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, + ] + + ref_q = torch.empty((E, T, H), dtype=fp8_dtype, device="cuda") + ref_s_f32 = torch.empty( + (E, T, cdiv(H, group_size)), dtype=torch.float32, device="cuda" + ) + + for e in range(E): + nt = tokens_per_expert[e].item() + if nt == 0: + continue + ref_q[e, :nt], ref_s_f32[e, :nt] = silu_mul_quant( + gate[e, :nt], up[e, :nt], group_size, ceil_ue8m0=ceil_ue8m0 + ) + + if scale_dtype == torch.float32: + return ref_q, ref_s_f32 + + assert scale_dtype == torch.int32 + return ref_q, pack_scales(ref_s_f32, tokens_per_expert) + + +def token_random(E, T, H2, tokens_per_expert): + """ + Initialize each token in a random range so we test a range of + scale values. + """ + y = torch.empty((E, T, H2), dtype=torch.bfloat16, device="cuda") + for e in range(E): + for t in range(tokens_per_expert[e].item()): + exp = random.choice(range(1, 20)) + y[e, t].uniform_(-(2**exp), 2**exp) + return y + + @pytest.mark.parametrize("E,T,H,fp8_type", CASES) @torch.inference_mode() -def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): +def test_silu_mul_fp8_quant_deep_gemm(E: int, T: int, H: int, fp8_type: torch.dtype): group_size = 128 current_platform.seed_everything(42) - # Input tensor of shape (E, T, 2*H) - y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( low=0, high=T, @@ -54,71 +205,83 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): device="cuda", ) + # Input tensor of shape (E, T, 2*H) + y = token_random(E, T, 2 * H, tokens_per_expert) + + gate = y[..., :H].to(torch.bfloat16) + up = y[..., H:].to(torch.bfloat16) + + scale_fmts = [ + DeepGemmQuantScaleFMT.FLOAT32, + DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, + DeepGemmQuantScaleFMT.UE8M0, + ] + # Run the SiLU V2 kernel - # TODO (varun): use_e8m0 is set to false as the reference impl does - # not handle that case. - y_q, y_s = persistent_masked_m_silu_mul_quant( - y, tokens_per_expert, group_size=group_size, use_ue8m0=False - ) - - torch.cuda.synchronize() - fp8_info = torch.finfo(fp8_dtype) - fp8_max = fp8_info.max - fp8_min = fp8_info.min - eps = 1e-10 - - y1 = y[..., :H].float() - y2 = y[..., H:] - silu_x = y1 * torch.sigmoid(y1) - merged = silu_x * y2 - - for e in range(E): - nt = tokens_per_expert[e].item() - ref_s = torch.empty( - (T, cdiv(H, group_size)), dtype=torch.float32, device="cuda" - ) - ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda") - - for t in range(nt): - data = merged[e, t].float() - ref_q_row = torch.empty_like(data) - - # process full groups - n_full_groups = H // group_size - if n_full_groups > 0: - data_grp = data[: n_full_groups * group_size].view( - n_full_groups, group_size - ) - amax = data_grp.abs().amax(dim=1).clamp(min=eps) - scale = amax / fp8_max - scaled = data[: n_full_groups * group_size] / scale.repeat_interleave( - group_size - ) - ref_q_row[: n_full_groups * group_size] = scaled.clamp( - fp8_min, fp8_max - ).to(fp8_dtype) - ref_s[t, :n_full_groups] = scale - - # process remainder group - rem = H % group_size - if rem > 0: - data_rem = data[-rem:] - amax = data_rem.abs().amax().clamp(min=eps) - scale = amax / fp8_max - scaled = data_rem / scale - ref_q_row[-rem:] = scaled.clamp(fp8_min, fp8_max).to(fp8_dtype) - ref_s[t, -1] = scale - - ref_q[t] = ref_q_row - - y_se = y_s[e].float() - y_qe = y_q[e].float() - - torch.testing.assert_close( - y_qe[:nt].to(torch.float32), - ref_q[:nt].to(torch.float32), - atol=2, - rtol=2e-1, + for scale_fmt in scale_fmts: + y_q, y_s = persistent_masked_m_silu_mul_quant( + y, + tokens_per_expert, + group_size=group_size, + quant_scale_fmt=scale_fmt, ) - torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) + ref_y_q, ref_y_s = ref_with_scale_fmt( + E, T, H, group_size, tokens_per_expert, gate, up, scale_fmt=scale_fmt + ) + + # deepgemm scales transform + dg_scales = None + if ( + has_deep_gemm() + and current_platform.has_device_capability(100) + and scale_fmt == DeepGemmQuantScaleFMT.UE8M0 + ): + from deep_gemm import transform_sf_into_required_layout + + _q, _s = ref_with_scale_fmt( + E, + T, + H, + group_size, + tokens_per_expert, + gate, + up, + scale_fmt=DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, + ) + dg_scales = transform_sf_into_required_layout( + sf=_s, + mn=_q.size(1), + k=_q.size(2), + recipe=(1, 128, 128), + num_groups=_q.size(0), + is_sfa=True, + ) + + expected_scale_dtype = ( + torch.int32 if scale_fmt == DeepGemmQuantScaleFMT.UE8M0 else torch.float32 + ) + assert y_s.dtype == expected_scale_dtype + assert ref_y_s.dtype == expected_scale_dtype + + for e in range(E): + nt = tokens_per_expert[e].item() + + torch.testing.assert_close( + y_q[e, :nt].to(torch.float32), + ref_y_q[e, :nt].to(torch.float32), + ) + + if scale_fmt == DeepGemmQuantScaleFMT.UE8M0: + G = H // group_size + y_s_sliced = as_uint8(y_s[e]) + ref_s_sliced = as_uint8(ref_y_s[e]) + torch.testing.assert_close(y_s_sliced[:nt, :G], ref_s_sliced[:nt, :G]) + if dg_scales is not None: + dg_sliced = as_uint8(dg_scales[e]) + torch.testing.assert_close(y_s_sliced[:nt, :G], dg_sliced[:nt, :G]) + else: + torch.testing.assert_close( + y_s[e, :nt], + ref_y_s[e, :nt], + ) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 869082f8231d1..79c92eb48612d 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -13,14 +14,33 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import ( + DeepGemmQuantScaleFMT, fp8_m_grouped_gemm_nt_masked, get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, ) +from vllm.utils.math_utils import cdiv logger = init_logger(__name__) +def scales_shape_stride_dtype( + E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT +) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]: + shape = (E, T, G) + strides = (T * G, 1, T) + if quant_scale_fmt in [ + DeepGemmQuantScaleFMT.FLOAT32, + DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, + ]: + return shape, strides, torch.float32 + + assert quant_scale_fmt == DeepGemmQuantScaleFMT.UE8M0 + shape = (E, T, cdiv(G, 4)) + strides = (T * cdiv(G, 4), 1, T) + return shape, strides, torch.int32 + + @triton.jit def _silu_mul_fp8_quant_deep_gemm( # Pointers ------------------------------------------------------------ @@ -49,7 +69,7 @@ def _silu_mul_fp8_quant_deep_gemm( eps: tl.constexpr, fp8_min: tl.constexpr, fp8_max: tl.constexpr, - use_ue8m0: tl.constexpr, + ceil_ue8m0: tl.constexpr, # Meta --------------------------------------------------------------- BLOCK: tl.constexpr, NUM_STAGES: tl.constexpr, @@ -86,7 +106,7 @@ def _silu_mul_fp8_quant_deep_gemm( y = gate * up y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max - if use_ue8m0: + if ceil_ue8m0: y_s = tl.exp2(tl.ceil(tl.log2(y_s))) y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) @@ -100,7 +120,7 @@ def persistent_masked_m_silu_mul_quant( tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert num_parallel_tokens=16, group_size: int = 128, - use_ue8m0: bool | None = None, + quant_scale_fmt: DeepGemmQuantScaleFMT = DeepGemmQuantScaleFMT.FLOAT32, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is @@ -137,7 +157,13 @@ def persistent_masked_m_silu_mul_quant( Returns `(y_q, y_s)` where * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] - * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + * `y_s` depends on quant_scale_fmt, + - quant_scale_fmt == FLOAT32, + `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + - quant_scale_fmt == E8M0, + `y_s`: Int32 tensor, shape (E, T, H // group_size // 4), strides (T*G, 1, T) + - quant_scale_fmt == E8M0_FLOAT32_SPARSE + `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) Let NUM_WARPS be the number of warps in a single thread block and `GROUP_SIZE = 128` be the size of the quantization group. """ @@ -155,17 +181,18 @@ def persistent_masked_m_silu_mul_quant( fp8_dtype = torch.float8_e4m3fn y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) - stride_ys_e = T * G - stride_ys_t = 1 - stride_ys_g = T + ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt) y_s = torch.empty_strided( - (E, T, G), - (stride_ys_e, stride_ys_t, stride_ys_g), - dtype=torch.float32, + ys_shape, + ys_strides, + dtype=ys_dtype, device=y.device, ) - use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used() + ceil_ue8m0 = quant_scale_fmt in [ + DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0, + DeepGemmQuantScaleFMT.UE8M0, + ] cuda_arch = current_platform.get_device_capability( device_id=y.device.index @@ -173,7 +200,7 @@ def persistent_masked_m_silu_mul_quant( if cuda_arch >= 80: torch.ops._C.persistent_masked_m_silu_mul_quant( - y, tokens_per_expert, y_q, y_s, use_ue8m0 + y, tokens_per_expert, y_q, y_s, ceil_ue8m0 ) else: stride_cnt_e = tokens_per_expert.stride()[0] @@ -189,6 +216,10 @@ def persistent_masked_m_silu_mul_quant( fp8_max = f_info.max fp8_min = f_info.min eps: float = 1e-10 + assert y_s.dtype == torch.float32, ( + "_silu_mul_fp8_quant_deep_gemm does" + "not support {y_s.dtype} scales. Only torch.float32 supported." + ) _silu_mul_fp8_quant_deep_gemm[grid]( y, y_q, @@ -202,14 +233,14 @@ def persistent_masked_m_silu_mul_quant( stride_yq_e, stride_yq_t, stride_yq_h, - stride_ys_e, - stride_ys_t, - stride_ys_g, + ys_strides[0], + ys_strides[1], + ys_strides[2], stride_cnt_e, eps, fp8_min, fp8_max, - is_deep_gemm_e8m0_used(), + ceil_ue8m0, BLOCK=group_size, NUM_STAGES=4, num_warps=1, @@ -255,7 +286,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): """ DeepGemm supports packed ue8m0 activation scales format in devices == sm100 """ - return current_platform.is_device_capability(100) + return is_deep_gemm_e8m0_used() and current_platform.is_device_capability(100) def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. @@ -329,10 +360,17 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expected_m, ) + quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle() a2q, a2q_scale = persistent_masked_m_silu_mul_quant( - workspace1, expert_num_tokens + workspace1, + expert_num_tokens, + quant_scale_fmt=quant_scale_fmt, ) fp8_m_grouped_gemm_nt_masked( - (a2q, a2q_scale), (w2, self.w2_scale), output, expert_num_tokens, expected_m + (a2q, a2q_scale), + (w2, self.w2_scale), + output, + expert_num_tokens, + expected_m, ) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 4c15baf7a8f93..b5ab37534dd78 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -9,6 +9,7 @@ import functools import importlib import os from collections.abc import Callable +from enum import Enum from typing import Any, NoReturn import torch @@ -20,6 +21,28 @@ from vllm.utils.import_utils import has_deep_gemm from vllm.utils.math_utils import cdiv +class DeepGemmQuantScaleFMT(Enum): + # Float32 scales in Float32 tensor + FLOAT32 = 0 + # Compute float32 scales and ceil the scales to UE8M0. + # Keep the scales in Float32 tensor. + FLOAT32_CEIL_UE8M0 = 1 + # Compute float32 scales and ceil the scales to UE8M0. + # Pack the scales into a int32 tensor where each int32 + # element contains 4 scale values. + UE8M0 = 2 + + @staticmethod + def from_oracle() -> "DeepGemmQuantScaleFMT": + if not is_deep_gemm_e8m0_used(): + return DeepGemmQuantScaleFMT.FLOAT32 + return ( + DeepGemmQuantScaleFMT.UE8M0 + if current_platform.is_device_capability(100) + else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 + ) + + @functools.cache def is_deep_gemm_supported() -> bool: """Return `True` if DeepGEMM is supported on the current platform.