From 3bbe11cc136373bd4f6c12912dc094dba086fa11 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 21 Aug 2025 17:56:15 -0400 Subject: [PATCH] [Perf] Small optimizations for silu_mul_fp8_quant_deep_gemm (#23265) Signed-off-by: mgoin --- .../kernels/benchmark_silu_mul_fp8_quant.py | 77 +++++++++++++++++++ .../moe/test_silu_mul_fp8_quant_deep_gemm.py | 4 +- .../layers/fused_moe/batched_deep_gemm_moe.py | 58 +++++++------- 3 files changed, 107 insertions(+), 32 deletions(-) create mode 100644 benchmarks/kernels/benchmark_silu_mul_fp8_quant.py diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py new file mode 100644 index 0000000000000..0650cbf3cc18e --- /dev/null +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time + +import torch + +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + silu_mul_fp8_quant_deep_gemm, +) +from vllm.platforms import current_platform + + +def benchmark(E, T, H, G=128, runs=50): + current_platform.seed_everything(42) + y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") + tokens_per_expert = torch.randint( + T // 2, T, size=(E,), dtype=torch.int32, device="cuda" + ) + + # Warmup + for _ in range(10): + silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + torch.cuda.synchronize() + + # Benchmark + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(runs): + silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) + torch.cuda.synchronize() + + avg_time = (time.perf_counter() - start) / runs * 1000 + + # Calculate actual work done (only count valid tokens) + actual_tokens = tokens_per_expert.sum().item() + actual_elements = actual_tokens * H + + # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops + ops_per_element = 8 + total_ops = actual_elements * ops_per_element + gflops = total_ops / (avg_time / 1000) / 1e9 + + # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) + input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs + output_bytes = actual_tokens * H * 1 # H fp8 outputs + scale_bytes = actual_tokens * (H // G) * 4 # scales in float32 + total_bytes = input_bytes + output_bytes + scale_bytes + memory_bw = total_bytes / (avg_time / 1000) / 1e9 + + return avg_time, gflops, memory_bw + + +configs = [ + (8, 32, 1024), + (16, 64, 2048), + (32, 128, 4096), + # DeepSeekV3 Configs + (256, 16, 7168), + (256, 32, 7168), + (256, 64, 7168), + (256, 128, 7168), + (256, 256, 7168), + (256, 512, 7168), + (256, 1024, 7168), +] + +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") +print("-" * 50) + +for E, T, H in configs: + try: + time_ms, gflops, gbps = benchmark(E, T, H) + print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") + except Exception: + print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 673a0aa367948..5a0379dfb4475 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -24,7 +24,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): current_platform.seed_everything(seed) # Input tensor of shape (E, T, 2*H) - y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda") + y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( low=0, high=T, @@ -74,7 +74,7 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): y_se = y_s[e] y_qe = y_q[e] - torch.testing.assert_close(y_se[:nt], ref_s[:nt]) + torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) torch.testing.assert_close( y_qe[:nt].to(torch.float32), ref_q[:nt].to(torch.float32), diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index d9cfe96f7a033..c4d680af932f0 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm( # number of valid tokens for this expert n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) - cols = tl.arange(0, BLOCK) - cols = cols.to(tl.int64) - mask_h = cols < BLOCK + cols = tl.arange(0, BLOCK).to(tl.int64) + mask = cols < BLOCK + + base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h + base_gate_offset = base_input_offset + cols * stride_i_h + base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h + base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + + cols * stride_yq_h) + base_ys_offset = e * stride_ys_e + g * stride_ys_g for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): - base_i_offset = (e * stride_i_e + t * stride_i_t + - g * GROUP_SIZE * stride_i_h) - base_yq_offset = (e * stride_yq_e + t * stride_yq_t + - g * GROUP_SIZE * stride_yq_h) - base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g - - mask = mask_h - x = tl.load(input_ptr + base_i_offset + cols * stride_i_h, - mask=mask, - other=0.0).to(tl.float32) - y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h + - cols * stride_i_h, + gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t, + mask=mask, + other=0.0).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, - other=0.0).to(tl.float32) + other=0.0) - x = x * (1.0 / (1.0 + tl.exp(-x))) - y = x * y2 + gate = gate * (1.0 / (1.0 + tl.exp(-gate))) + y = gate * up + + y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max + if use_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(y_s))) - _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - scale_raw = _absmax / fp8_max - y_s = tl.math.exp2(tl.ceil( - tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) - tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) - tl.store(y_s_ptr + base_ys_offset, y_s) + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) def silu_mul_fp8_quant_deep_gemm( - y: torch.Tensor, # (E, T, 2*H) float32 + y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert group_size: int = 128, eps: float = 1e-10, -): +) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. Returns `(y_q, y_s)` where - * `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`. - * `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)` + * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] + * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) """ assert y.ndim == 3, "y must be (E, T, 2*H)" E, T, H2 = y.shape @@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm( stride_cnt_e = tokens_per_expert.stride()[0] - # static grid over experts and H-groups. + # Static grid over experts and H-groups. # A loop inside the kernel handles the token dim grid = (E * G, ) @@ -178,7 +176,7 @@ def silu_mul_fp8_quant_deep_gemm( fp8_max, is_blackwell_deep_gemm_e8m0_used(), BLOCK=group_size, - NUM_STAGES=8, + NUM_STAGES=4, num_warps=1, )