From 7b03584de8819a870644bc853cf24cd2ff8a9f64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elvir=20Crn=C4=8Devi=C4=87?= Date: Fri, 10 Oct 2025 17:19:53 +0200 Subject: [PATCH] Silu v2 (#25074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: mgoin Signed-off-by: elvircrn Signed-off-by: Elvir Crnčević Co-authored-by: mgoin Co-authored-by: Varun Sundar Rabindranath --- .../kernels/benchmark_silu_mul_fp8_quant.py | 291 ++++++---- csrc/ops.h | 4 +- csrc/quantization/activation_kernels.cu | 542 ++++++++++-------- csrc/torch_bindings.cpp | 10 +- .../moe/test_silu_mul_fp8_quant_deep_gemm.py | 11 +- .../layers/fused_moe/batched_deep_gemm_moe.py | 66 ++- 6 files changed, 519 insertions(+), 405 deletions(-) diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py index c7a4066b39d7..a5887aafd30d 100644 --- a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -1,5 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Comprehensive 3-way SiLU Benchmark Suite + +This benchmark compares three SiLU implementations: +1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation +2. Triton Kernel - Triton-based implementation + +The suite generates detailed performance comparisons including: +- Memory bandwidth utilization +- Speedup ratios (baseline vs optimized implementations) +- Performance across different expert configurations and token distributions +""" + from collections.abc import Callable import matplotlib.pyplot as plt @@ -7,7 +21,7 @@ import numpy as np import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm_cuda, + persistent_masked_m_silu_mul_quant, ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton( num_parallel_tokens, group_size: int = 128, eps: float = 1e-10, + expert_offsets: torch.Tensor = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales @@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton( # Parse generation strategies -strategies = ["uniform", "max_t", "first_t"] +strategies = ["random_imbalanced", "uniform", "max_t"] def benchmark( @@ -195,15 +210,27 @@ def benchmark( current_platform.seed_everything(42 + seed_offset) y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() - if gen_strategy == "uniform": - r = torch.rand(size=(E,), device="cuda") + if gen_strategy == "random_imbalanced": + + def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"): + mean = total_tokens // n_e + min_max = mean // ratio + e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean + e[0] = min_max + r = torch.rand(size=(E - 1,)) + r /= r.sum() + r *= total_tokens - min_max + r = r.round().long() + e[1:] = r.to(device=device) + return e + + tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda") + elif gen_strategy == "uniform": + r = torch.rand(size=(E,)) r /= r.sum() r *= total_tokens - tokens_per_expert = r.int() - tokens_per_expert = torch.minimum( - tokens_per_expert, - torch.ones((E,), device=r.device, dtype=torch.int) * T, - ) + r = r.round().long() + tokens_per_expert = r elif gen_strategy == "max_t": tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda") tokens_per_expert.fill_(total_tokens / E) @@ -281,40 +308,34 @@ def benchmark( def create_comparison_plot( - ratio, cuda_times, baseline_times, config_labels, strategy_name, id + ratios, silu_v2_times, triton_times, config_labels, strategy_name, id ): - """Create a comparison plot for a specific generation strategy""" - fig, ax = plt.subplots(1, 1, figsize=(16, 6)) + fig, ax = plt.subplots(1, 1, figsize=(18, 6)) # Configure x-axis positions x = np.arange(len(config_labels)) - width = 0.35 + width = 0.25 # Execution Time plot (lower is better) + ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue") ax.bar( - x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue" - ) - ax.bar( - x + width / 2, - baseline_times, - width, - label="Baseline", - alpha=0.8, - color="orange", + x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green" ) - # Add speedup labels over each bar pair + # Add speedup labels over each bar trio for i in range(len(x)): - speedup = ratio[i] - max_height = max(cuda_times[i], baseline_times[i]) + triton_v2_speedup = ratios[i][1] # triton/v2 + max_height = max(silu_v2_times[i], triton_times[i]) + + # Triton/V2 speedup ax.text( - x[i], + x[i] + width / 2, max_height + max_height * 0.02, - f"{speedup:.2f}x", + f"{triton_v2_speedup:.2f}x", ha="center", va="bottom", fontweight="bold", - fontsize=9, + fontsize=8, ) ax.set_xlabel("Configuration") @@ -332,56 +353,75 @@ def create_comparison_plot( def create_combined_plot(all_results): - """Create a combined plot with all strategies in one PNG""" num_strategies = len(all_results) - fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies)) + fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies)) if num_strategies == 1: axes = [axes] for idx, ( strategy_name, - ratio, - cuda_times, - baseline_times, + all_ratios, + all_silu_v2_results, + all_triton_results, config_labels, + config_x_axis, ) in enumerate(all_results): ax = axes[idx] + # Flatten the nested results to get bandwidth percentages for plotting + silu_v2_bandwidths = [] + triton_bandwidths = [] + flat_ratios = [] + + for config_results in all_silu_v2_results: + for result in config_results: + silu_v2_bandwidths.append(result[3]) # bandwidth percentage + + for config_results in all_triton_results: + for result in config_results: + triton_bandwidths.append(result[3]) # bandwidth percentage + + for config_ratios in all_ratios: + for ratio in config_ratios: + flat_ratios.append(ratio) + # Configure x-axis positions x = np.arange(len(config_labels)) - width = 0.35 + width = 0.25 - # Execution Time plot (lower is better) + # Bandwidth utilization plot (higher is better) ax.bar( - x - width / 2, - cuda_times, + x, + silu_v2_bandwidths, width, - label="CUDA Kernel", + label="SiLU V2 (CUDA)", alpha=0.8, color="blue", ) ax.bar( - x + width / 2, - baseline_times, + x + width, + triton_bandwidths, width, - label="Baseline", + label="Triton Kernel", alpha=0.8, - color="orange", + color="green", ) - # Add speedup labels over each bar pair + # Add speedup labels over each bar trio for i in range(len(x)): - speedup = ratio[i] - max_height = max(cuda_times[i], baseline_times[i]) + triton_v2_speedup = flat_ratios[i] # triton/v2 + max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i]) + + # Triton/V2 speedup ax.text( - x[i], + x[i] + width / 2, max_height + max_height * 0.02, - f"{speedup:.2f}x", + f"{triton_v2_speedup:.2f}x", ha="center", va="bottom", fontweight="bold", - fontsize=9, + fontsize=8, ) ax.set_xlabel("Configuration") @@ -395,7 +435,7 @@ def create_combined_plot(all_results): ax.grid(True, alpha=0.3) plt.tight_layout() - filename = "../../silu_bench/silu_benchmark_combined.png" + filename = "silu_benchmark_combined_3way.png" plt.savefig(filename, dpi=300, bbox_inches="tight") plt.show() @@ -405,7 +445,9 @@ def create_combined_plot(all_results): outer_dim = 7168 configs = [ # DeepSeekV3 Configs + # (1, 56, 7168), (8, 1024, 7168), + # (32, 56, 7168), # DeepSeekV3 Configs (32, 1024, 7168), # DeepSeekV3 Configs @@ -417,6 +459,7 @@ num_warmups = 20 strategy_descriptions = { "uniform": "Uniform Random", + "random_imbalanced": "Imbalanced Random", "max_t": "Even Assignment", "first_t": "experts[0] = T, experts[1:] = 0", } @@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies): print(f"Testing strategy: {strategy_descriptions[strategy]}") print(f"{'=' * 60}") - # Collect benchmark data for both algorithms + # Collect benchmark data for all three algorithms config_labels = [] config_x_axis = [] - all_cuda_results = [] - all_baseline_results = [] + all_silu_v2_results = [] + all_triton_results = [] all_ratios = [] for E, T, H in configs: - total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E] + total_tokens_config = [] + for i in [8, 16, 32, 64, 128, 256, 512]: + if i <= T: + total_tokens_config.append(i * E) config_x_axis.append(total_tokens_config) - cuda_results = [] - baseline_results = [] + silu_v2_results = [] + triton_results = [] ratios = [] for total_tokens in total_tokens_config: config_label = f"E={E},T={T},H={H},TT={total_tokens}" config_labels.append(config_label) - # CUDA kernel results - time_ms_cuda, gflops, gbps, perc = benchmark( - silu_mul_fp8_quant_deep_gemm_cuda, + # SiLU V2 (CUDA kernel) results + time_ms_silu_v2, gflops, gbps, perc = benchmark( + persistent_masked_m_silu_mul_quant, E, T, H, @@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies): num_warmups=num_warmups, gen_strategy=strategy, ) - cuda_results.append((time_ms_cuda, gflops, gbps, perc)) + silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc)) - # Baseline results + # Triton kernel results time_ms_triton, gflops, gbps, perc = benchmark( silu_mul_fp8_quant_deep_gemm_triton, E, @@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies): num_warmups=num_warmups, gen_strategy=strategy, ) - baseline_results.append((time_ms_triton, gflops, gbps, perc)) - ratios.append(time_ms_triton / time_ms_cuda) + triton_results.append((time_ms_triton, gflops, gbps, perc)) - print(f"Completed: {config_label}") - all_cuda_results.append(cuda_results) - all_baseline_results.append(baseline_results) + # Calculate speedup ratios (triton baseline / implementation) + triton_v2_ratio = time_ms_triton / time_ms_silu_v2 + ratios.append(triton_v2_ratio) + + print( + f"Completed: {config_label}:" + f" V2: {time_ms_silu_v2:.3f}ms," + f" Triton: {time_ms_triton:.3f}ms" + ) + + all_silu_v2_results.append(silu_v2_results) + all_triton_results.append(triton_results) all_ratios.append(ratios) # Store results for combined plotting @@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies): ( strategy_descriptions[strategy], all_ratios, - all_cuda_results, - all_baseline_results, + all_silu_v2_results, + all_triton_results, config_labels, config_x_axis, ) @@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies): # Print summary table for this strategy print(f"\nSummary Table - {strategy_descriptions[strategy]}:") - print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}") - print("-" * 60) + print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}") + print("-" * 90) for i, (E, T, H) in enumerate(configs): - speedup = baseline_results[i][0] / cuda_results[i][0] + # Get the first result for each config (simplifying for summary) + v2_time = silu_v2_results[i][0] + triton_time = triton_results[i][0] + triton_v2_speedup = triton_time / v2_time config_label = f"E={E:3d},T={T:4d},H={H:4d}" print( - f"{config_label:<20} {cuda_results[i][0]:8.5f} " - f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x" + f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} " + f"{triton_v2_speedup:8.2f}x" ) @@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results): num_strategies = len(all_results) num_configs = len(configs) - # Create side-by-side subplots: 2 columns for speedup and bandwidth percentage fig, axs = plt.subplots( - num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies) + num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies) ) # Add main title to the entire figure fig.suptitle( - "Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)", - fontsize=16, + "Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)", + fontsize=18, fontweight="bold", y=0.98, ) @@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results): ( strategy_name, all_ratios, - all_cuda_results, - all_baseline_results, + all_silu_v2_results, + all_triton_results, config_labels, config_x_axis, ) = result @@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results): ratios = all_ratios[config_idx] total_tokens_values = config_x_axis[config_idx] - # Extract CUDA and Triton bandwidth percentages - cuda_bandwidth_percentages = [ - result[3] for result in all_cuda_results[config_idx] + # Extract speedup ratios + triton_v2_ratios = [ratio for ratio in ratios] + + # Extract bandwidth percentages for all implementations + v2_bandwidth_percentages = [ + result[3] for result in all_silu_v2_results[config_idx] ] triton_bandwidth_percentages = [ - result[3] for result in all_baseline_results[config_idx] + result[3] for result in all_triton_results[config_idx] ] # Plot speedup ratios vs total tokens (left plot) ax_speedup.plot( - total_tokens_values, ratios, "bo-", linewidth=3, markersize=8 + total_tokens_values, + triton_v2_ratios, + "go-", + linewidth=3, + markersize=8, + label="Triton/V2 Speedup", ) ax_speedup.set_title( - f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}", + f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}", fontsize=12, fontweight="bold", ) ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11) + ax_speedup.legend(prop={"weight": "bold"}) ax_speedup.grid(True, alpha=0.3) + # Plot bandwidth utilization (right plot) ax_bandwidth.plot( total_tokens_values, - cuda_bandwidth_percentages, - "ro-", + v2_bandwidth_percentages, + "o-", linewidth=3, markersize=8, - label="CUDA", + label="SiLU V2", + color="blue", ) ax_bandwidth.plot( total_tokens_values, triton_bandwidth_percentages, - "go-", + "o-", linewidth=3, markersize=8, label="Triton", + color="green", ) ax_bandwidth.set_title( f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}", @@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results): for label in ax.get_xticklabels() + ax.get_yticklabels(): label.set_fontweight("bold") - # Add value labels on speedup points - for x, y in zip(total_tokens_values, ratios): + # Add value labels on Triton/V2 speedup points + for x, y in zip(total_tokens_values, triton_v2_ratios): ax_speedup.annotate( f"{y:.2f}x", (x, y), textcoords="offset points", - xytext=(0, 12), - ha="center", - fontsize=10, - fontweight="bold", - bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7), - ) - - # Add value labels on CUDA bandwidth points - for x, y in zip(total_tokens_values, cuda_bandwidth_percentages): - ax_bandwidth.annotate( - f"{y:.1f}%", - (x, y), - textcoords="offset points", - xytext=(0, 12), - ha="center", - fontsize=9, - fontweight="bold", - bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3), - ) - - # Add value labels on Triton bandwidth points - for x, y in zip(total_tokens_values, triton_bandwidth_percentages): - ax_bandwidth.annotate( - f"{y:.1f}%", - (x, y), - textcoords="offset points", xytext=(0, -15), ha="center", fontsize=9, @@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results): plt.tight_layout() plt.subplots_adjust(top=0.93) # Make room for main title - filename = "silu_benchmark_total_tokens.png" + filename = "silu_benchmark_total_tokens_3way.png" plt.savefig(filename, dpi=300, bbox_inches="tight") plt.show() return filename -# Create combined plot with all strategies -combined_plot_filename = create_total_tokens_plot(all_results) +# Create comprehensive 3-way comparison plots +combined_plot_filename = create_combined_plot(all_results) +total_tokens_plot_filename = create_total_tokens_plot(all_results) -print(f"\n{'=' * 60}") -print("Benchmark Complete!") -print(f"Generated combined plot: {combined_plot_filename}") -print(f"{'=' * 60}") +print(f"\n{'=' * 80}") +print("3-Way Benchmark Suite Complete!") +print(f"Generated combined comparison plot: {combined_plot_filename}") +print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}") +print("Compared: SiLU V2 (CUDA), and Triton implementations") +print(f"{'=' * 80}") diff --git a/csrc/ops.h b/csrc/ops.h index 9dd302faf5b8..2a9214e7fb03 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -138,12 +138,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& input_global_scale); #endif -void silu_mul_fp8_quant_deep_gemm_cuda( +void persistent_masked_m_silu_mul_quant( const at::Tensor& input, // (E, T, 2*H) const at::Tensor& counts, // (E) at::Tensor& y_q, // (E, T, H) [OUT] at::Tensor& y_s, // (E, T, H//group_size) [OUT] - int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens); + bool use_ue8m0); void mul_and_silu(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 55da79a12d89..6fcd246f63c5 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -114,13 +114,22 @@ __global__ void act_and_mul_quant_kernel( } __device__ __forceinline__ float silu(float x) { - return (__fdividef(x, (1.f + expf(-x)))); + return __fdividef(x, (1.f + expf(-x))); } __device__ __forceinline__ float2 silu2(float2 x) { return make_float2(silu(x.x), silu(x.y)); } +__device__ __forceinline__ __nv_bfloat162 silu2_v2(float2 x) { +#ifndef USE_ROCM + return make_bfloat162(__float2bfloat16_rn(silu(x.x)), + __float2bfloat16_rn(silu(x.y))); +#else + return __float22bfloat162_rn(make_float2(silu(x.x), silu(x.y))); +#endif +} + #ifndef USE_ROCM __device__ __forceinline__ float warp_max(float v) { static constexpr unsigned FULL_MASK = 0xffffffffu; @@ -223,224 +232,308 @@ constexpr __nv_bfloat16 get_fp8_min() { return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032}); } } -#ifndef USE_ROCM -template +__device__ __forceinline__ int warp_expert_search( + int idx, int n, const Idx_t* __restrict__ input, Idx_t val) { + const Idx_t* input_ptr = input + idx; + int base_offset = 0; + + for (;;) { + bool move_on = (idx < n && *input_ptr <= val); + + unsigned mask = __ballot_sync(0xffffffff, move_on); + + if (mask != 0xffffffffu) { + int last_lane = 31 - __clz(mask); + return base_offset + last_lane; + } + + input_ptr += 32; + base_offset += 32; + idx += 32; + } +} + +template +__device__ __forceinline__ void token_bounds(int32_t n_tokens, + int32_t worker_id, + int32_t& n_tokens_lower, + int32_t& n_tokens_upper) { + if (n_tokens < num_parallel_tokens && worker_id < n_tokens) { + if (worker_id >= num_parallel_tokens) return; + n_tokens_lower = worker_id; + n_tokens_upper = worker_id + 1; + } else { + int32_t chunk_size = n_tokens / num_parallel_tokens; + int32_t residual = n_tokens - chunk_size * num_parallel_tokens; + auto calc_id = [&](int32_t id) { + if (id < residual) + return min(n_tokens, id * (chunk_size + 1)); + else + return min(n_tokens, id * chunk_size + residual); + }; + n_tokens_lower = calc_id(worker_id); + n_tokens_upper = calc_id(worker_id + 1); + } +} + +template __global__ void silu_mul_fp8_quant_deep_gemm_kernel( const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q, - float* __restrict__ _y_s, const int32_t* __restrict__ counts, - + float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert, // sizes - int H, int G, - + Idx_t E, Idx_t T, Idx_t H, // strides (in elements) Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, Idx_t stride_ys_g, Idx_t stride_counts_e) { +#ifndef USE_ROCM + static constexpr int NUM_WARPS = THREADS / WARP_SIZE; + + static constexpr int LOAD_STAGE_SIZE = 2 * GROUP_SIZE / 8; + static constexpr int LOAD_STAGE_MOD = NUM_STAGES * LOAD_STAGE_SIZE; + + static constexpr int COMPUTE_STAGE_SIZE = 2 * GROUP_SIZE / 4; + static constexpr int COMPUTE_STAGE_MOD = COMPUTE_STAGE_SIZE * NUM_STAGES; + + extern __shared__ __align__(16) __int128_t smem_128[]; + + int* s_expert_offsets = + reinterpret_cast(smem_128 + (SMEM_SIZE_BYTES_Y / 16)); + static constexpr __nv_bfloat16 fp8_min = get_fp8_min(); static constexpr __nv_bfloat16 fp8_max = get_fp8_max(); - // We assign EPS with its 16-bit unsigned counterpart to allow constexpr. + // We assign EPS with it's 16-bit unsigned counterpart to allow constexpr. static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996}); + int tid = threadIdx.x; + int warp_id = tid >> 5; + int lane_id = tid & 0x1f; - // We pack 8 16-bit bfloat16 values into a 128-bit __int128_t. - static constexpr int32_t BFLOAT16_PER_GROUP = 8; + int running_sum{}; + if (!warp_id) { + for (int i = 0; i < E; i += WARP_SIZE) { + bool valid = (i + threadIdx.x) < E; + int value = + (valid ? tokens_per_expert[i + threadIdx.x * stride_counts_e] : 0) + + (!lane_id ? running_sum : 0); - // We split the shared memory in half, corresponding to gate and up matrices: - // [...gate_i, ...up_i] where 0 <= i < stages. - static constexpr int32_t S_NUM_128 = - 2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES; - static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE; - static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2; - static constexpr int32_t S_NUM_64 = S_NUM_128 * 2; - __shared__ __int128_t __align__(16) s_buff_128[S_NUM_128]; + for (int offset = 1; offset < 32; offset *= 2) { + int n = __shfl_up_sync(0xFFFFFFFFu, value, offset); + if (lane_id >= offset) value += n; + } - const int32_t tid = threadIdx.x; - const int32_t warp_id = tid / WARP_SIZE; - const int32_t lane_id = tid % WARP_SIZE; + if (valid) { + s_expert_offsets[i + threadIdx.x + 1] = value; + } - auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128); + running_sum = __shfl_sync(0xFFFFFFFFu, value, WARP_SIZE - 1); + } - // block handles one (expert e, group g) - int32_t pid = blockIdx.x; - int32_t e = pid / G; - int32_t g = pid % G; - - const int32_t n_tokens = counts[e * stride_counts_e]; - - if (!n_tokens) { - return; // Exit ASAP. + if (!lane_id) { + s_expert_offsets[0] = 0; + } } - const Idx_t stride_i_t_128 = stride_i_t / 8u; + __syncthreads(); - int32_t n_tokens_lower, n_tokens_upper; + int32_t total_tokens = s_expert_offsets[E]; + const int warp_position_yq = warp_id * (H / NUM_WARPS); + const int warp_position_scales = warp_id * (H / (GROUP_SIZE * NUM_WARPS)); + + // A single block will handle tokens_per_block tokens. // Each block i iterates over tokens of a slice of n_tokens = // expert_counts[i], with the size of chunk being // (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of // updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling. - if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) { - // Specialize this, but can be likely fused. - if (blockIdx.y >= NUM_PARALLEL_TOKENS) { - return; - } - n_tokens_lower = blockIdx.y; - n_tokens_upper = blockIdx.y + 1; - } else { - auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS; - auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS; - auto calc_id = [&](int32_t id) { - if (id < residual) { - return min(n_tokens, id * (chunk_size + 1)); - } else { - return min(n_tokens, id * chunk_size + residual); - } - }; - n_tokens_lower = calc_id(blockIdx.y); - n_tokens_upper = calc_id(blockIdx.y + 1); - } - if (n_tokens_lower >= n_tokens_upper) { + // Each warp will get space to store its hidden dim for gate and up. + __int128_t* s_hidden_load = smem_128 + warp_id * ((2 * 128 / 8) * NUM_STAGES); + __int128_t* smem_load_ptr = s_hidden_load + lane_id; + + const __nv_bfloat16 fp8_inv = __hdiv(__float2bfloat16(1.f), fp8_max); + + int32_t compute_pipeline_offset_64 = 0; + int32_t load_stage_offset{}; + const __nv_bfloat16 one_bf16 = __float2bfloat16_rn(1.f); + + __int64_t* smem_compute_ptr = reinterpret_cast<__int64_t*>(smem_128) + + warp_id * (2 * (GROUP_SIZE / 4) * NUM_STAGES) + + lane_id; + __int64_t* s_gate64_ptr = smem_compute_ptr; + __int64_t* s_up64_ptr = smem_compute_ptr + GROUP_SIZE / 4; + + int tokens_lower, tokens_upper; + + token_bounds(total_tokens, blockIdx.x, tokens_lower, + tokens_upper); + + Idx_t expert_id{}, expert_offset{}, next_expert_offset{}; + int token_id = tokens_lower; + int32_t t_load{}; + + if (token_id < tokens_upper) { + expert_id = warp_expert_search(lane_id, E, s_expert_offsets, token_id); + expert_offset = s_expert_offsets[expert_id]; + next_expert_offset = s_expert_offsets[expert_id + 1]; + } else { + // This thread block has no work to do. return; } - // We do calculations here, using constexpr wherever possible. - const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h; - const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g; - const Idx_t base_yq = - e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h; - Idx_t gate_off_128 = (base_i / static_cast(8u)); - auto input_128_ptr = reinterpret_cast(_input); - auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) + - stride_i_t_128 * n_tokens_lower; - auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u; - auto y_s_ptr = - _y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t; - auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE + - stride_yq_t * n_tokens_lower + 4 * lane_id; - int32_t t_load = n_tokens_lower, load_stage_id = 0; - auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT); - auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u; - int32_t stage_offset{}; + int t_load_bound = H / (GROUP_SIZE * NUM_WARPS); - static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2); - static constexpr int32_t LOAD_STAGE_MOD = - NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2); + Idx_t base_i = ((expert_id * stride_i_e) / 8) + + (token_id - expert_offset) * stride_i_t / 8; + const Idx_t gate_warp_offset = + warp_id * ((stride_i_h * H) / (8 * NUM_WARPS)) + (lane_id & 0b1111); + + const __int128_t* input_128_ptr = + reinterpret_cast(_input) + gate_warp_offset + + ((lane_id < 16) ? 0 : ((H * stride_i_h) / 8)); + __int128_t* load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i); + + auto token_offset = token_id - expert_offset; - // Two halves of all threads in a block conduct global loads for gate and up, - // repsectively. auto load_and_advance_y_pred = [&] { - if (t_load < n_tokens_upper) { - auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset; - auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset; + if (t_load < t_load_bound) { + // Here we are simply continuing to load data + // from the current token. + auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset; // It is very important that LOAD_STAGE_SIZE is constexpr to avoid // unnecessary ALU ops. - stage_offset += LOAD_STAGE_SIZE; - stage_offset %= LOAD_STAGE_MOD; + load_stage_offset += LOAD_STAGE_SIZE; + load_stage_offset %= LOAD_STAGE_MOD; - if (tid < HALF_THREAD_COUNT) { - cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr); - gate_128_ptr += stride_i_t_128; - } else { - cp_async4(s_up_stage_128_staged_ptr, up_128_ptr); - up_128_ptr += stride_i_t_128; - } + cp_async4(smem_load_ptr_staged, load_ptr); + load_ptr += GROUP_SIZE / 8; + ++t_load; + } else if (token_id + 1 < tokens_upper) { + // We loaded everything from the current token, let's move on + // to the next one, and we checked that we have more tokens to load. + ++token_id; + t_load = 0; + if (token_id >= next_expert_offset) { + // We need to find the next expert. + do { + // This is a loop because it's possible + // that some experts are assigned 0 tokens. + // NOTE: We are guaranteed that there's at least + // one more token left so we don't have to check for + // expert_id bounds. + ++expert_id; + // This skips 1 memory read. + expert_offset = next_expert_offset; + next_expert_offset = s_expert_offsets[expert_id + 1]; + } while (next_expert_offset == expert_offset); + + base_i = expert_id * (stride_i_e / 8); + token_offset = 0; + load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i); + } else { + // We remain within the same expert, so just + // move by H/4 __int128_t (2 * H/8). + base_i += stride_yq_t / 4; + token_offset++; + } + + load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i); + + auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset; + + // It is very important that LOAD_STAGE_SIZE is constexpr to avoid + // unnecessary ALU ops. + load_stage_offset += LOAD_STAGE_SIZE; + load_stage_offset %= LOAD_STAGE_MOD; + + cp_async4(smem_load_ptr_staged, load_ptr); + load_ptr += GROUP_SIZE / 8; ++t_load; - ++load_stage_id; } // We fence even if there is nothing to load to simplify pipelining. cp_async_fence(); }; + // We need to warm-up the pipeline. #pragma unroll for (int i = 0; i < NUM_STAGES - 1; i++) { load_and_advance_y_pred(); } - __int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>( - s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) + - lane_id; - __int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2; + __nv_fp8x4_e4m3* y_q_base_ptr = + reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id; + auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g; - static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u; - static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES; + for (auto j = tokens_lower; j < tokens_upper; j++) { + const Idx_t base_ys = expert_id * stride_ys_e; + auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t; + __nv_fp8x4_e4m3* y_q_ptr = + y_q_base_ptr + (expert_id * stride_yq_e + token_offset * stride_yq_t + + warp_position_yq * stride_yq_h) / + 4; + const int COMPUTE_LIMIT = H / (GROUP_SIZE * NUM_WARPS); - int32_t compute_pipeline_offset_64 = 0; + for (int i = 0; i < COMPUTE_LIMIT; i++) { + cp_async_wait(); + __syncthreads(); + load_and_advance_y_pred(); - for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) { - __nv_bfloat162 results_bf162[2]; + __int64_t* gate64_ptr = s_gate64_ptr + compute_pipeline_offset_64; + __int64_t* up64_ptr = s_up64_ptr + compute_pipeline_offset_64; - cp_async_wait(); - __syncthreads(); + // COMPUTE_STAGE_SIZE/MOD must also be constexpr! + compute_pipeline_offset_64 += COMPUTE_STAGE_SIZE; + compute_pipeline_offset_64 %= COMPUTE_STAGE_MOD; - // We double-buffer pipelined loads so that the next load will - // concurrently run with compute without overwrites. - load_and_advance_y_pred(); + __int64_t gate64 = *gate64_ptr; + __int64_t up64 = *up64_ptr; - auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64; - auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64; - - // STAGE_SIZE must also be constexpr! - compute_pipeline_offset_64 += STAGE_SIZE; - compute_pipeline_offset_64 %= STAGE_MOD; - - // Each thread loads (gate/up) 2X 4X bfloat16 values into registers. - __int64_t gate64 = *s_gate_compute_64; - __nv_bfloat162* s_gate_compute_32 = - reinterpret_cast<__nv_bfloat162*>(&gate64); - - __int64_t up64 = *s_up_compute_64; - __nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64); + // Compute + __nv_bfloat162 res[2]; + __nv_bfloat162* s_up_comp = reinterpret_cast<__nv_bfloat162*>(&up64); + __nv_bfloat162* s_gate_comp = reinterpret_cast<__nv_bfloat162*>(&gate64); #pragma unroll - for (int i = 0; i < 2; i++) { - // For silu, we make sure that div is emitted. - float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i])); - results_bf162[i] = __float22bfloat162_rn(gate); - } + for (int32_t k = 0; k < 2; ++k) { + __nv_bfloat162 gate = silu2_v2(__bfloat1622float2(s_gate_comp[k])); + res[k] = __hmul2(gate, s_up_comp[k]); + } + + auto _y_max2 = __hmax2(__habs2(res[0]), __habs2(res[1])); + + _y_max2.x = __hmax(__hmax(_y_max2.x, _y_max2.y), EPS); + + __nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv); + + if constexpr (USE_UE8M0) { + y_s = hexp2(hceil(hlog2(y_s))); + } + + __nv_bfloat16 inv_y = __hdiv(one_bf16, y_s); + + auto y_s2 = make_bfloat162(inv_y, inv_y); #pragma unroll - for (int i = 0; i < 2; i++) { - results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]); - } + for (int32_t k = 0; k < 2; ++k) { + res[k] = clip(__hmul2(res[k], y_s2), __bfloat162bfloat162(fp8_min), + __bfloat162bfloat162(fp8_max)); + } - auto _y_max2 = - __hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1])); + *y_q_ptr = __nv_fp8x4_e4m3(res[0], res[1]); + y_q_ptr += WARP_SIZE * stride_yq_h; - __nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y)); - - // An entire group is assigned to a single warp, so a simple warp reduce - // is used. - __nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max; - - if constexpr (USE_UE8M0) { - y_s = hexp2(hceil(hlog2(y_s))); - } - - auto inv_y = __float2bfloat16_rn(1.f) / y_s; - - auto y_s2 = make_bfloat162(inv_y, inv_y); - - #pragma unroll - for (int32_t i = 0; i < 2; ++i) { - results_bf162[i] = - clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min), - __bfloat162bfloat162(fp8_max)); - } - - auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]); - *reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4; - y_q_ptr += stride_yq_t; - - if (lane_id == 0) { - *y_s_ptr = y_s; - y_s_ptr += stride_ys_t; + if (!lane_id) { + *y_s_ptr = y_s; + y_s_ptr += stride_ys_g; + } } } -} #endif +} } // namespace vllm @@ -475,14 +568,14 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } -void silu_mul_fp8_quant_deep_gemm_cuda( - const at::Tensor& input, // (E, T, 2*H) - const at::Tensor& counts, // (E) - at::Tensor& y_q, // (E, T, H) [OUT] - at::Tensor& y_s, // (E, T, H//group_size) [OUT] - int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) { +void persistent_masked_m_silu_mul_quant( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& tokens_per_expert, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + bool use_ue8m0) { #ifndef USE_ROCM - // This kernel relies heavily on cp.async and fp8 support. + // This kernel currently only supports H % 128 == 0 and assumes a // fixed GROUP_SIZE of 128. TORCH_CHECK(input.dtype() == torch::kBFloat16); @@ -491,10 +584,6 @@ void silu_mul_fp8_quant_deep_gemm_cuda( TORCH_CHECK(y_s.dtype() == torch::kFloat32); TORCH_CHECK(input.size(-1) % 256 == 0); - // Check that num_parallel_tokens is of power of 2 and between 1 and 64. - TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64); - TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1))); - using Idx_t = int64_t; Idx_t E = input.size(0); @@ -510,81 +599,54 @@ void silu_mul_fp8_quant_deep_gemm_cuda( Idx_t stride_ys_t = y_s.stride(1); Idx_t stride_ys_g = y_s.stride(2); - Idx_t stride_counts_e = counts.stride(0); + Idx_t stride_counts_e = tokens_per_expert.stride(0); static constexpr int GROUP_SIZE = 128; - #define KERNEL_FN \ - if (use_ue8m0) { \ - vllm::silu_mul_fp8_quant_deep_gemm_kernel \ - <<>>( \ - reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ - (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ - reinterpret_cast(counts.data_ptr()), H, G, \ - stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \ - stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \ - stride_counts_e); \ - } else { \ - vllm::silu_mul_fp8_quant_deep_gemm_kernel \ - <<>>( \ - reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ - (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ - reinterpret_cast(counts.data_ptr()), H, G, \ - stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \ - stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \ - stride_counts_e); \ - } - - #define KERNEL_CALL_H \ - if (H % (4 * GROUP_SIZE) == 0) { \ - static constexpr int NUM_WARPS = 4; \ - populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \ - KERNEL_FN \ - } else { \ - static constexpr int NUM_WARPS = 1; \ - populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \ - KERNEL_FN \ - } - - #define KERNEL_CALL_TOP_LEVEL \ - if (num_parallel_tokens == 1) { \ - static constexpr int NUM_PARALLEL_TOKENS = 1; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 2) { \ - static constexpr int NUM_PARALLEL_TOKENS = 2; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 4) { \ - static constexpr int NUM_PARALLEL_TOKENS = 4; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 8) { \ - static constexpr int NUM_PARALLEL_TOKENS = 8; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 16) { \ - static constexpr int NUM_PARALLEL_TOKENS = 16; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 32) { \ - static constexpr int NUM_PARALLEL_TOKENS = 32; \ - KERNEL_CALL_H \ - } else if (num_parallel_tokens == 64) { \ - static constexpr int NUM_PARALLEL_TOKENS = 64; \ - KERNEL_CALL_H \ - } - - Idx_t G; - dim3 block, grid; - auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) { - G = H / Idx_t(group_size * num_warps); - grid = dim3(E * G, _num_parallel_tokens); - block = dim3(num_warps * WARP_SIZE); - }; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(), - "silu_mul_fp8_quant_deep_gemm_kernel", - [&] { KERNEL_CALL_TOP_LEVEL }); + + #define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \ + static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \ + int sms = SILU_V2_BLOCK_COUNT; \ + static constexpr int max_shared_mem_bytes = \ + GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \ + dim3 grid(sms), block(THREAD_COUNT); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + VLLM_DISPATCH_FP8_TYPES( \ + y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \ + vllm::silu_mul_fp8_quant_deep_gemm_kernel< \ + BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \ + USE_UE8M0, GROUP_SIZE, STAGES> \ + <<>>( \ + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ + (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + reinterpret_cast(tokens_per_expert.data_ptr()), E, \ + T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \ + stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \ + stride_ys_g, stride_counts_e); \ + }); + + static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32; + + if (!use_ue8m0) { + if (H >= 4096) { + static constexpr int NUM_STAGES = 4; + static constexpr int THREAD_COUNT = 256; + KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES); + } else { + static constexpr int THREAD_COUNT = 32; + KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2); + } + } else { + if (H >= 4096) { + static constexpr int NUM_STAGES = 4; + static constexpr int THREAD_COUNT = 256; + KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES); + } else { + static constexpr int THREAD_COUNT = 32; + KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2); + } + } #endif } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index bef8cdc33f13..a4a9f87b28f1 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -33,11 +33,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #endif ops.def( - "silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! " - "y_q, Tensor! y_s, int group_size, " - "bool use_ue8m0, int num_parallel_tokens) -> ()"); - ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA, - &silu_mul_fp8_quant_deep_gemm_cuda); + "persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! " + "y_q, Tensor! y_s," + "bool use_ue8m0) -> ()"); + ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA, + &persistent_masked_m_silu_mul_quant); ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index b6ca80e97e91..8b3bebb391f2 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -5,7 +5,7 @@ import pytest import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm_cuda, + persistent_masked_m_silu_mul_quant, ) from vllm.platforms import current_platform from vllm.utils import cdiv @@ -50,15 +50,15 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): # Input tensor of shape (E, T, 2*H) y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( - low=T // 2, + low=0, high=T, size=(E,), dtype=torch.int32, device="cuda", ) - # Run the Triton kernel - y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda( + # Run the SiLU V2 kernel + y_q, y_s = persistent_masked_m_silu_mul_quant( y, tokens_per_expert, group_size=group_size ) @@ -115,10 +115,11 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): y_se = y_s[e].float() y_qe = y_q[e].float() - torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) torch.testing.assert_close( y_qe[:nt].to(torch.float32), ref_q[:nt].to(torch.float32), atol=2, rtol=2e-1, ) + + torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 94b18e51da96..35d2dcb91d25 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from math import log2 from typing import Optional import torch @@ -94,7 +93,7 @@ def _silu_mul_fp8_quant_deep_gemm( tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) -def silu_mul_fp8_quant_deep_gemm_cuda( +def persistent_masked_m_silu_mul_quant( y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert num_parallel_tokens=16, @@ -103,9 +102,41 @@ def silu_mul_fp8_quant_deep_gemm_cuda( """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. + We launch a fixed grid of threads to accommodate CUDA graphs. Let `P2` + be a parallelization factor for persistent_masked_m_silu_mul_quant over the + hidden dimension. + + Let `expert_offsets = [0] + [num_tokens.cumsum()]` and + `total_tokens = expert_offsets[-1]`. + persistent_masked_m_silu_mul_quant launches `total_tokens x P2` number of + thread blocks. Each thread block contains `NUM_WARPS` warps. + + Every thread block needs to find it's corresponding expert by warp-parallel scanning + over the `expert_offsets` array. + + The i-th warp in the first thread block processes + `[i * warp_chunk_size, (i + 1) * warp_chunk_size]` groups + sequentially, where `warp_chunk_size = ((H / GROUP_SIZE) / P2) / NUM_WARPS`, + pipelining loads and computes. + + The shared memory layout for 4 warps with a 2-stage pipeline for SiLU V2 + can is visualized like so: + + stage0 stage1 + ┌─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┐ + │gate0│up0│gate1│up1│gate2│up2│gate3│up3│gate0│up0│gate1│up1│gate2│up2│gate3│up3│ + └─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┘ + + with the main difference between V1 and V2 being the global load + stride between warps, and between half-warps. Regarding the latter stride, + we assign the first half warp of every warp for `gate` loads and the second + half-warp to `up` loads. + Returns `(y_q, y_s)` where * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + Let NUM_WARPS be the number of warps in a single thread block and + `GROUP_SIZE = 128` be the size of the quantization group. """ assert y.ndim == 3, "y must be (E, T, 2*H)" E, T, H2 = y.shape @@ -133,30 +164,15 @@ def silu_mul_fp8_quant_deep_gemm_cuda( use_ue8m0 = is_deep_gemm_e8m0_used() - if E <= 16: - max_empirical_parallelism = 64 - elif E <= 32: - max_empirical_parallelism = 16 - else: - max_empirical_parallelism = 4 - - # We never want to launch more than Tx number of threads - # This computes the clip. - num_parallel_tokens = max( - 1, min(max_empirical_parallelism, 2 ** int(log2(min(num_parallel_tokens, T)))) - ) cuda_arch = current_platform.get_device_capability( device_id=y.device.index ).to_int() if cuda_arch >= 80: - torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda( - y, tokens_per_expert, y_q, y_s, group_size, use_ue8m0, num_parallel_tokens + torch.ops._C.persistent_masked_m_silu_mul_quant( + y, tokens_per_expert, y_q, y_s, use_ue8m0 ) else: - # Default to triton if not on cuda or if arch is too old - y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) - stride_cnt_e = tokens_per_expert.stride()[0] # Static grid over experts and H-groups. @@ -166,16 +182,6 @@ def silu_mul_fp8_quant_deep_gemm_cuda( stride_i_e, stride_i_t, stride_i_h = y.stride() stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() - # desired scale strides (elements): (T*G, 1, T) - stride_ys_e = T * G - stride_ys_t = 1 - stride_ys_g = T - y_s = torch.empty_strided( - (E, T, G), - (stride_ys_e, stride_ys_t, stride_ys_g), - dtype=torch.float32, - device=y.device, - ) f_info = torch.finfo(fp8_dtype) fp8_max = f_info.max fp8_min = f_info.min @@ -313,7 +319,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expected_m, ) - a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda( + a2q, a2q_scale = persistent_masked_m_silu_mul_quant( workspace1, expert_num_tokens )