[Feature] Integrate new deepgemm (#19820)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-06-24 15:51:56 -04:00 committed by GitHub
parent 91f7d9d0b6
commit c6e3bba8e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 230 additions and 264 deletions

View File

@ -86,6 +86,9 @@ def benchmark_config(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_deep_gemm:
# we use the default block shape for deepgemm
block_quant_shape = [128, 128]
if use_fp8_w8a8:
if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]

View File

@ -1,13 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# fmt: off
# ruff: noqa: E501
import time
# Import DeepGEMM functions
import deep_gemm
import torch
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
from deep_gemm import fp8_gemm_nt
from deep_gemm.testing.numeric import calc_diff
from deep_gemm.utils.math import ceil_div, per_block_cast_to_fp8, per_token_cast_to_fp8
# Import vLLM functions
from vllm import _custom_ops as ops
@ -18,107 +16,84 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.triton_utils import triton
# Copied from
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
def per_token_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert tensor to FP8 format with per-token scaling."""
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
# Copied from
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
def per_block_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def per_block_cast_to_fp8_vllm(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert tensor to FP8 format with per-block scaling."""
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
dtype=x.dtype,
device=x.device)
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2)
)
def benchmark_shape(m: int,
n: int,
k: int,
warmup: int = 100,
repeat: int = 10000,
verbose: bool = False) -> dict:
def benchmark_shape(
m: int,
n: int,
k: int,
warmup: int = 100,
repeat: int = 10000,
verbose: bool = False,
) -> dict:
"""Benchmark all implementations for a specific (m, n, k) shape."""
if verbose:
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
# Create test tensors
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
# Reference result in BF16
A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
torch.cuda.synchronize()
C_ref = A @ B.t()
# Pre-quantize B for all implementations
# (weights can be pre-quantized offline)
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B)
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B)
B_vllm, B_scale_vllm = per_block_cast_to_fp8_vllm(B)
# Block size configuration
block_size = [128, 128]
# Pre-quantize A for all implementations
A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
C_deepgemm = (
torch.empty((n, m), device="cuda", dtype=torch.bfloat16).t().contiguous()
)
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
A, block_size[1], column_major_scales=True)
A, block_size[1], column_major_scales=True
)
# === DeepGEMM Implementation ===
def deepgemm_gemm():
# A quantization is inside the loop as it depends on activations
# A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
# A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
# A, block_size[1])
# A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
# C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
(B_deepgemm, B_scale_deepgemm),
C_deepgemm)
fp8_gemm_nt(
(A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
)
return C_deepgemm
# === vLLM Triton Implementation ===
def vllm_triton_gemm():
# A quantization is inside the loop as it depends on activations
# A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
return w8a8_block_fp8_matmul(A_vllm,
B_vllm,
A_scale_vllm,
B_scale_vllm,
block_size,
output_dtype=torch.bfloat16)
return w8a8_block_fp8_matmul(
A_vllm,
B_vllm,
A_scale_vllm,
B_scale_vllm,
block_size,
output_dtype=torch.bfloat16,
)
# === vLLM CUTLASS Implementation ===
def vllm_cutlass_gemm():
# A quantization is inside the loop as it depends on activations
# A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
# A, block_size[1], column_major_scales=True)
return ops.cutlass_scaled_mm(A_vllm_cutlass,
B_vllm.T,
scale_a=A_scale_vllm_cutlass,
scale_b=B_scale_vllm.T,
out_dtype=torch.bfloat16)
return ops.cutlass_scaled_mm(
A_vllm_cutlass,
B_vllm.T,
scale_a=A_scale_vllm_cutlass,
scale_b=B_scale_vllm.T,
out_dtype=torch.bfloat16,
)
# Run correctness check first
if verbose:
print("Running correctness check...")
C_deepgemm = deepgemm_gemm()
@ -133,26 +108,22 @@ def benchmark_shape(m: int,
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
print("vLLM Triton vs DeepGEMM difference: "
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
print("vLLM CUTLASS vs DeepGEMM difference: "
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
print(
"vLLM Triton vs DeepGEMM difference: "
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
)
print(
"vLLM CUTLASS vs DeepGEMM difference: "
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
)
# Benchmark implementations
implementations = {
"DeepGEMM": deepgemm_gemm,
"vLLM Triton": vllm_triton_gemm,
"vLLM CUTLASS": vllm_cutlass_gemm
"vLLM CUTLASS": vllm_cutlass_gemm,
}
benchmark_results = {
"shape": {
"m": m,
"n": n,
"k": k
},
"implementations": {}
}
benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
for name, func in implementations.items():
# Warmup
@ -180,38 +151,36 @@ def benchmark_shape(m: int,
"tflops": tflops,
"gb_s": gb_s,
"diff": {
"DeepGEMM":
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
"Reference":
deepgemm_diff if name == "DeepGEMM" else
(vllm_triton_diff
if name == "vLLM Triton" else vllm_cutlass_diff)
}
"DeepGEMM": 0.0
if name == "DeepGEMM"
else calc_diff(func(), C_deepgemm),
"Reference": deepgemm_diff
if name == "DeepGEMM"
else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
},
}
if verbose:
print(
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
)
print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
# Calculate speedups
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
for name, data in benchmark_results["implementations"].items():
if name != "DeepGEMM":
speedup = baseline / data["time_ms"]
benchmark_results["implementations"][name][
"speedup_vs_deepgemm"] = speedup
benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
if verbose:
print(f"DeepGEMM is {1/speedup:.2f}x "
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
print(
f"DeepGEMM is {1 / speedup:.2f}x "
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
)
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
"time_ms"]
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
"time_ms"]
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
benchmark_results["implementations"]["vLLM CUTLASS"][
"speedup_vs_triton"] = cutlass_vs_triton
benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
cutlass_vs_triton
)
if verbose:
print(
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
@ -223,8 +192,7 @@ def benchmark_shape(m: int,
def format_table_row(values, widths):
"""Format a row with specified column widths."""
return "| " + " | ".join(f"{val:{w}}"
for val, w in zip(values, widths)) + " |"
return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
def print_table(headers, rows, title=None):
@ -232,16 +200,12 @@ def print_table(headers, rows, title=None):
if title:
print(f"\n{title}")
# Calculate column widths based on headers and data
widths = [
max(len(str(h)), max(len(str(row[i])) for row in rows))
for i, h in enumerate(headers)
]
# Create separator line
separator = "+-" + "-+-".join("-" * w for w in widths) + "-+"
# Print table
print(separator)
print(format_table_row(headers, widths))
print(separator)
@ -259,44 +223,22 @@ def run_benchmarks(verbose: bool = False):
"""Run benchmarks for a set of common shapes."""
print("===== STARTING FP8 GEMM BENCHMARK =====")
# Make sure we're using the GPU
if not torch.cuda.is_available():
print("CUDA not available! Tests require GPU.")
return
# Print system information
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Triton version: {triton.__version__}")
print(f"Using device: {torch.cuda.get_device_name()}")
# Enable TF32 for better performance
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Set seeds for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Define benchmark shapes (m, n, k)
shapes = [
(8, 4096, 7168),
(8, 7168, 18432),
(8, 18432, 7168),
(64, 4096, 7168),
(64, 7168, 18432),
(64, 18432, 7168),
(64, 24576, 1536),
(64, 32768, 512),
(64, 7168, 16384),
(128, 4096, 7168),
(128, 7168, 18432),
(128, 18432, 7168),
(1024, 4096, 7168),
(1024, 18432, 7168),
(2048, 4096, 7168),
(4096, 4096, 7168),
]
shapes = [
# (64, 2112, 7168),
(64, 24576, 1536),
@ -323,7 +265,6 @@ def run_benchmarks(verbose: bool = False):
result = benchmark_shape(m, n, k, verbose=verbose)
all_results.append(result)
# Print results in a nicely formatted table
print("\n===== PERFORMANCE COMPARISON =====")
# Print DeepGEMM table
@ -332,38 +273,50 @@ def run_benchmarks(verbose: bool = False):
for result in all_results:
shape = result["shape"]
impl_data = result["implementations"]["DeepGEMM"]
deepgemm_rows.append([
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
])
deepgemm_rows.append(
[
shape["m"],
shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
]
)
print_table(deepgemm_headers,
deepgemm_rows,
title="DeepGEMM Implementation:")
print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
# Print vLLM Triton table
triton_headers = [
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
]
triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
triton_rows = []
for result in all_results:
shape = result["shape"]
impl_data = result["implementations"]["vLLM Triton"]
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
triton_rows.append([
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
format_speedup(speedup)
])
triton_rows.append(
[
shape["m"],
shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
format_speedup(speedup),
]
)
print_table(triton_headers,
triton_rows,
title="vLLM Triton Implementation:")
print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
# Print vLLM CUTLASS table
cutlass_headers = [
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
"vs Triton"
"m",
"n",
"k",
"Time (μs)",
"TFLOPS",
"GB/s",
"vs DeepGEMM",
"vs Triton",
]
cutlass_rows = []
for result in all_results:
@ -371,28 +324,27 @@ def run_benchmarks(verbose: bool = False):
impl_data = result["implementations"]["vLLM CUTLASS"]
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
cutlass_rows.append([
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
format_speedup(vs_deepgemm),
format_speedup(vs_triton)
])
cutlass_rows.append(
[
shape["m"],
shape["n"],
shape["k"],
f"{impl_data['time_us']:.1f}",
f"{impl_data['tflops']:.1f}",
f"{impl_data['gb_s']:.1f}",
format_speedup(vs_deepgemm),
format_speedup(vs_triton),
]
)
print_table(cutlass_headers,
cutlass_rows,
title="vLLM CUTLASS Implementation:")
print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
# Calculate and print averages
print("\n===== AVERAGE PERFORMANCE =====")
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
avg_metrics = {
impl: {
"tflops": 0,
"gb_s": 0,
"time_ms": 0
}
for impl in implementations
impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
}
for result in all_results:
@ -410,9 +362,9 @@ def run_benchmarks(verbose: bool = False):
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
avg_rows.append([
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
])
avg_rows.append(
[impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
)
print_table(avg_headers, avg_rows)
@ -420,21 +372,19 @@ def run_benchmarks(verbose: bool = False):
avg_speedups = {
"DeepGEMM vs vLLM Triton": 0,
"DeepGEMM vs vLLM CUTLASS": 0,
"vLLM CUTLASS vs vLLM Triton": 0
"vLLM CUTLASS vs vLLM Triton": 0,
}
for result in all_results:
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
"time_ms"]
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
avg_speedups[
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
avg_speedups[
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
avg_speedups[
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
vllm_triton_time / vllm_cutlass_time
)
print("\n===== AVERAGE SPEEDUPS =====")
speedup_headers = ["Comparison", "Speedup"]
@ -446,14 +396,12 @@ def run_benchmarks(verbose: bool = False):
print_table(speedup_headers, speedup_rows)
# Average accuracy comparison
print("\n===== ACCURACY COMPARISON =====")
avg_diff = {impl: 0 for impl in implementations}
for result in all_results:
for impl in implementations:
avg_diff[impl] += result["implementations"][impl]["diff"][
"Reference"]
avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
diff_headers = ["Implementation", "Avg Diff vs Reference"]
diff_rows = []

View File

@ -66,25 +66,6 @@ def next_power_of_2(x):
return 2**math.ceil(math.log2(x))
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights(
e: int,
n: int,
@ -125,8 +106,8 @@ def make_block_quant_fp8_weights(
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
w1[i], w1_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w2_bf16[i])
return w1, w2, w1_s, w2_s

View File

@ -18,7 +18,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
get_col_major_tma_aligned_tensor, per_token_group_quant_fp8,
w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
dg_available = False
@ -263,25 +264,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
assert rel_diff < 0.03
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@ -299,10 +281,8 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
_, block_k = block_size[0], block_size[1]
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
A_fp8, As_fp8 = deep_gemm.utils.math.per_token_cast_to_fp8(A_fp32)
B_fp8, Bs_fp8 = deep_gemm.utils.math.per_block_cast_to_fp8(B_fp32)
As = As_fp8.to(torch.float32)
Bs = Bs_fp8.to(torch.float32)
@ -310,15 +290,12 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) //
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
deep_gemm.fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
@ -382,16 +359,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
dtype=torch.bfloat16,
device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
inter_out, m_indices)
deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a_q, a_s), (w1, w1_s),
inter_out, m_indices)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
deep_gemm.m_grouped_fp8_gemm_nt_contiguous((act_out_q, act_out_s),
(w2, w2_s), out, m_indices)
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
@ -441,15 +418,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
w1_s = get_col_major_tma_aligned_tensor(w1_s).contiguous()
w2_s = get_col_major_tma_aligned_tensor(w2_s).contiguous()
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(E):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
w1[i], w1_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w2_bf16[i])
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
@ -460,14 +437,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
topk, block_size)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))

View File

@ -266,19 +266,16 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale),
(w1, w1_scale),
out=workspace1,
masked_m=expert_num_tokens,
expected_m=expected_m)
dg.fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
out=workspace1,
masked_m=expert_num_tokens,
expected_m=expected_m)
assert expert_num_tokens is not None
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
out=output,
masked_m=expert_num_tokens,
expected_m=expected_m)
dg.fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale),
out=output,
masked_m=expert_num_tokens,
expected_m=expected_m)

View File

@ -144,8 +144,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
(M_sum, N // 2))
mm2_out = _resize_cache(workspace2, (M_sum, K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
dg.m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
mm1_out, expert_ids)
self.activation(activation, act_out, mm1_out.view(-1, N))
@ -154,9 +154,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.block_shape[1],
column_major_scales=True,
out_q=quant_out)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
dg.m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
mm2_out, expert_ids)
torch.index_select(mm2_out, 0, inv_perm, out=output)

View File

@ -58,7 +58,7 @@ def w8a8_block_fp8_matmul_deepgemm(
output_dtype)
# Deepgemm only supports output tensor type as bfloat16
assert C.dtype == torch.bfloat16
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
deep_gemm.fp8_gemm_nt((A, As), (B, Bs), C)
return C

View File

@ -114,6 +114,10 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def apply_w8a8_block_fp8_linear(
@ -158,9 +162,6 @@ def apply_w8a8_block_fp8_linear(
if current_platform.is_cuda():
if current_platform.has_device_capability(100):
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
use_cutlass = cutlass_block_fp8_supported and (
ceil_div(weight.shape[0], 128) == weight_scale.shape[0]
and ceil_div(weight.shape[1], 128) == weight_scale.shape[1])
@ -655,3 +656,67 @@ def w8a8_block_fp8_matmul(
)
return C
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return ceil_div(x, alignment) * alignment
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along
the M axis (thus meets the requirement of LHS scaling tensor in
DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in
# CUDA
assert x.dim() in (2, 3)
remove_dim = False
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
if x.stride(0) == 1 and x.stride(1) == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(
2) == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
aligned_x = torch.transpose(
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x