diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py index 0650cbf3cc18e..c7a4066b39d70 100644 --- a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py +++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py @@ -1,77 +1,675 @@ -#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time +from collections.abc import Callable +import matplotlib.pyplot as plt +import numpy as np import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm, + silu_mul_fp8_quant_deep_gemm_cuda, ) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used -def benchmark(E, T, H, G=128, runs=50): - current_platform.seed_everything(42) - y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") - tokens_per_expert = torch.randint( - T // 2, T, size=(E,), dtype=torch.int32, device="cuda" +@triton.jit +def _silu_mul_fp8_quant_deep_gemm( + # Pointers ------------------------------------------------------------ + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp8 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) + # Sizes --------------------------------------------------------------- + H: tl.constexpr, # hidden dimension (per output) + GROUP_SIZE: tl.constexpr, # elements per group (usually 128) + # Strides for input (elements) --------------------------------------- + stride_i_e, + stride_i_t, + stride_i_h, + # Strides for y_q (elements) ----------------------------------------- + stride_yq_e, + stride_yq_t, + stride_yq_h, + # Strides for y_s (elements) ----------------------------------------- + stride_ys_e, + stride_ys_t, + stride_ys_g, + # Stride for counts (elements) + stride_counts_e, + # Numeric params ------------------------------------------------------ + eps: tl.constexpr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + use_ue8m0: tl.constexpr, + # Meta --------------------------------------------------------------- + BLOCK: tl.constexpr, + NUM_STAGES: tl.constexpr, +): + G = H // GROUP_SIZE + + # map program id -> (e, g) + pid = tl.program_id(0) + e = pid // G + g = pid % G + + e = e.to(tl.int64) + g = g.to(tl.int64) + + # number of valid tokens for this expert + n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64) + + cols = tl.arange(0, BLOCK).to(tl.int64) + mask = cols < BLOCK + + base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h + base_gate_offset = base_input_offset + cols * stride_i_h + base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h + base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h + base_ys_offset = e * stride_ys_e + g * stride_ys_g + + for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): + gate = tl.load( + input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0 + ).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0) + + gate = gate * (1.0 / (1.0 + tl.exp(-gate))) + y = gate * up + + y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max + if use_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(y_s))) + + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask) + tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) + + +def silu_mul_fp8_quant_deep_gemm_triton( + y: torch.Tensor, # (E, T, 2*H) + tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + num_parallel_tokens, + group_size: int = 128, + eps: float = 1e-10, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales + + y has shape (E, T, 2*H). The first half of the last dimension is + silu-activated, multiplied by the second half, then quantized into FP8. + + Returns `(y_q, y_s)` where + * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] + * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) + """ + assert y.ndim == 3, "y must be (E, T, 2*H)" + E, T, H2 = y.shape + assert H2 % 2 == 0, "last dim of y must be even (2*H)" + H = H2 // 2 + G = (H + group_size - 1) // group_size + assert H % group_size == 0, "H must be divisible by group_size" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, ( + "tokens_per_expert must be shape (E,)" + ) + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) + + # allocate outputs + fp8_dtype = torch.float8_e4m3fn + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided( + (E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device, ) + stride_cnt_e = tokens_per_expert.stride()[0] + + # Static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G,) + + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + is_deep_gemm_e8m0_used(), + BLOCK=group_size, + NUM_STAGES=4, + num_warps=1, + ) + + return y_q, y_s + + +# Parse generation strategies +strategies = ["uniform", "max_t", "first_t"] + + +def benchmark( + kernel: Callable, + E: int, + T: int, + H: int, + total_tokens: int, + num_parallel_tokens: int = 64, + G: int = 128, + runs: int = 200, + num_warmups: int = 20, + gen_strategy: str = "default", + iterations_per_run: int = 20, +): + def generate_data(seed_offset=0): + """Generate input data with given seed offset""" + current_platform.seed_everything(42 + seed_offset) + y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous() + + if gen_strategy == "uniform": + r = torch.rand(size=(E,), device="cuda") + r /= r.sum() + r *= total_tokens + tokens_per_expert = r.int() + tokens_per_expert = torch.minimum( + tokens_per_expert, + torch.ones((E,), device=r.device, dtype=torch.int) * T, + ) + elif gen_strategy == "max_t": + tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda") + tokens_per_expert.fill_(total_tokens / E) + elif gen_strategy == "first_t": + tokens_per_expert = torch.zeros(size=(E,), dtype=torch.int32, device="cuda") + tokens_per_expert[0] = min(T, total_tokens) + else: + raise ValueError(f"Unknown generation strategy: {gen_strategy}") + return y, tokens_per_expert + + dataset_count = 4 + # Pre-generate different input matrices for each iteration to avoid cache effects + data_sets = [generate_data(i) for i in range(dataset_count)] + # Warmup - for _ in range(10): - silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) - torch.cuda.synchronize() + y, tokens_per_expert = data_sets[0] + for _ in range(num_warmups): + kernel( + y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G + ) + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) # Benchmark - torch.cuda.synchronize() - start = time.perf_counter() + latencies: list[float] = [] for _ in range(runs): - silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) - torch.cuda.synchronize() + torch.cuda.synchronize() - avg_time = (time.perf_counter() - start) / runs * 1000 + start_event.record() + for i in range(iterations_per_run): + y, tokens_per_expert = data_sets[i % dataset_count] + kernel( + y, + tokens_per_expert, + num_parallel_tokens=num_parallel_tokens, + group_size=G, + ) + end_event.record() + end_event.synchronize() - # Calculate actual work done (only count valid tokens) + total_time_ms = start_event.elapsed_time(end_event) + per_iter_time_ms = total_time_ms / iterations_per_run + latencies.append(per_iter_time_ms) + + # Use median instead of average for better outlier handling + median_time_ms = np.median(latencies) + median_time_s = median_time_ms / 1000 + + # Calculate actual work done (using first dataset for consistency) + _, tokens_per_expert = data_sets[0] actual_tokens = tokens_per_expert.sum().item() actual_elements = actual_tokens * H # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops ops_per_element = 8 total_ops = actual_elements * ops_per_element - gflops = total_ops / (avg_time / 1000) / 1e9 + gflops = total_ops / median_time_s / 1e9 # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs output_bytes = actual_tokens * H * 1 # H fp8 outputs scale_bytes = actual_tokens * (H // G) * 4 # scales in float32 total_bytes = input_bytes + output_bytes + scale_bytes - memory_bw = total_bytes / (avg_time / 1000) / 1e9 + memory_bw = total_bytes / median_time_s / 1e9 - return avg_time, gflops, memory_bw + HOPPER_BANDWIDTH_TBPS = 3.35 + return ( + median_time_ms, + gflops, + memory_bw, + (memory_bw / (HOPPER_BANDWIDTH_TBPS * 1024)) * 100, + ) +def create_comparison_plot( + ratio, cuda_times, baseline_times, config_labels, strategy_name, id +): + """Create a comparison plot for a specific generation strategy""" + fig, ax = plt.subplots(1, 1, figsize=(16, 6)) + + # Configure x-axis positions + x = np.arange(len(config_labels)) + width = 0.35 + + # Execution Time plot (lower is better) + ax.bar( + x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue" + ) + ax.bar( + x + width / 2, + baseline_times, + width, + label="Baseline", + alpha=0.8, + color="orange", + ) + + # Add speedup labels over each bar pair + for i in range(len(x)): + speedup = ratio[i] + max_height = max(cuda_times[i], baseline_times[i]) + ax.text( + x[i], + max_height + max_height * 0.02, + f"{speedup:.2f}x", + ha="center", + va="bottom", + fontweight="bold", + fontsize=9, + ) + + ax.set_xlabel("Configuration") + ax.set_ylabel("% Utilization") + ax.set_title( + f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)" + ) + ax.set_xticks(x) + ax.set_xticklabels(config_labels, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + return fig, ax + + +def create_combined_plot(all_results): + """Create a combined plot with all strategies in one PNG""" + num_strategies = len(all_results) + fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies)) + + if num_strategies == 1: + axes = [axes] + + for idx, ( + strategy_name, + ratio, + cuda_times, + baseline_times, + config_labels, + ) in enumerate(all_results): + ax = axes[idx] + + # Configure x-axis positions + x = np.arange(len(config_labels)) + width = 0.35 + + # Execution Time plot (lower is better) + ax.bar( + x - width / 2, + cuda_times, + width, + label="CUDA Kernel", + alpha=0.8, + color="blue", + ) + ax.bar( + x + width / 2, + baseline_times, + width, + label="Baseline", + alpha=0.8, + color="orange", + ) + + # Add speedup labels over each bar pair + for i in range(len(x)): + speedup = ratio[i] + max_height = max(cuda_times[i], baseline_times[i]) + ax.text( + x[i], + max_height + max_height * 0.02, + f"{speedup:.2f}x", + ha="center", + va="bottom", + fontweight="bold", + fontsize=9, + ) + + ax.set_xlabel("Configuration") + ax.set_ylabel("% Utilization") + ax.set_title( + f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)" + ) + ax.set_xticks(x) + ax.set_xticklabels(config_labels, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + filename = "../../silu_bench/silu_benchmark_combined.png" + plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.show() + + return filename + + +outer_dim = 7168 configs = [ - (8, 32, 1024), - (16, 64, 2048), - (32, 128, 4096), # DeepSeekV3 Configs - (256, 16, 7168), - (256, 32, 7168), - (256, 64, 7168), - (256, 128, 7168), - (256, 256, 7168), - (256, 512, 7168), + (8, 1024, 7168), + # DeepSeekV3 Configs + (32, 1024, 7168), + # DeepSeekV3 Configs (256, 1024, 7168), ] -print(f"GPU: {torch.cuda.get_device_name()}") -print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") -print("-" * 50) +runs = 100 +num_warmups = 20 -for E, T, H in configs: - try: - time_ms, gflops, gbps = benchmark(E, T, H) - print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") - except Exception: - print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") +strategy_descriptions = { + "uniform": "Uniform Random", + "max_t": "Even Assignment", + "first_t": "experts[0] = T, experts[1:] = 0", +} + +print(f"GPU: {torch.cuda.get_device_name()}") +print(f"Testing strategies: {', '.join(strategies)}") +print(f"Configurations: {len(configs)} configs") + +all_results = [] + +# Run benchmarks for each strategy +for id, strategy in enumerate(strategies): + print(f"\n{'=' * 60}") + print(f"Testing strategy: {strategy_descriptions[strategy]}") + print(f"{'=' * 60}") + + # Collect benchmark data for both algorithms + config_labels = [] + config_x_axis = [] + all_cuda_results = [] + all_baseline_results = [] + all_ratios = [] + + for E, T, H in configs: + total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E] + config_x_axis.append(total_tokens_config) + + cuda_results = [] + baseline_results = [] + ratios = [] + + for total_tokens in total_tokens_config: + config_label = f"E={E},T={T},H={H},TT={total_tokens}" + config_labels.append(config_label) + + # CUDA kernel results + time_ms_cuda, gflops, gbps, perc = benchmark( + silu_mul_fp8_quant_deep_gemm_cuda, + E, + T, + H, + total_tokens, + runs=runs, + num_warmups=num_warmups, + gen_strategy=strategy, + ) + cuda_results.append((time_ms_cuda, gflops, gbps, perc)) + + # Baseline results + time_ms_triton, gflops, gbps, perc = benchmark( + silu_mul_fp8_quant_deep_gemm_triton, + E, + T, + H, + total_tokens, + runs=runs, + num_warmups=num_warmups, + gen_strategy=strategy, + ) + baseline_results.append((time_ms_triton, gflops, gbps, perc)) + ratios.append(time_ms_triton / time_ms_cuda) + + print(f"Completed: {config_label}") + all_cuda_results.append(cuda_results) + all_baseline_results.append(baseline_results) + all_ratios.append(ratios) + + # Store results for combined plotting + all_results.append( + ( + strategy_descriptions[strategy], + all_ratios, + all_cuda_results, + all_baseline_results, + config_labels, + config_x_axis, + ) + ) + + # Print summary table for this strategy + print(f"\nSummary Table - {strategy_descriptions[strategy]}:") + print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}") + print("-" * 60) + + for i, (E, T, H) in enumerate(configs): + speedup = baseline_results[i][0] / cuda_results[i][0] + config_label = f"E={E:3d},T={T:4d},H={H:4d}" + print( + f"{config_label:<20} {cuda_results[i][0]:8.5f} " + f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x" + ) + + +def create_total_tokens_plot(all_results): + num_strategies = len(all_results) + num_configs = len(configs) + + # Create side-by-side subplots: 2 columns for speedup and bandwidth percentage + fig, axs = plt.subplots( + num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies) + ) + + # Add main title to the entire figure + fig.suptitle( + "Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)", + fontsize=16, + fontweight="bold", + y=0.98, + ) + + # Handle single strategy case + if num_strategies == 1: + axs = axs.reshape(1, -1) + + # Handle single config case + if num_configs == 1: + axs = axs.reshape(-1, 2) + + for strategy_idx, result in enumerate(all_results): + ( + strategy_name, + all_ratios, + all_cuda_results, + all_baseline_results, + config_labels, + config_x_axis, + ) = result + + for config_idx in range(num_configs): + # Speedup plot (left column) + ax_speedup = axs[strategy_idx, config_idx * 2] + # Bandwidth plot (right column) + ax_bandwidth = axs[strategy_idx, config_idx * 2 + 1] + + E, T, H = configs[config_idx] + ratios = all_ratios[config_idx] + total_tokens_values = config_x_axis[config_idx] + + # Extract CUDA and Triton bandwidth percentages + cuda_bandwidth_percentages = [ + result[3] for result in all_cuda_results[config_idx] + ] + triton_bandwidth_percentages = [ + result[3] for result in all_baseline_results[config_idx] + ] + + # Plot speedup ratios vs total tokens (left plot) + ax_speedup.plot( + total_tokens_values, ratios, "bo-", linewidth=3, markersize=8 + ) + ax_speedup.set_title( + f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}", + fontsize=12, + fontweight="bold", + ) + ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) + ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11) + ax_speedup.grid(True, alpha=0.3) + + ax_bandwidth.plot( + total_tokens_values, + cuda_bandwidth_percentages, + "ro-", + linewidth=3, + markersize=8, + label="CUDA", + ) + ax_bandwidth.plot( + total_tokens_values, + triton_bandwidth_percentages, + "go-", + linewidth=3, + markersize=8, + label="Triton", + ) + ax_bandwidth.set_title( + f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}", + fontsize=12, + fontweight="bold", + ) + ax_bandwidth.set_xlabel("Total Tokens", fontweight="bold", fontsize=11) + ax_bandwidth.set_ylabel( + "% of Peak Bandwidth", fontweight="bold", fontsize=11 + ) + ax_bandwidth.legend(prop={"weight": "bold"}) + ax_bandwidth.grid(True, alpha=0.3) + + # Format x-axis labels for both plots + for ax in [ax_speedup, ax_bandwidth]: + ax.set_xticks(total_tokens_values) + ax.set_xticklabels( + [ + f"{tt // 1000}K" if tt >= 1000 else str(tt) + for tt in total_tokens_values + ], + fontweight="bold", + ) + # Make tick labels bold + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_fontweight("bold") + + # Add value labels on speedup points + for x, y in zip(total_tokens_values, ratios): + ax_speedup.annotate( + f"{y:.2f}x", + (x, y), + textcoords="offset points", + xytext=(0, 12), + ha="center", + fontsize=10, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7), + ) + + # Add value labels on CUDA bandwidth points + for x, y in zip(total_tokens_values, cuda_bandwidth_percentages): + ax_bandwidth.annotate( + f"{y:.1f}%", + (x, y), + textcoords="offset points", + xytext=(0, 12), + ha="center", + fontsize=9, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3), + ) + + # Add value labels on Triton bandwidth points + for x, y in zip(total_tokens_values, triton_bandwidth_percentages): + ax_bandwidth.annotate( + f"{y:.1f}%", + (x, y), + textcoords="offset points", + xytext=(0, -15), + ha="center", + fontsize=9, + fontweight="bold", + bbox=dict(boxstyle="round,pad=0.2", facecolor="green", alpha=0.3), + ) + + plt.tight_layout() + plt.subplots_adjust(top=0.93) # Make room for main title + filename = "silu_benchmark_total_tokens.png" + plt.savefig(filename, dpi=300, bbox_inches="tight") + plt.show() + + return filename + + +# Create combined plot with all strategies +combined_plot_filename = create_total_tokens_plot(all_results) + +print(f"\n{'=' * 60}") +print("Benchmark Complete!") +print(f"Generated combined plot: {combined_plot_filename}") +print(f"{'=' * 60}") diff --git a/csrc/ops.h b/csrc/ops.h index 3ecfd2cd9bf3f..c65bf431640d5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -133,6 +133,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& input_global_scale); #endif +void silu_mul_fp8_quant_deep_gemm_cuda( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens); void mul_and_silu(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 8bc2b9bff3d5a..9ddb5af3052fa 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -9,6 +9,26 @@ #include "quantization/fp8/common.cuh" +#include + +#ifndef USE_ROCM + #include + #include + #include +#else + #include + #include + #include + +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16_raw __nv_bfloat16_raw; + +typedef __hip_fp8_e4m3 __nv_fp8_e4m3; +typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3; +#endif + +#include "core/registration.h" namespace vllm { template @@ -87,6 +107,337 @@ __global__ void act_and_mul_quant_kernel( } } } + +__device__ __forceinline__ float silu(float x) { + return (__fdividef(x, (1.f + expf(-x)))); +} + +__device__ __forceinline__ float2 silu2(float2 x) { + return make_float2(silu(x.x), silu(x.y)); +} + +#ifndef USE_ROCM +__device__ __forceinline__ float warp_max(float v) { + static constexpr unsigned FULL_MASK = 0xffffffffu; + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, offset)); + } + return v; +} + +__device__ __forceinline__ __nv_bfloat16 warp_max(__nv_bfloat16 v) { + static constexpr unsigned FULL_MASK = 0xffffffffu; + for (int offset = 1; offset < WARP_SIZE; offset *= 2) { + v = __hmax(v, __shfl_xor_sync(FULL_MASK, v, offset)); + } + return v; +} +#endif + +template +__device__ __forceinline__ void cp_async4(T* _smem_ptr, const U* _glob_ptr) { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + auto smem_ptr = reinterpret_cast(_smem_ptr); + auto glob_ptr = reinterpret_cast(_glob_ptr); + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +#else + _smem_ptr[0] = _glob_ptr[0]; +#endif +} + +__device__ __forceinline__ void cp_async_fence() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.commit_group;\n" ::); +#else +#endif +} + +template +__device__ __forceinline__ void cp_async_wait() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#else +#endif +} + +template <> +__device__ __forceinline__ void cp_async_wait<0>() { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + asm volatile("cp.async.wait_all;\n" ::); +#else +#endif +} + +__device__ __forceinline__ float clip(float v, float mmin, float mmax) { +#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800 + return fminf(mmax, fmaxf(v, mmin)); +#else +#endif +} + +__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v, + __nv_bfloat16 mmin, + __nv_bfloat16 mmax) { + return __hmin(mmax, __hmax(v, mmin)); +} + +__device__ __forceinline__ __nv_bfloat162 clip(__nv_bfloat162 v, + __nv_bfloat162 mmin, + __nv_bfloat162 mmax) { + return __hmin2(mmax, __hmax2(v, mmin)); +} + +// We use the following values for fp8 min/max: +// __nv_fp8_e4m3 = (-448, +448) +// __nv_fp8_e4m3uz = (-240.0, +240.0) +// It is currently assumed that only +template +constexpr __nv_bfloat16 get_fp8_max() { + static_assert(std::is_same_v || + std::is_same_v); + if constexpr (std::is_same_v) { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 17376}); + } else { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 17264}); + } +} + +template +constexpr __nv_bfloat16 get_fp8_min() { + static_assert(std::is_same_v || + std::is_same_v); + if constexpr (std::is_same_v) { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 50144}); + } else { + return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032}); + } +} +#ifndef USE_ROCM +template +__global__ void silu_mul_fp8_quant_deep_gemm_kernel( + const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q, + float* __restrict__ _y_s, const int32_t* __restrict__ counts, + + // sizes + int H, int G, + + // strides (in elements) + Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e, + Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t, + Idx_t stride_ys_g, Idx_t stride_counts_e) { + static constexpr __nv_bfloat16 fp8_min = get_fp8_min(); + static constexpr __nv_bfloat16 fp8_max = get_fp8_max(); + // We assign EPS with its 16-bit unsigned counterpart to allow constexpr. + static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996}); + + // We pack 8 16-bit bfloat16 values into a 128-bit __int128_t. + static constexpr int32_t BFLOAT16_PER_GROUP = 8; + + // We split the shared memory in half, corresponding to gate and up matrices: + // [...gate_i, ...up_i] where 0 <= i < stages. + static constexpr int32_t S_NUM_128 = + 2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES; + static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE; + static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2; + static constexpr int32_t S_NUM_64 = S_NUM_128 * 2; + __shared__ __int128_t __align__(16) s_buff_128[S_NUM_128]; + + const int32_t tid = threadIdx.x; + const int32_t warp_id = tid / WARP_SIZE; + const int32_t lane_id = tid % WARP_SIZE; + + auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128); + + // block handles one (expert e, group g) + int32_t pid = blockIdx.x; + int32_t e = pid / G; + int32_t g = pid % G; + + const int32_t n_tokens = counts[e * stride_counts_e]; + + if (!n_tokens) { + return; // Exit ASAP. + } + + const Idx_t stride_i_t_128 = stride_i_t / 8u; + + int32_t n_tokens_lower, n_tokens_upper; + + // Each block i iterates over tokens of a slice of n_tokens = + // expert_counts[i], with the size of chunk being + // (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of + // updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling. + if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) { + // Specialize this, but can be likely fused. + if (blockIdx.y >= NUM_PARALLEL_TOKENS) { + return; + } + n_tokens_lower = blockIdx.y; + n_tokens_upper = blockIdx.y + 1; + } else { + auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS; + auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS; + auto calc_id = [&](int32_t id) { + if (id < residual) { + return min(n_tokens, id * (chunk_size + 1)); + } else { + return min(n_tokens, id * chunk_size + residual); + } + }; + n_tokens_lower = calc_id(blockIdx.y); + n_tokens_upper = calc_id(blockIdx.y + 1); + } + + if (n_tokens_lower >= n_tokens_upper) { + return; + } + + // We do calculations here, using constexpr wherever possible. + const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h; + const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g; + const Idx_t base_yq = + e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h; + Idx_t gate_off_128 = (base_i / static_cast(8u)); + auto input_128_ptr = reinterpret_cast(_input); + auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) + + stride_i_t_128 * n_tokens_lower; + auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u; + auto y_s_ptr = + _y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t; + auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE + + stride_yq_t * n_tokens_lower + 4 * lane_id; + int32_t t_load = n_tokens_lower, load_stage_id = 0; + auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT); + auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u; + int32_t stage_offset{}; + + static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2); + static constexpr int32_t LOAD_STAGE_MOD = + NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2); + + // Two halves of all threads in a block conduct global loads for gate and up, + // repsectively. + auto load_and_advance_y_pred = [&] { + if (t_load < n_tokens_upper) { + auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset; + auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset; + + // It is very important that LOAD_STAGE_SIZE is constexpr to avoid + // unnecessary ALU ops. + stage_offset += LOAD_STAGE_SIZE; + stage_offset %= LOAD_STAGE_MOD; + + if (tid < HALF_THREAD_COUNT) { + cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr); + gate_128_ptr += stride_i_t_128; + } else { + cp_async4(s_up_stage_128_staged_ptr, up_128_ptr); + up_128_ptr += stride_i_t_128; + } + ++t_load; + ++load_stage_id; + } + // We fence even if there is nothing to load to simplify pipelining. + cp_async_fence(); + }; + + #pragma unroll + for (int i = 0; i < NUM_STAGES - 1; i++) { + load_and_advance_y_pred(); + } + + __int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>( + s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) + + lane_id; + __int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2; + + static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u; + static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES; + + int32_t compute_pipeline_offset_64 = 0; + + for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) { + __nv_bfloat16 y_max_bf16 = EPS; + __nv_bfloat162 results_bf162[2]; + + cp_async_wait(); + __syncthreads(); + + // We double-buffer pipelined loads so that the next load will + // concurrently run with compute without overwrites. + load_and_advance_y_pred(); + + auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64; + auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64; + + // STAGE_SIZE must also be constexpr! + compute_pipeline_offset_64 += STAGE_SIZE; + compute_pipeline_offset_64 %= STAGE_MOD; + + // Each thread loads (gate/up) 2X 4X bfloat16 values into registers. + __int64_t gate64 = *s_gate_compute_64; + __nv_bfloat162* s_gate_compute_32 = + reinterpret_cast<__nv_bfloat162*>(&gate64); + + __int64_t up64 = *s_up_compute_64; + __nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64); + + #pragma unroll + for (int i = 0; i < 2; i++) { + // For silu, we make sure that div is emitted. + float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i])); + results_bf162[i] = __float22bfloat162_rn(gate); + } + + #pragma unroll + for (int i = 0; i < 2; i++) { + results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]); + } + + auto _y_max2 = + __hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1])); + + y_max_bf16 = __hmax(_y_max2.x, _y_max2.y); + + // An entire group is assigned to a single warp, so a simple warp reduce + // is used. + __nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max; + + if constexpr (USE_UE8M0) { + y_s = hexp2(hceil(hlog2(y_s))); + } + + auto inv_y = __float2bfloat16_rn(1.f) / y_s; + + auto y_s2 = make_bfloat162(inv_y, inv_y); + + #pragma unroll + for (int32_t i = 0; i < 2; ++i) { + results_bf162[i] = + clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min), + __bfloat162bfloat162(fp8_max)); + } + + auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]); + *reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4; + y_q_ptr += stride_yq_t; + + if (lane_id == 0) { + *y_s_ptr = y_s; + y_s_ptr += stride_ys_t; + } + } +} +#endif + } // namespace vllm // Launch activation, gating, and quantize kernel. @@ -119,3 +470,117 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d] TORCH_CHECK(input.size(-1) % 2 == 0); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } + +void silu_mul_fp8_quant_deep_gemm_cuda( + const at::Tensor& input, // (E, T, 2*H) + const at::Tensor& counts, // (E) + at::Tensor& y_q, // (E, T, H) [OUT] + at::Tensor& y_s, // (E, T, H//group_size) [OUT] + int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) { +#ifndef USE_ROCM + // This kernel relies heavily on cp.async and fp8 support. + // This kernel currently only supports H % 128 == 0 and assumes a + // fixed GROUP_SIZE of 128. + TORCH_CHECK(input.dtype() == torch::kBFloat16); + TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn || + y_q.dtype() == torch::kFloat8_e4m3fnuz); + TORCH_CHECK(y_s.dtype() == torch::kFloat32); + TORCH_CHECK(input.size(-1) % 256 == 0); + + // Check that num_parallel_tokens is of power of 2 and between 1 and 64. + TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64); + TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1))); + + using Idx_t = int64_t; + + Idx_t E = input.size(0); + Idx_t T = input.size(1); + Idx_t H = input.size(2) / 2; + Idx_t stride_i_e = input.stride(0); + Idx_t stride_i_t = input.stride(1); + Idx_t stride_i_h = input.stride(2); + Idx_t stride_yq_e = y_q.stride(0); + Idx_t stride_yq_t = y_q.stride(1); + Idx_t stride_yq_h = y_q.stride(2); + Idx_t stride_ys_e = y_s.stride(0); + Idx_t stride_ys_t = y_s.stride(1); + Idx_t stride_ys_g = y_s.stride(2); + + Idx_t stride_counts_e = counts.stride(0); + + static constexpr int GROUP_SIZE = 128; + + #define KERNEL_FN \ + if (use_ue8m0) { \ + vllm::silu_mul_fp8_quant_deep_gemm_kernel \ + <<>>( \ + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ + (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + reinterpret_cast(counts.data_ptr()), H, G, \ + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \ + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \ + stride_counts_e); \ + } else { \ + vllm::silu_mul_fp8_quant_deep_gemm_kernel \ + <<>>( \ + reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \ + (fp8_t*)y_q.data_ptr(), y_s.data_ptr(), \ + reinterpret_cast(counts.data_ptr()), H, G, \ + stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \ + stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \ + stride_counts_e); \ + } + + #define KERNEL_CALL_H \ + if (H % (4 * GROUP_SIZE) == 0) { \ + static constexpr int NUM_WARPS = 4; \ + populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \ + KERNEL_FN \ + } else { \ + static constexpr int NUM_WARPS = 1; \ + populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \ + KERNEL_FN \ + } + + #define KERNEL_CALL_TOP_LEVEL \ + if (num_parallel_tokens == 1) { \ + static constexpr int NUM_PARALLEL_TOKENS = 1; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 2) { \ + static constexpr int NUM_PARALLEL_TOKENS = 2; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 4) { \ + static constexpr int NUM_PARALLEL_TOKENS = 4; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 8) { \ + static constexpr int NUM_PARALLEL_TOKENS = 8; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 16) { \ + static constexpr int NUM_PARALLEL_TOKENS = 16; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 32) { \ + static constexpr int NUM_PARALLEL_TOKENS = 32; \ + KERNEL_CALL_H \ + } else if (num_parallel_tokens == 64) { \ + static constexpr int NUM_PARALLEL_TOKENS = 64; \ + KERNEL_CALL_H \ + } + + Idx_t G; + dim3 block, grid; + auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) { + G = H / Idx_t(group_size * num_warps); + grid = dim3(E * G, _num_parallel_tokens); + block = dim3(num_warps * WARP_SIZE); + }; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(), + "silu_mul_fp8_quant_deep_gemm_kernel", + [&] { KERNEL_CALL_TOP_LEVEL }); + +#endif +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c63fa7cffc78d..81aca7b8860d5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #define stride_tag #endif + ops.def( + "silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! " + "y_q, Tensor! y_s, int group_size, " + "bool use_ue8m0, int num_parallel_tokens) -> ()"); + ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA, + &silu_mul_fp8_quant_deep_gemm_cuda); + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 5a0379dfb4475..383b5ebfba9b7 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -5,28 +5,52 @@ import pytest import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm) + silu_mul_fp8_quant_deep_gemm_cuda) from vllm.platforms import current_platform +from vllm.utils import cdiv + +fp8_dtype = torch.float8_e4m3fn -# (E, T, H, group_size, seed) CASES = [ - (1, 1, 128, 64, 0), - (1, 4, 128, 128, 0), - (2, 4, 256, 128, 0), - (32, 64, 256, 128, 0), - (17, 31, 768, 128, 0), + (1, 1, 128, fp8_dtype), + (1, 4, 128, fp8_dtype), + (2, 4, 256, fp8_dtype), + (32, 64, 256, fp8_dtype), + (17, 31, 768, fp8_dtype), + (1, 1, 128 * 1, fp8_dtype), + (1, 1, 128 * 2, fp8_dtype), + (1, 1, 128 * 3, fp8_dtype), + (1, 1, 128 * 4, fp8_dtype), + (8, 16, 128 * 1, fp8_dtype), + (8, 16, 128 * 2, fp8_dtype), + (8, 16, 128 * 3, fp8_dtype), + (8, 16, 128 * 4, fp8_dtype), + (8, 64, 7168, fp8_dtype), + (8, 128, 7168, fp8_dtype), + (8, 256, 7168, fp8_dtype), + (8, 512, 7168, fp8_dtype), + (8, 1024, 7168, fp8_dtype), + (256, 8, 7168, fp8_dtype), + (256, 16, 7168, fp8_dtype), + (256, 32, 7168, fp8_dtype), + (256, 64, 7168, fp8_dtype), + + # Only add a few fnuz tests to help with long CI times. + (8, 512, 7168, torch.float8_e4m3fnuz), + (8, 1024, 7168, torch.float8_e4m3fnuz), ] -@pytest.mark.parametrize("E,T,H,group_size,seed", CASES) +@pytest.mark.parametrize("E,T,H,fp8_type", CASES) @torch.inference_mode() -def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): - current_platform.seed_everything(seed) +def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): + group_size = 128 + current_platform.seed_everything(42) # Input tensor of shape (E, T, 2*H) y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") tokens_per_expert = torch.randint( - low=0, + low=T // 2, high=T, size=(E, ), dtype=torch.int32, @@ -34,45 +58,59 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed): ) # Run the Triton kernel - y_q, y_s = silu_mul_fp8_quant_deep_gemm(y, - tokens_per_expert, - group_size=group_size, - eps=1e-10) + y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(y, + tokens_per_expert, + group_size=group_size) - # Reference implementation - fp8_info = torch.finfo(torch.float8_e4m3fn) + torch.cuda.synchronize() + fp8_info = torch.finfo(fp8_dtype) fp8_max = fp8_info.max fp8_min = fp8_info.min eps = 1e-10 - # Compute silu activation and elementwise multiplication - y1 = y[..., :H] + y1 = y[..., :H].float() y2 = y[..., H:] silu_x = y1 * torch.sigmoid(y1) merged = silu_x * y2 - # Compute reference scales and quantized output, skipping padded tokens for e in range(E): nt = tokens_per_expert[e].item() - ref_s = torch.empty((T, H // group_size), + ref_s = torch.empty((T, cdiv(H, group_size)), dtype=torch.float32, device="cuda") - ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda") + ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda") + for t in range(nt): - data = merged[e, t] - data_grp = data.view(H // group_size, group_size) - amax = data_grp.abs().amax(dim=1).clamp(min=eps) - scale = amax / fp8_max + data = merged[e, t].float() + ref_q_row = torch.empty_like(data) - scaled = data / scale.repeat_interleave(group_size) - clamped = scaled.clamp(fp8_min, fp8_max) - q = clamped.to(torch.float8_e4m3fn) + # process full groups + n_full_groups = H // group_size + if n_full_groups > 0: + data_grp = data[:n_full_groups * group_size].view( + n_full_groups, group_size) + amax = data_grp.abs().amax(dim=1).clamp(min=eps) + scale = amax / fp8_max + scaled = data[:n_full_groups * + group_size] / scale.repeat_interleave(group_size) + ref_q_row[:n_full_groups * group_size] = scaled.clamp( + fp8_min, fp8_max).to(fp8_dtype) + ref_s[t, :n_full_groups] = scale - ref_s[t] = scale - ref_q[t] = q + # process remainder group + rem = H % group_size + if rem > 0: + data_rem = data[-rem:] + amax = data_rem.abs().amax().clamp(min=eps) + scale = amax / fp8_max + scaled = data_rem / scale + ref_q_row[-rem:] = scaled.clamp(fp8_min, fp8_max).to(fp8_dtype) + ref_s[t, -1] = scale - y_se = y_s[e] - y_qe = y_q[e] + ref_q[t] = ref_q_row + + y_se = y_s[e].float() + y_qe = y_q[e].float() torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2) torch.testing.assert_close( diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a5326dfe84f6d..0ab6355f41565 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from math import log2 from typing import Optional import torch @@ -10,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used) @@ -24,35 +26,28 @@ def _silu_mul_fp8_quant_deep_gemm( y_q_ptr, # fp8 quantized activations (E, T, H) y_s_ptr, # 16-bit scales (E, T, G) counts_ptr, # int32 num tokens per expert (E) - # Sizes --------------------------------------------------------------- H: tl.constexpr, # hidden dimension (per output) GROUP_SIZE: tl.constexpr, # elements per group (usually 128) - # Strides for input (elements) --------------------------------------- stride_i_e, stride_i_t, stride_i_h, - # Strides for y_q (elements) ----------------------------------------- stride_yq_e, stride_yq_t, stride_yq_h, - # Strides for y_s (elements) ----------------------------------------- stride_ys_e, stride_ys_t, stride_ys_g, - # Stride for counts (elements) stride_counts_e, - # Numeric params ------------------------------------------------------ eps: tl.constexpr, fp8_min: tl.constexpr, fp8_max: tl.constexpr, use_ue8m0: tl.constexpr, - # Meta --------------------------------------------------------------- BLOCK: tl.constexpr, NUM_STAGES: tl.constexpr, @@ -101,17 +96,15 @@ def _silu_mul_fp8_quant_deep_gemm( tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s) -def silu_mul_fp8_quant_deep_gemm( +def silu_mul_fp8_quant_deep_gemm_cuda( y: torch.Tensor, # (E, T, 2*H) tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert + num_parallel_tokens=16, group_size: int = 128, - eps: float = 1e-10, ) -> tuple[torch.Tensor, torch.Tensor]: """Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales - - y has shape (E, T, 2*H). The first half of the last dimension is + y has shape (E, T, 2*H). The first half of the last dimension is silu-activated, multiplied by the second half, then quantized into FP8. - Returns `(y_q, y_s)` where * `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H] * `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T) @@ -120,22 +113,17 @@ def silu_mul_fp8_quant_deep_gemm( E, T, H2 = y.shape assert H2 % 2 == 0, "last dim of y must be even (2*H)" H = H2 // 2 - G = H // group_size - assert H % group_size == 0, "H must be divisible by group_size" - assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \ - "tokens_per_expert must be shape (E,)" + G = (H + group_size - 1) // group_size + assert H % 8 == 0, "H must be divisible by 8" + assert group_size == 128, "H must be divisible by 8" + assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) - # allocate outputs fp8_dtype = torch.float8_e4m3fn y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) - # strides (elements) - stride_i_e, stride_i_t, stride_i_h = y.stride() - stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() - - # desired scale strides (elements): (T*G, 1, T) stride_ys_e = T * G stride_ys_t = 1 stride_ys_g = T @@ -144,47 +132,86 @@ def silu_mul_fp8_quant_deep_gemm( dtype=torch.float32, device=y.device) - stride_cnt_e = tokens_per_expert.stride()[0] + use_ue8m0 = is_deep_gemm_e8m0_used() - # Static grid over experts and H-groups. - # A loop inside the kernel handles the token dim - grid = (E * G, ) + if E <= 16: + max_empirical_parallelism = 64 + elif E <= 32: + max_empirical_parallelism = 16 + else: + max_empirical_parallelism = 4 - f_info = torch.finfo(fp8_dtype) - fp8_max = f_info.max - fp8_min = f_info.min + # We never want to launch more than Tx number of threads + # This computes the clip. + num_parallel_tokens = max( + 1, + min(max_empirical_parallelism, 2**int(log2(min(num_parallel_tokens, + T))))) + cuda_arch = current_platform.get_device_capability( + device_id=y.device.index).to_int() - _silu_mul_fp8_quant_deep_gemm[grid]( - y, - y_q, - y_s, - tokens_per_expert, - H, - group_size, - stride_i_e, - stride_i_t, - stride_i_h, - stride_yq_e, - stride_yq_t, - stride_yq_h, - stride_ys_e, - stride_ys_t, - stride_ys_g, - stride_cnt_e, - eps, - fp8_min, - fp8_max, - is_deep_gemm_e8m0_used(), - BLOCK=group_size, - NUM_STAGES=4, - num_warps=1, - ) + if cuda_arch >= 80: + torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda(y, tokens_per_expert, + y_q, y_s, group_size, + use_ue8m0, + num_parallel_tokens) + else: + # Default to triton if not on cuda or if arch is too old + y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) + + stride_cnt_e = tokens_per_expert.stride()[0] + + # Static grid over experts and H-groups. + # A loop inside the kernel handles the token dim + grid = (E * G, ) + # strides (elements) + stride_i_e, stride_i_t, stride_i_h = y.stride() + stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() + + # desired scale strides (elements): (T*G, 1, T) + stride_ys_e = T * G + stride_ys_t = 1 + stride_ys_g = T + y_s = torch.empty_strided( + (E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device, + ) + f_info = torch.finfo(fp8_dtype) + fp8_max = f_info.max + fp8_min = f_info.min + eps: float = 1e-10 + _silu_mul_fp8_quant_deep_gemm[grid]( + y, + y_q, + y_s, + tokens_per_expert, + H, + group_size, + stride_i_e, + stride_i_t, + stride_i_h, + stride_yq_e, + stride_yq_t, + stride_yq_h, + stride_ys_e, + stride_ys_t, + stride_ys_g, + stride_cnt_e, + eps, + fp8_min, + fp8_max, + is_deep_gemm_e8m0_used(), + BLOCK=group_size, + NUM_STAGES=4, + num_warps=1, + ) return y_q, y_s class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - # The Deep Gemm kernels only support block size of 128 DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128] @@ -297,8 +324,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), workspace1, expert_num_tokens, expected_m) - a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, - expert_num_tokens) + a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda( + workspace1, expert_num_tokens) fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, expert_num_tokens, expected_m)