mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 01:45:21 +08:00
Silu v2 (#25074)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: elvircrn <elvircrn@gmail.com> Signed-off-by: Elvir Crnčević <elvircrn@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
This commit is contained in:
parent
ae9d0e7da5
commit
7b03584de8
@ -1,5 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Comprehensive 3-way SiLU Benchmark Suite
|
||||
|
||||
This benchmark compares three SiLU implementations:
|
||||
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
|
||||
2. Triton Kernel - Triton-based implementation
|
||||
|
||||
The suite generates detailed performance comparisons including:
|
||||
- Memory bandwidth utilization
|
||||
- Speedup ratios (baseline vs optimized implementations)
|
||||
- Performance across different expert configurations and token distributions
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -7,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
|
||||
num_parallel_tokens,
|
||||
group_size: int = 128,
|
||||
eps: float = 1e-10,
|
||||
expert_offsets: torch.Tensor = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
|
||||
@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
|
||||
|
||||
|
||||
# Parse generation strategies
|
||||
strategies = ["uniform", "max_t", "first_t"]
|
||||
strategies = ["random_imbalanced", "uniform", "max_t"]
|
||||
|
||||
|
||||
def benchmark(
|
||||
@ -195,15 +210,27 @@ def benchmark(
|
||||
current_platform.seed_everything(42 + seed_offset)
|
||||
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
|
||||
if gen_strategy == "uniform":
|
||||
r = torch.rand(size=(E,), device="cuda")
|
||||
if gen_strategy == "random_imbalanced":
|
||||
|
||||
def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
|
||||
mean = total_tokens // n_e
|
||||
min_max = mean // ratio
|
||||
e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean
|
||||
e[0] = min_max
|
||||
r = torch.rand(size=(E - 1,))
|
||||
r /= r.sum()
|
||||
r *= total_tokens - min_max
|
||||
r = r.round().long()
|
||||
e[1:] = r.to(device=device)
|
||||
return e
|
||||
|
||||
tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda")
|
||||
elif gen_strategy == "uniform":
|
||||
r = torch.rand(size=(E,))
|
||||
r /= r.sum()
|
||||
r *= total_tokens
|
||||
tokens_per_expert = r.int()
|
||||
tokens_per_expert = torch.minimum(
|
||||
tokens_per_expert,
|
||||
torch.ones((E,), device=r.device, dtype=torch.int) * T,
|
||||
)
|
||||
r = r.round().long()
|
||||
tokens_per_expert = r
|
||||
elif gen_strategy == "max_t":
|
||||
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
|
||||
tokens_per_expert.fill_(total_tokens / E)
|
||||
@ -281,40 +308,34 @@ def benchmark(
|
||||
|
||||
|
||||
def create_comparison_plot(
|
||||
ratio, cuda_times, baseline_times, config_labels, strategy_name, id
|
||||
ratios, silu_v2_times, triton_times, config_labels, strategy_name, id
|
||||
):
|
||||
"""Create a comparison plot for a specific generation strategy"""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(16, 6))
|
||||
fig, ax = plt.subplots(1, 1, figsize=(18, 6))
|
||||
|
||||
# Configure x-axis positions
|
||||
x = np.arange(len(config_labels))
|
||||
width = 0.35
|
||||
width = 0.25
|
||||
|
||||
# Execution Time plot (lower is better)
|
||||
ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue")
|
||||
ax.bar(
|
||||
x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue"
|
||||
)
|
||||
ax.bar(
|
||||
x + width / 2,
|
||||
baseline_times,
|
||||
width,
|
||||
label="Baseline",
|
||||
alpha=0.8,
|
||||
color="orange",
|
||||
x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green"
|
||||
)
|
||||
|
||||
# Add speedup labels over each bar pair
|
||||
# Add speedup labels over each bar trio
|
||||
for i in range(len(x)):
|
||||
speedup = ratio[i]
|
||||
max_height = max(cuda_times[i], baseline_times[i])
|
||||
triton_v2_speedup = ratios[i][1] # triton/v2
|
||||
max_height = max(silu_v2_times[i], triton_times[i])
|
||||
|
||||
# Triton/V2 speedup
|
||||
ax.text(
|
||||
x[i],
|
||||
x[i] + width / 2,
|
||||
max_height + max_height * 0.02,
|
||||
f"{speedup:.2f}x",
|
||||
f"{triton_v2_speedup:.2f}x",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=9,
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Configuration")
|
||||
@ -332,56 +353,75 @@ def create_comparison_plot(
|
||||
|
||||
|
||||
def create_combined_plot(all_results):
|
||||
"""Create a combined plot with all strategies in one PNG"""
|
||||
num_strategies = len(all_results)
|
||||
fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies))
|
||||
fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies))
|
||||
|
||||
if num_strategies == 1:
|
||||
axes = [axes]
|
||||
|
||||
for idx, (
|
||||
strategy_name,
|
||||
ratio,
|
||||
cuda_times,
|
||||
baseline_times,
|
||||
all_ratios,
|
||||
all_silu_v2_results,
|
||||
all_triton_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
) in enumerate(all_results):
|
||||
ax = axes[idx]
|
||||
|
||||
# Flatten the nested results to get bandwidth percentages for plotting
|
||||
silu_v2_bandwidths = []
|
||||
triton_bandwidths = []
|
||||
flat_ratios = []
|
||||
|
||||
for config_results in all_silu_v2_results:
|
||||
for result in config_results:
|
||||
silu_v2_bandwidths.append(result[3]) # bandwidth percentage
|
||||
|
||||
for config_results in all_triton_results:
|
||||
for result in config_results:
|
||||
triton_bandwidths.append(result[3]) # bandwidth percentage
|
||||
|
||||
for config_ratios in all_ratios:
|
||||
for ratio in config_ratios:
|
||||
flat_ratios.append(ratio)
|
||||
|
||||
# Configure x-axis positions
|
||||
x = np.arange(len(config_labels))
|
||||
width = 0.35
|
||||
width = 0.25
|
||||
|
||||
# Execution Time plot (lower is better)
|
||||
# Bandwidth utilization plot (higher is better)
|
||||
ax.bar(
|
||||
x - width / 2,
|
||||
cuda_times,
|
||||
x,
|
||||
silu_v2_bandwidths,
|
||||
width,
|
||||
label="CUDA Kernel",
|
||||
label="SiLU V2 (CUDA)",
|
||||
alpha=0.8,
|
||||
color="blue",
|
||||
)
|
||||
ax.bar(
|
||||
x + width / 2,
|
||||
baseline_times,
|
||||
x + width,
|
||||
triton_bandwidths,
|
||||
width,
|
||||
label="Baseline",
|
||||
label="Triton Kernel",
|
||||
alpha=0.8,
|
||||
color="orange",
|
||||
color="green",
|
||||
)
|
||||
|
||||
# Add speedup labels over each bar pair
|
||||
# Add speedup labels over each bar trio
|
||||
for i in range(len(x)):
|
||||
speedup = ratio[i]
|
||||
max_height = max(cuda_times[i], baseline_times[i])
|
||||
triton_v2_speedup = flat_ratios[i] # triton/v2
|
||||
max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i])
|
||||
|
||||
# Triton/V2 speedup
|
||||
ax.text(
|
||||
x[i],
|
||||
x[i] + width / 2,
|
||||
max_height + max_height * 0.02,
|
||||
f"{speedup:.2f}x",
|
||||
f"{triton_v2_speedup:.2f}x",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=9,
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Configuration")
|
||||
@ -395,7 +435,7 @@ def create_combined_plot(all_results):
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
filename = "../../silu_bench/silu_benchmark_combined.png"
|
||||
filename = "silu_benchmark_combined_3way.png"
|
||||
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||
plt.show()
|
||||
|
||||
@ -405,7 +445,9 @@ def create_combined_plot(all_results):
|
||||
outer_dim = 7168
|
||||
configs = [
|
||||
# DeepSeekV3 Configs
|
||||
# (1, 56, 7168),
|
||||
(8, 1024, 7168),
|
||||
# (32, 56, 7168),
|
||||
# DeepSeekV3 Configs
|
||||
(32, 1024, 7168),
|
||||
# DeepSeekV3 Configs
|
||||
@ -417,6 +459,7 @@ num_warmups = 20
|
||||
|
||||
strategy_descriptions = {
|
||||
"uniform": "Uniform Random",
|
||||
"random_imbalanced": "Imbalanced Random",
|
||||
"max_t": "Even Assignment",
|
||||
"first_t": "experts[0] = T, experts[1:] = 0",
|
||||
}
|
||||
@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies):
|
||||
print(f"Testing strategy: {strategy_descriptions[strategy]}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# Collect benchmark data for both algorithms
|
||||
# Collect benchmark data for all three algorithms
|
||||
config_labels = []
|
||||
config_x_axis = []
|
||||
all_cuda_results = []
|
||||
all_baseline_results = []
|
||||
all_silu_v2_results = []
|
||||
all_triton_results = []
|
||||
all_ratios = []
|
||||
|
||||
for E, T, H in configs:
|
||||
total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E]
|
||||
total_tokens_config = []
|
||||
for i in [8, 16, 32, 64, 128, 256, 512]:
|
||||
if i <= T:
|
||||
total_tokens_config.append(i * E)
|
||||
config_x_axis.append(total_tokens_config)
|
||||
|
||||
cuda_results = []
|
||||
baseline_results = []
|
||||
silu_v2_results = []
|
||||
triton_results = []
|
||||
ratios = []
|
||||
|
||||
for total_tokens in total_tokens_config:
|
||||
config_label = f"E={E},T={T},H={H},TT={total_tokens}"
|
||||
config_labels.append(config_label)
|
||||
|
||||
# CUDA kernel results
|
||||
time_ms_cuda, gflops, gbps, perc = benchmark(
|
||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||
# SiLU V2 (CUDA kernel) results
|
||||
time_ms_silu_v2, gflops, gbps, perc = benchmark(
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
E,
|
||||
T,
|
||||
H,
|
||||
@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies):
|
||||
num_warmups=num_warmups,
|
||||
gen_strategy=strategy,
|
||||
)
|
||||
cuda_results.append((time_ms_cuda, gflops, gbps, perc))
|
||||
silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc))
|
||||
|
||||
# Baseline results
|
||||
# Triton kernel results
|
||||
time_ms_triton, gflops, gbps, perc = benchmark(
|
||||
silu_mul_fp8_quant_deep_gemm_triton,
|
||||
E,
|
||||
@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies):
|
||||
num_warmups=num_warmups,
|
||||
gen_strategy=strategy,
|
||||
)
|
||||
baseline_results.append((time_ms_triton, gflops, gbps, perc))
|
||||
ratios.append(time_ms_triton / time_ms_cuda)
|
||||
triton_results.append((time_ms_triton, gflops, gbps, perc))
|
||||
|
||||
print(f"Completed: {config_label}")
|
||||
all_cuda_results.append(cuda_results)
|
||||
all_baseline_results.append(baseline_results)
|
||||
# Calculate speedup ratios (triton baseline / implementation)
|
||||
triton_v2_ratio = time_ms_triton / time_ms_silu_v2
|
||||
ratios.append(triton_v2_ratio)
|
||||
|
||||
print(
|
||||
f"Completed: {config_label}:"
|
||||
f" V2: {time_ms_silu_v2:.3f}ms,"
|
||||
f" Triton: {time_ms_triton:.3f}ms"
|
||||
)
|
||||
|
||||
all_silu_v2_results.append(silu_v2_results)
|
||||
all_triton_results.append(triton_results)
|
||||
all_ratios.append(ratios)
|
||||
|
||||
# Store results for combined plotting
|
||||
@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies):
|
||||
(
|
||||
strategy_descriptions[strategy],
|
||||
all_ratios,
|
||||
all_cuda_results,
|
||||
all_baseline_results,
|
||||
all_silu_v2_results,
|
||||
all_triton_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
)
|
||||
@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies):
|
||||
|
||||
# Print summary table for this strategy
|
||||
print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
|
||||
print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}")
|
||||
print("-" * 60)
|
||||
print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}")
|
||||
print("-" * 90)
|
||||
|
||||
for i, (E, T, H) in enumerate(configs):
|
||||
speedup = baseline_results[i][0] / cuda_results[i][0]
|
||||
# Get the first result for each config (simplifying for summary)
|
||||
v2_time = silu_v2_results[i][0]
|
||||
triton_time = triton_results[i][0]
|
||||
triton_v2_speedup = triton_time / v2_time
|
||||
config_label = f"E={E:3d},T={T:4d},H={H:4d}"
|
||||
print(
|
||||
f"{config_label:<20} {cuda_results[i][0]:8.5f} "
|
||||
f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x"
|
||||
f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} "
|
||||
f"{triton_v2_speedup:8.2f}x"
|
||||
)
|
||||
|
||||
|
||||
@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results):
|
||||
num_strategies = len(all_results)
|
||||
num_configs = len(configs)
|
||||
|
||||
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
|
||||
fig, axs = plt.subplots(
|
||||
num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies)
|
||||
num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies)
|
||||
)
|
||||
|
||||
# Add main title to the entire figure
|
||||
fig.suptitle(
|
||||
"Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)",
|
||||
fontsize=16,
|
||||
"Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)",
|
||||
fontsize=18,
|
||||
fontweight="bold",
|
||||
y=0.98,
|
||||
)
|
||||
@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results):
|
||||
(
|
||||
strategy_name,
|
||||
all_ratios,
|
||||
all_cuda_results,
|
||||
all_baseline_results,
|
||||
all_silu_v2_results,
|
||||
all_triton_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
) = result
|
||||
@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results):
|
||||
ratios = all_ratios[config_idx]
|
||||
total_tokens_values = config_x_axis[config_idx]
|
||||
|
||||
# Extract CUDA and Triton bandwidth percentages
|
||||
cuda_bandwidth_percentages = [
|
||||
result[3] for result in all_cuda_results[config_idx]
|
||||
# Extract speedup ratios
|
||||
triton_v2_ratios = [ratio for ratio in ratios]
|
||||
|
||||
# Extract bandwidth percentages for all implementations
|
||||
v2_bandwidth_percentages = [
|
||||
result[3] for result in all_silu_v2_results[config_idx]
|
||||
]
|
||||
triton_bandwidth_percentages = [
|
||||
result[3] for result in all_baseline_results[config_idx]
|
||||
result[3] for result in all_triton_results[config_idx]
|
||||
]
|
||||
|
||||
# Plot speedup ratios vs total tokens (left plot)
|
||||
ax_speedup.plot(
|
||||
total_tokens_values, ratios, "bo-", linewidth=3, markersize=8
|
||||
total_tokens_values,
|
||||
triton_v2_ratios,
|
||||
"go-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="Triton/V2 Speedup",
|
||||
)
|
||||
ax_speedup.set_title(
|
||||
f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}",
|
||||
f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}",
|
||||
fontsize=12,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
|
||||
ax_speedup.legend(prop={"weight": "bold"})
|
||||
ax_speedup.grid(True, alpha=0.3)
|
||||
|
||||
# Plot bandwidth utilization (right plot)
|
||||
ax_bandwidth.plot(
|
||||
total_tokens_values,
|
||||
cuda_bandwidth_percentages,
|
||||
"ro-",
|
||||
v2_bandwidth_percentages,
|
||||
"o-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="CUDA",
|
||||
label="SiLU V2",
|
||||
color="blue",
|
||||
)
|
||||
ax_bandwidth.plot(
|
||||
total_tokens_values,
|
||||
triton_bandwidth_percentages,
|
||||
"go-",
|
||||
"o-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="Triton",
|
||||
color="green",
|
||||
)
|
||||
ax_bandwidth.set_title(
|
||||
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
|
||||
@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results):
|
||||
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
||||
label.set_fontweight("bold")
|
||||
|
||||
# Add value labels on speedup points
|
||||
for x, y in zip(total_tokens_values, ratios):
|
||||
# Add value labels on Triton/V2 speedup points
|
||||
for x, y in zip(total_tokens_values, triton_v2_ratios):
|
||||
ax_speedup.annotate(
|
||||
f"{y:.2f}x",
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 12),
|
||||
ha="center",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
|
||||
)
|
||||
|
||||
# Add value labels on CUDA bandwidth points
|
||||
for x, y in zip(total_tokens_values, cuda_bandwidth_percentages):
|
||||
ax_bandwidth.annotate(
|
||||
f"{y:.1f}%",
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 12),
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3),
|
||||
)
|
||||
|
||||
# Add value labels on Triton bandwidth points
|
||||
for x, y in zip(total_tokens_values, triton_bandwidth_percentages):
|
||||
ax_bandwidth.annotate(
|
||||
f"{y:.1f}%",
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(0, -15),
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results):
|
||||
|
||||
plt.tight_layout()
|
||||
plt.subplots_adjust(top=0.93) # Make room for main title
|
||||
filename = "silu_benchmark_total_tokens.png"
|
||||
filename = "silu_benchmark_total_tokens_3way.png"
|
||||
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||
plt.show()
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
# Create combined plot with all strategies
|
||||
combined_plot_filename = create_total_tokens_plot(all_results)
|
||||
# Create comprehensive 3-way comparison plots
|
||||
combined_plot_filename = create_combined_plot(all_results)
|
||||
total_tokens_plot_filename = create_total_tokens_plot(all_results)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Benchmark Complete!")
|
||||
print(f"Generated combined plot: {combined_plot_filename}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"\n{'=' * 80}")
|
||||
print("3-Way Benchmark Suite Complete!")
|
||||
print(f"Generated combined comparison plot: {combined_plot_filename}")
|
||||
print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}")
|
||||
print("Compared: SiLU V2 (CUDA), and Triton implementations")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
@ -138,12 +138,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& input_global_scale);
|
||||
#endif
|
||||
void silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
void persistent_masked_m_silu_mul_quant(
|
||||
const at::Tensor& input, // (E, T, 2*H)
|
||||
const at::Tensor& counts, // (E)
|
||||
at::Tensor& y_q, // (E, T, H) [OUT]
|
||||
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
|
||||
bool use_ue8m0);
|
||||
|
||||
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
|
||||
@ -114,13 +114,22 @@ __global__ void act_and_mul_quant_kernel(
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float silu(float x) {
|
||||
return (__fdividef(x, (1.f + expf(-x))));
|
||||
return __fdividef(x, (1.f + expf(-x)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float2 silu2(float2 x) {
|
||||
return make_float2(silu(x.x), silu(x.y));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __nv_bfloat162 silu2_v2(float2 x) {
|
||||
#ifndef USE_ROCM
|
||||
return make_bfloat162(__float2bfloat16_rn(silu(x.x)),
|
||||
__float2bfloat16_rn(silu(x.y)));
|
||||
#else
|
||||
return __float22bfloat162_rn(make_float2(silu(x.x), silu(x.y)));
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__device__ __forceinline__ float warp_max(float v) {
|
||||
static constexpr unsigned FULL_MASK = 0xffffffffu;
|
||||
@ -223,224 +232,308 @@ constexpr __nv_bfloat16 get_fp8_min() {
|
||||
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032});
|
||||
}
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
template <typename fp8_type, int32_t NUM_WARPS, typename Idx_t,
|
||||
int NUM_PARALLEL_TOKENS, bool USE_UE8M0, int GROUP_SIZE = 128,
|
||||
|
||||
template <typename Idx_t>
|
||||
__device__ __forceinline__ int warp_expert_search(
|
||||
int idx, int n, const Idx_t* __restrict__ input, Idx_t val) {
|
||||
const Idx_t* input_ptr = input + idx;
|
||||
int base_offset = 0;
|
||||
|
||||
for (;;) {
|
||||
bool move_on = (idx < n && *input_ptr <= val);
|
||||
|
||||
unsigned mask = __ballot_sync(0xffffffff, move_on);
|
||||
|
||||
if (mask != 0xffffffffu) {
|
||||
int last_lane = 31 - __clz(mask);
|
||||
return base_offset + last_lane;
|
||||
}
|
||||
|
||||
input_ptr += 32;
|
||||
base_offset += 32;
|
||||
idx += 32;
|
||||
}
|
||||
}
|
||||
|
||||
template <int num_parallel_tokens>
|
||||
__device__ __forceinline__ void token_bounds(int32_t n_tokens,
|
||||
int32_t worker_id,
|
||||
int32_t& n_tokens_lower,
|
||||
int32_t& n_tokens_upper) {
|
||||
if (n_tokens < num_parallel_tokens && worker_id < n_tokens) {
|
||||
if (worker_id >= num_parallel_tokens) return;
|
||||
n_tokens_lower = worker_id;
|
||||
n_tokens_upper = worker_id + 1;
|
||||
} else {
|
||||
int32_t chunk_size = n_tokens / num_parallel_tokens;
|
||||
int32_t residual = n_tokens - chunk_size * num_parallel_tokens;
|
||||
auto calc_id = [&](int32_t id) {
|
||||
if (id < residual)
|
||||
return min(n_tokens, id * (chunk_size + 1));
|
||||
else
|
||||
return min(n_tokens, id * chunk_size + residual);
|
||||
};
|
||||
n_tokens_lower = calc_id(worker_id);
|
||||
n_tokens_upper = calc_id(worker_id + 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type,
|
||||
int THREADS, typename Idx_t, bool USE_UE8M0, int GROUP_SIZE = 128,
|
||||
int NUM_STAGES = 3>
|
||||
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
||||
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
|
||||
float* __restrict__ _y_s, const int32_t* __restrict__ counts,
|
||||
|
||||
float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
|
||||
// sizes
|
||||
int H, int G,
|
||||
|
||||
Idx_t E, Idx_t T, Idx_t H,
|
||||
// strides (in elements)
|
||||
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
|
||||
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
|
||||
Idx_t stride_ys_g, Idx_t stride_counts_e) {
|
||||
#ifndef USE_ROCM
|
||||
static constexpr int NUM_WARPS = THREADS / WARP_SIZE;
|
||||
|
||||
static constexpr int LOAD_STAGE_SIZE = 2 * GROUP_SIZE / 8;
|
||||
static constexpr int LOAD_STAGE_MOD = NUM_STAGES * LOAD_STAGE_SIZE;
|
||||
|
||||
static constexpr int COMPUTE_STAGE_SIZE = 2 * GROUP_SIZE / 4;
|
||||
static constexpr int COMPUTE_STAGE_MOD = COMPUTE_STAGE_SIZE * NUM_STAGES;
|
||||
|
||||
extern __shared__ __align__(16) __int128_t smem_128[];
|
||||
|
||||
int* s_expert_offsets =
|
||||
reinterpret_cast<int*>(smem_128 + (SMEM_SIZE_BYTES_Y / 16));
|
||||
|
||||
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
|
||||
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
|
||||
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
|
||||
// We assign EPS with it's 16-bit unsigned counterpart to allow constexpr.
|
||||
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
|
||||
int tid = threadIdx.x;
|
||||
int warp_id = tid >> 5;
|
||||
int lane_id = tid & 0x1f;
|
||||
|
||||
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
|
||||
static constexpr int32_t BFLOAT16_PER_GROUP = 8;
|
||||
int running_sum{};
|
||||
if (!warp_id) {
|
||||
for (int i = 0; i < E; i += WARP_SIZE) {
|
||||
bool valid = (i + threadIdx.x) < E;
|
||||
int value =
|
||||
(valid ? tokens_per_expert[i + threadIdx.x * stride_counts_e] : 0) +
|
||||
(!lane_id ? running_sum : 0);
|
||||
|
||||
// We split the shared memory in half, corresponding to gate and up matrices:
|
||||
// [...gate_i, ...up_i] where 0 <= i < stages.
|
||||
static constexpr int32_t S_NUM_128 =
|
||||
2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES;
|
||||
static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE;
|
||||
static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2;
|
||||
static constexpr int32_t S_NUM_64 = S_NUM_128 * 2;
|
||||
__shared__ __int128_t __align__(16) s_buff_128[S_NUM_128];
|
||||
for (int offset = 1; offset < 32; offset *= 2) {
|
||||
int n = __shfl_up_sync(0xFFFFFFFFu, value, offset);
|
||||
if (lane_id >= offset) value += n;
|
||||
}
|
||||
|
||||
const int32_t tid = threadIdx.x;
|
||||
const int32_t warp_id = tid / WARP_SIZE;
|
||||
const int32_t lane_id = tid % WARP_SIZE;
|
||||
if (valid) {
|
||||
s_expert_offsets[i + threadIdx.x + 1] = value;
|
||||
}
|
||||
|
||||
auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128);
|
||||
running_sum = __shfl_sync(0xFFFFFFFFu, value, WARP_SIZE - 1);
|
||||
}
|
||||
|
||||
// block handles one (expert e, group g)
|
||||
int32_t pid = blockIdx.x;
|
||||
int32_t e = pid / G;
|
||||
int32_t g = pid % G;
|
||||
|
||||
const int32_t n_tokens = counts[e * stride_counts_e];
|
||||
|
||||
if (!n_tokens) {
|
||||
return; // Exit ASAP.
|
||||
if (!lane_id) {
|
||||
s_expert_offsets[0] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
const Idx_t stride_i_t_128 = stride_i_t / 8u;
|
||||
__syncthreads();
|
||||
|
||||
int32_t n_tokens_lower, n_tokens_upper;
|
||||
int32_t total_tokens = s_expert_offsets[E];
|
||||
|
||||
const int warp_position_yq = warp_id * (H / NUM_WARPS);
|
||||
const int warp_position_scales = warp_id * (H / (GROUP_SIZE * NUM_WARPS));
|
||||
|
||||
// A single block will handle tokens_per_block tokens.
|
||||
// Each block i iterates over tokens of a slice of n_tokens =
|
||||
// expert_counts[i], with the size of chunk being
|
||||
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
|
||||
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
|
||||
if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) {
|
||||
// Specialize this, but can be likely fused.
|
||||
if (blockIdx.y >= NUM_PARALLEL_TOKENS) {
|
||||
return;
|
||||
}
|
||||
n_tokens_lower = blockIdx.y;
|
||||
n_tokens_upper = blockIdx.y + 1;
|
||||
} else {
|
||||
auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS;
|
||||
auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS;
|
||||
auto calc_id = [&](int32_t id) {
|
||||
if (id < residual) {
|
||||
return min(n_tokens, id * (chunk_size + 1));
|
||||
} else {
|
||||
return min(n_tokens, id * chunk_size + residual);
|
||||
}
|
||||
};
|
||||
n_tokens_lower = calc_id(blockIdx.y);
|
||||
n_tokens_upper = calc_id(blockIdx.y + 1);
|
||||
}
|
||||
|
||||
if (n_tokens_lower >= n_tokens_upper) {
|
||||
// Each warp will get space to store its hidden dim for gate and up.
|
||||
__int128_t* s_hidden_load = smem_128 + warp_id * ((2 * 128 / 8) * NUM_STAGES);
|
||||
__int128_t* smem_load_ptr = s_hidden_load + lane_id;
|
||||
|
||||
const __nv_bfloat16 fp8_inv = __hdiv(__float2bfloat16(1.f), fp8_max);
|
||||
|
||||
int32_t compute_pipeline_offset_64 = 0;
|
||||
int32_t load_stage_offset{};
|
||||
const __nv_bfloat16 one_bf16 = __float2bfloat16_rn(1.f);
|
||||
|
||||
__int64_t* smem_compute_ptr = reinterpret_cast<__int64_t*>(smem_128) +
|
||||
warp_id * (2 * (GROUP_SIZE / 4) * NUM_STAGES) +
|
||||
lane_id;
|
||||
__int64_t* s_gate64_ptr = smem_compute_ptr;
|
||||
__int64_t* s_up64_ptr = smem_compute_ptr + GROUP_SIZE / 4;
|
||||
|
||||
int tokens_lower, tokens_upper;
|
||||
|
||||
token_bounds<BLOCK_COUNT>(total_tokens, blockIdx.x, tokens_lower,
|
||||
tokens_upper);
|
||||
|
||||
Idx_t expert_id{}, expert_offset{}, next_expert_offset{};
|
||||
int token_id = tokens_lower;
|
||||
int32_t t_load{};
|
||||
|
||||
if (token_id < tokens_upper) {
|
||||
expert_id = warp_expert_search<int>(lane_id, E, s_expert_offsets, token_id);
|
||||
expert_offset = s_expert_offsets[expert_id];
|
||||
next_expert_offset = s_expert_offsets[expert_id + 1];
|
||||
} else {
|
||||
// This thread block has no work to do.
|
||||
return;
|
||||
}
|
||||
|
||||
// We do calculations here, using constexpr wherever possible.
|
||||
const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h;
|
||||
const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g;
|
||||
const Idx_t base_yq =
|
||||
e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h;
|
||||
Idx_t gate_off_128 = (base_i / static_cast<Idx_t>(8u));
|
||||
auto input_128_ptr = reinterpret_cast<const __int128_t*>(_input);
|
||||
auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) +
|
||||
stride_i_t_128 * n_tokens_lower;
|
||||
auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u;
|
||||
auto y_s_ptr =
|
||||
_y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t;
|
||||
auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE +
|
||||
stride_yq_t * n_tokens_lower + 4 * lane_id;
|
||||
int32_t t_load = n_tokens_lower, load_stage_id = 0;
|
||||
auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT);
|
||||
auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u;
|
||||
int32_t stage_offset{};
|
||||
int t_load_bound = H / (GROUP_SIZE * NUM_WARPS);
|
||||
|
||||
static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2);
|
||||
static constexpr int32_t LOAD_STAGE_MOD =
|
||||
NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2);
|
||||
Idx_t base_i = ((expert_id * stride_i_e) / 8) +
|
||||
(token_id - expert_offset) * stride_i_t / 8;
|
||||
const Idx_t gate_warp_offset =
|
||||
warp_id * ((stride_i_h * H) / (8 * NUM_WARPS)) + (lane_id & 0b1111);
|
||||
|
||||
const __int128_t* input_128_ptr =
|
||||
reinterpret_cast<const __int128_t*>(_input) + gate_warp_offset +
|
||||
((lane_id < 16) ? 0 : ((H * stride_i_h) / 8));
|
||||
__int128_t* load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
|
||||
|
||||
auto token_offset = token_id - expert_offset;
|
||||
|
||||
// Two halves of all threads in a block conduct global loads for gate and up,
|
||||
// repsectively.
|
||||
auto load_and_advance_y_pred = [&] {
|
||||
if (t_load < n_tokens_upper) {
|
||||
auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset;
|
||||
auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset;
|
||||
if (t_load < t_load_bound) {
|
||||
// Here we are simply continuing to load data
|
||||
// from the current token.
|
||||
auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset;
|
||||
|
||||
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
|
||||
// unnecessary ALU ops.
|
||||
stage_offset += LOAD_STAGE_SIZE;
|
||||
stage_offset %= LOAD_STAGE_MOD;
|
||||
load_stage_offset += LOAD_STAGE_SIZE;
|
||||
load_stage_offset %= LOAD_STAGE_MOD;
|
||||
|
||||
if (tid < HALF_THREAD_COUNT) {
|
||||
cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr);
|
||||
gate_128_ptr += stride_i_t_128;
|
||||
} else {
|
||||
cp_async4(s_up_stage_128_staged_ptr, up_128_ptr);
|
||||
up_128_ptr += stride_i_t_128;
|
||||
}
|
||||
cp_async4(smem_load_ptr_staged, load_ptr);
|
||||
load_ptr += GROUP_SIZE / 8;
|
||||
++t_load;
|
||||
} else if (token_id + 1 < tokens_upper) {
|
||||
// We loaded everything from the current token, let's move on
|
||||
// to the next one, and we checked that we have more tokens to load.
|
||||
++token_id;
|
||||
t_load = 0;
|
||||
if (token_id >= next_expert_offset) {
|
||||
// We need to find the next expert.
|
||||
do {
|
||||
// This is a loop because it's possible
|
||||
// that some experts are assigned 0 tokens.
|
||||
// NOTE: We are guaranteed that there's at least
|
||||
// one more token left so we don't have to check for
|
||||
// expert_id bounds.
|
||||
++expert_id;
|
||||
// This skips 1 memory read.
|
||||
expert_offset = next_expert_offset;
|
||||
next_expert_offset = s_expert_offsets[expert_id + 1];
|
||||
} while (next_expert_offset == expert_offset);
|
||||
|
||||
base_i = expert_id * (stride_i_e / 8);
|
||||
token_offset = 0;
|
||||
load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
|
||||
} else {
|
||||
// We remain within the same expert, so just
|
||||
// move by H/4 __int128_t (2 * H/8).
|
||||
base_i += stride_yq_t / 4;
|
||||
token_offset++;
|
||||
}
|
||||
|
||||
load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
|
||||
|
||||
auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset;
|
||||
|
||||
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
|
||||
// unnecessary ALU ops.
|
||||
load_stage_offset += LOAD_STAGE_SIZE;
|
||||
load_stage_offset %= LOAD_STAGE_MOD;
|
||||
|
||||
cp_async4(smem_load_ptr_staged, load_ptr);
|
||||
load_ptr += GROUP_SIZE / 8;
|
||||
++t_load;
|
||||
++load_stage_id;
|
||||
}
|
||||
// We fence even if there is nothing to load to simplify pipelining.
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
// We need to warm-up the pipeline.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_STAGES - 1; i++) {
|
||||
load_and_advance_y_pred();
|
||||
}
|
||||
|
||||
__int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>(
|
||||
s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) +
|
||||
lane_id;
|
||||
__int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2;
|
||||
__nv_fp8x4_e4m3* y_q_base_ptr =
|
||||
reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id;
|
||||
auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g;
|
||||
|
||||
static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u;
|
||||
static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES;
|
||||
for (auto j = tokens_lower; j < tokens_upper; j++) {
|
||||
const Idx_t base_ys = expert_id * stride_ys_e;
|
||||
auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t;
|
||||
__nv_fp8x4_e4m3* y_q_ptr =
|
||||
y_q_base_ptr + (expert_id * stride_yq_e + token_offset * stride_yq_t +
|
||||
warp_position_yq * stride_yq_h) /
|
||||
4;
|
||||
const int COMPUTE_LIMIT = H / (GROUP_SIZE * NUM_WARPS);
|
||||
|
||||
int32_t compute_pipeline_offset_64 = 0;
|
||||
for (int i = 0; i < COMPUTE_LIMIT; i++) {
|
||||
cp_async_wait<NUM_STAGES - 2>();
|
||||
__syncthreads();
|
||||
load_and_advance_y_pred();
|
||||
|
||||
for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) {
|
||||
__nv_bfloat162 results_bf162[2];
|
||||
__int64_t* gate64_ptr = s_gate64_ptr + compute_pipeline_offset_64;
|
||||
__int64_t* up64_ptr = s_up64_ptr + compute_pipeline_offset_64;
|
||||
|
||||
cp_async_wait<NUM_STAGES - 2>();
|
||||
__syncthreads();
|
||||
// COMPUTE_STAGE_SIZE/MOD must also be constexpr!
|
||||
compute_pipeline_offset_64 += COMPUTE_STAGE_SIZE;
|
||||
compute_pipeline_offset_64 %= COMPUTE_STAGE_MOD;
|
||||
|
||||
// We double-buffer pipelined loads so that the next load will
|
||||
// concurrently run with compute without overwrites.
|
||||
load_and_advance_y_pred();
|
||||
__int64_t gate64 = *gate64_ptr;
|
||||
__int64_t up64 = *up64_ptr;
|
||||
|
||||
auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64;
|
||||
auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64;
|
||||
|
||||
// STAGE_SIZE must also be constexpr!
|
||||
compute_pipeline_offset_64 += STAGE_SIZE;
|
||||
compute_pipeline_offset_64 %= STAGE_MOD;
|
||||
|
||||
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
|
||||
__int64_t gate64 = *s_gate_compute_64;
|
||||
__nv_bfloat162* s_gate_compute_32 =
|
||||
reinterpret_cast<__nv_bfloat162*>(&gate64);
|
||||
|
||||
__int64_t up64 = *s_up_compute_64;
|
||||
__nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64);
|
||||
// Compute
|
||||
__nv_bfloat162 res[2];
|
||||
__nv_bfloat162* s_up_comp = reinterpret_cast<__nv_bfloat162*>(&up64);
|
||||
__nv_bfloat162* s_gate_comp = reinterpret_cast<__nv_bfloat162*>(&gate64);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++) {
|
||||
// For silu, we make sure that div is emitted.
|
||||
float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i]));
|
||||
results_bf162[i] = __float22bfloat162_rn(gate);
|
||||
}
|
||||
for (int32_t k = 0; k < 2; ++k) {
|
||||
__nv_bfloat162 gate = silu2_v2(__bfloat1622float2(s_gate_comp[k]));
|
||||
res[k] = __hmul2(gate, s_up_comp[k]);
|
||||
}
|
||||
|
||||
auto _y_max2 = __hmax2(__habs2(res[0]), __habs2(res[1]));
|
||||
|
||||
_y_max2.x = __hmax(__hmax(_y_max2.x, _y_max2.y), EPS);
|
||||
|
||||
__nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv);
|
||||
|
||||
if constexpr (USE_UE8M0) {
|
||||
y_s = hexp2(hceil(hlog2(y_s)));
|
||||
}
|
||||
|
||||
__nv_bfloat16 inv_y = __hdiv(one_bf16, y_s);
|
||||
|
||||
auto y_s2 = make_bfloat162(inv_y, inv_y);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++) {
|
||||
results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]);
|
||||
}
|
||||
for (int32_t k = 0; k < 2; ++k) {
|
||||
res[k] = clip(__hmul2(res[k], y_s2), __bfloat162bfloat162(fp8_min),
|
||||
__bfloat162bfloat162(fp8_max));
|
||||
}
|
||||
|
||||
auto _y_max2 =
|
||||
__hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1]));
|
||||
*y_q_ptr = __nv_fp8x4_e4m3(res[0], res[1]);
|
||||
y_q_ptr += WARP_SIZE * stride_yq_h;
|
||||
|
||||
__nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y));
|
||||
|
||||
// An entire group is assigned to a single warp, so a simple warp reduce
|
||||
// is used.
|
||||
__nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max;
|
||||
|
||||
if constexpr (USE_UE8M0) {
|
||||
y_s = hexp2(hceil(hlog2(y_s)));
|
||||
}
|
||||
|
||||
auto inv_y = __float2bfloat16_rn(1.f) / y_s;
|
||||
|
||||
auto y_s2 = make_bfloat162(inv_y, inv_y);
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t i = 0; i < 2; ++i) {
|
||||
results_bf162[i] =
|
||||
clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min),
|
||||
__bfloat162bfloat162(fp8_max));
|
||||
}
|
||||
|
||||
auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]);
|
||||
*reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4;
|
||||
y_q_ptr += stride_yq_t;
|
||||
|
||||
if (lane_id == 0) {
|
||||
*y_s_ptr = y_s;
|
||||
y_s_ptr += stride_ys_t;
|
||||
if (!lane_id) {
|
||||
*y_s_ptr = y_s;
|
||||
y_s_ptr += stride_ys_g;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@ -475,14 +568,14 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||
}
|
||||
|
||||
void silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
const at::Tensor& input, // (E, T, 2*H)
|
||||
const at::Tensor& counts, // (E)
|
||||
at::Tensor& y_q, // (E, T, H) [OUT]
|
||||
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) {
|
||||
void persistent_masked_m_silu_mul_quant(
|
||||
const at::Tensor& input, // (E, T, 2*H)
|
||||
const at::Tensor& tokens_per_expert, // (E)
|
||||
at::Tensor& y_q, // (E, T, H) [OUT]
|
||||
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||
bool use_ue8m0) {
|
||||
#ifndef USE_ROCM
|
||||
// This kernel relies heavily on cp.async and fp8 support.
|
||||
|
||||
// This kernel currently only supports H % 128 == 0 and assumes a
|
||||
// fixed GROUP_SIZE of 128.
|
||||
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
||||
@ -491,10 +584,6 @@ void silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(input.size(-1) % 256 == 0);
|
||||
|
||||
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
|
||||
TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64);
|
||||
TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1)));
|
||||
|
||||
using Idx_t = int64_t;
|
||||
|
||||
Idx_t E = input.size(0);
|
||||
@ -510,81 +599,54 @@ void silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
Idx_t stride_ys_t = y_s.stride(1);
|
||||
Idx_t stride_ys_g = y_s.stride(2);
|
||||
|
||||
Idx_t stride_counts_e = counts.stride(0);
|
||||
Idx_t stride_counts_e = tokens_per_expert.stride(0);
|
||||
|
||||
static constexpr int GROUP_SIZE = 128;
|
||||
|
||||
#define KERNEL_FN \
|
||||
if (use_ue8m0) { \
|
||||
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
|
||||
NUM_PARALLEL_TOKENS, true> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
||||
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
||||
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
|
||||
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
|
||||
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
|
||||
stride_counts_e); \
|
||||
} else { \
|
||||
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
|
||||
NUM_PARALLEL_TOKENS, false> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
||||
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
||||
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
|
||||
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
|
||||
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
|
||||
stride_counts_e); \
|
||||
}
|
||||
|
||||
#define KERNEL_CALL_H \
|
||||
if (H % (4 * GROUP_SIZE) == 0) { \
|
||||
static constexpr int NUM_WARPS = 4; \
|
||||
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
|
||||
KERNEL_FN \
|
||||
} else { \
|
||||
static constexpr int NUM_WARPS = 1; \
|
||||
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
|
||||
KERNEL_FN \
|
||||
}
|
||||
|
||||
#define KERNEL_CALL_TOP_LEVEL \
|
||||
if (num_parallel_tokens == 1) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 1; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 2) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 2; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 4) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 4; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 8) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 8; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 16) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 16; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 32) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 32; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 64) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 64; \
|
||||
KERNEL_CALL_H \
|
||||
}
|
||||
|
||||
Idx_t G;
|
||||
dim3 block, grid;
|
||||
auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) {
|
||||
G = H / Idx_t(group_size * num_warps);
|
||||
grid = dim3(E * G, _num_parallel_tokens);
|
||||
block = dim3(num_warps * WARP_SIZE);
|
||||
};
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(),
|
||||
"silu_mul_fp8_quant_deep_gemm_kernel",
|
||||
[&] { KERNEL_CALL_TOP_LEVEL });
|
||||
|
||||
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
|
||||
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
|
||||
int sms = SILU_V2_BLOCK_COUNT; \
|
||||
static constexpr int max_shared_mem_bytes = \
|
||||
GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
|
||||
dim3 grid(sms), block(THREAD_COUNT); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
VLLM_DISPATCH_FP8_TYPES( \
|
||||
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
|
||||
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
|
||||
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \
|
||||
USE_UE8M0, GROUP_SIZE, STAGES> \
|
||||
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
|
||||
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
||||
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
||||
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
|
||||
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
|
||||
stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \
|
||||
stride_ys_g, stride_counts_e); \
|
||||
});
|
||||
|
||||
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
|
||||
|
||||
if (!use_ue8m0) {
|
||||
if (H >= 4096) {
|
||||
static constexpr int NUM_STAGES = 4;
|
||||
static constexpr int THREAD_COUNT = 256;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
|
||||
} else {
|
||||
static constexpr int THREAD_COUNT = 32;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
|
||||
}
|
||||
} else {
|
||||
if (H >= 4096) {
|
||||
static constexpr int NUM_STAGES = 4;
|
||||
static constexpr int THREAD_COUNT = 256;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
|
||||
} else {
|
||||
static constexpr int THREAD_COUNT = 32;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -33,11 +33,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
#endif
|
||||
|
||||
ops.def(
|
||||
"silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! "
|
||||
"y_q, Tensor! y_s, int group_size, "
|
||||
"bool use_ue8m0, int num_parallel_tokens) -> ()");
|
||||
ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA,
|
||||
&silu_mul_fp8_quant_deep_gemm_cuda);
|
||||
"persistent_masked_m_silu_mul_quant(Tensor input, Tensor counts, Tensor! "
|
||||
"y_q, Tensor! y_s,"
|
||||
"bool use_ue8m0) -> ()");
|
||||
ops.impl("persistent_masked_m_silu_mul_quant", torch::kCUDA,
|
||||
&persistent_masked_m_silu_mul_quant);
|
||||
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
|
||||
@ -5,7 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
@ -50,15 +50,15 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
|
||||
# Input tensor of shape (E, T, 2*H)
|
||||
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
|
||||
tokens_per_expert = torch.randint(
|
||||
low=T // 2,
|
||||
low=0,
|
||||
high=T,
|
||||
size=(E,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Run the Triton kernel
|
||||
y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
# Run the SiLU V2 kernel
|
||||
y_q, y_s = persistent_masked_m_silu_mul_quant(
|
||||
y, tokens_per_expert, group_size=group_size
|
||||
)
|
||||
|
||||
@ -115,10 +115,11 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
|
||||
y_se = y_s[e].float()
|
||||
y_qe = y_q[e].float()
|
||||
|
||||
torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2)
|
||||
torch.testing.assert_close(
|
||||
y_qe[:nt].to(torch.float32),
|
||||
ref_q[:nt].to(torch.float32),
|
||||
atol=2,
|
||||
rtol=2e-1,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(y_se[:nt], ref_s[:nt], atol=1e-4, rtol=1e-2)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from math import log2
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -94,7 +93,7 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
||||
|
||||
|
||||
def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
def persistent_masked_m_silu_mul_quant(
|
||||
y: torch.Tensor, # (E, T, 2*H)
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
num_parallel_tokens=16,
|
||||
@ -103,9 +102,41 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||
We launch a fixed grid of threads to accommodate CUDA graphs. Let `P2`
|
||||
be a parallelization factor for persistent_masked_m_silu_mul_quant over the
|
||||
hidden dimension.
|
||||
|
||||
Let `expert_offsets = [0] + [num_tokens.cumsum()]` and
|
||||
`total_tokens = expert_offsets[-1]`.
|
||||
persistent_masked_m_silu_mul_quant launches `total_tokens x P2` number of
|
||||
thread blocks. Each thread block contains `NUM_WARPS` warps.
|
||||
|
||||
Every thread block needs to find it's corresponding expert by warp-parallel scanning
|
||||
over the `expert_offsets` array.
|
||||
|
||||
The i-th warp in the first thread block processes
|
||||
`[i * warp_chunk_size, (i + 1) * warp_chunk_size]` groups
|
||||
sequentially, where `warp_chunk_size = ((H / GROUP_SIZE) / P2) / NUM_WARPS`,
|
||||
pipelining loads and computes.
|
||||
|
||||
The shared memory layout for 4 warps with a 2-stage pipeline for SiLU V2
|
||||
can is visualized like so:
|
||||
|
||||
stage0 stage1
|
||||
┌─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┐
|
||||
│gate0│up0│gate1│up1│gate2│up2│gate3│up3│gate0│up0│gate1│up1│gate2│up2│gate3│up3│
|
||||
└─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┘
|
||||
|
||||
with the main difference between V1 and V2 being the global load
|
||||
stride between warps, and between half-warps. Regarding the latter stride,
|
||||
we assign the first half warp of every warp for `gate` loads and the second
|
||||
half-warp to `up` loads.
|
||||
|
||||
Returns `(y_q, y_s)` where
|
||||
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
Let NUM_WARPS be the number of warps in a single thread block and
|
||||
`GROUP_SIZE = 128` be the size of the quantization group.
|
||||
"""
|
||||
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||
E, T, H2 = y.shape
|
||||
@ -133,30 +164,15 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
if E <= 16:
|
||||
max_empirical_parallelism = 64
|
||||
elif E <= 32:
|
||||
max_empirical_parallelism = 16
|
||||
else:
|
||||
max_empirical_parallelism = 4
|
||||
|
||||
# We never want to launch more than Tx number of threads
|
||||
# This computes the clip.
|
||||
num_parallel_tokens = max(
|
||||
1, min(max_empirical_parallelism, 2 ** int(log2(min(num_parallel_tokens, T))))
|
||||
)
|
||||
cuda_arch = current_platform.get_device_capability(
|
||||
device_id=y.device.index
|
||||
).to_int()
|
||||
|
||||
if cuda_arch >= 80:
|
||||
torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
y, tokens_per_expert, y_q, y_s, group_size, use_ue8m0, num_parallel_tokens
|
||||
torch.ops._C.persistent_masked_m_silu_mul_quant(
|
||||
y, tokens_per_expert, y_q, y_s, use_ue8m0
|
||||
)
|
||||
else:
|
||||
# Default to triton if not on cuda or if arch is too old
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
|
||||
# Static grid over experts and H-groups.
|
||||
@ -166,16 +182,6 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||
|
||||
# desired scale strides (elements): (T*G, 1, T)
|
||||
stride_ys_e = T * G
|
||||
stride_ys_t = 1
|
||||
stride_ys_g = T
|
||||
y_s = torch.empty_strided(
|
||||
(E, T, G),
|
||||
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||
dtype=torch.float32,
|
||||
device=y.device,
|
||||
)
|
||||
f_info = torch.finfo(fp8_dtype)
|
||||
fp8_max = f_info.max
|
||||
fp8_min = f_info.min
|
||||
@ -313,7 +319,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expected_m,
|
||||
)
|
||||
|
||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
|
||||
workspace1, expert_num_tokens
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user