[Refactor] Remove Duplicate per_block_cast_to_fp8, Remove Dependencies of DeepGEMM (#21787)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-07-31 21:13:27 -04:00 committed by GitHub
parent 0bd409cf01
commit 3700642013
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 55 additions and 132 deletions

View File

@ -4,49 +4,16 @@
# ruff: noqa: E501
import time
# Import DeepGEMM functions
import deep_gemm
import torch
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
# Import vLLM functions
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor,
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
from vllm.triton_utils import triton
# Copied from
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
def per_token_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert tensor to FP8 format with per-token scaling."""
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
# Copied from
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
def per_block_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert tensor to FP8 format with per-block scaling."""
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8
def benchmark_shape(m: int,
@ -69,14 +36,14 @@ def benchmark_shape(m: int,
# Pre-quantize B for all implementations
# (weights can be pre-quantized offline)
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B)
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B)
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True)
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True)
# Block size configuration
block_size = [128, 128]
# Pre-quantize A for all implementations
A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
@ -85,7 +52,7 @@ def benchmark_shape(m: int,
# === DeepGEMM Implementation ===
def deepgemm_gemm():
deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm),
(B_deepgemm, B_scale_deepgemm),
C_deepgemm)
return C_deepgemm

View File

@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import torch
import vllm._custom_ops as ops
from vllm.utils.deep_gemm import per_block_cast_to_fp8
def per_token_cast_to_fp8(
@ -20,29 +20,6 @@ def per_token_cast_to_fp8(
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(
x: torch.Tensor, block_size_k: int,
block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(
int(math.ceil(m / block_size_k)) * block_size_k,
int(math.ceil(n / block_size_n)) * block_size_n,
),
dtype=x.dtype,
device=x.device,
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_size_k,
x_padded.size(1) // block_size_k, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_non_quant_weights(
e: int,
n: int,
@ -99,11 +76,9 @@ def make_block_quant_fp8_weights(
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
block_size_k=block_k,
block_size_n=block_n)
block_size=[block_k, block_n])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
block_size_k=block_k,
block_size_n=block_n)
block_size=[block_k, block_n])
return w1, w2, w1_s, w2_s

View File

@ -12,10 +12,8 @@ import torch
from tests.kernels.utils import baseline_scaled_mm
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
def cdiv(a, b):
return (a + b - 1) // b
from vllm.utils import cdiv
from vllm.utils.deep_gemm import per_block_cast_to_fp8
def per_token_cast_to_fp8(
@ -32,21 +30,6 @@ def per_token_cast_to_fp8(
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((cdiv(m, 128) * 128, cdiv(n, 128) * 128),
device=x.device,
dtype=x.dtype)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(dtype=torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
(4, 8192, 7168, 4096),
(4, 8192, 2048, 7168),

View File

@ -69,8 +69,12 @@ def make_block_quant_fp8_weights(
dtype=torch.float32)
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
block_size=block_size,
use_ue8m0=True)
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
block_size=block_size,
use_ue8m0=True)
return w1, w2, w1_s, w2_s

View File

@ -5,8 +5,7 @@ from typing import Optional
import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
per_block_cast_to_int8)
from tests.kernels.quant_utils import per_block_cast_to_int8
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
@ -15,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.utils import round_up
from vllm.utils.deep_gemm import per_block_cast_to_fp8
def triton_moe(

View File

@ -222,25 +222,6 @@ def native_per_token_group_quant_int8(x,
DEFAULT_BLOCK_SHAPE = [128, 128]
def per_block_cast_to_fp8(
x: torch.Tensor,
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
) -> tuple[torch.Tensor, torch.Tensor]:
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def per_block_cast_to_int8(
x: torch.Tensor,
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,

View File

@ -117,7 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
As = As_fp8.to(torch.float32)
Bs = Bs_fp8.to(torch.float32)

View File

@ -14,7 +14,7 @@ import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils import cdiv, has_deep_gemm
@functools.cache
@ -37,7 +37,7 @@ def is_blackwell_deep_gemm_used() -> bool:
return False
_lazy_init()
if _per_block_cast_impl is None:
if _fp8_gemm_nt_impl is None:
return False
return (current_platform.is_cuda()
@ -63,18 +63,15 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_per_block_cast_impl: Callable[..., Any] | None = None
def _lazy_init() -> None:
"""Import deep_gemm and resolve symbols on first use."""
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \
_per_block_cast_impl
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
# fast path
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
or _grouped_masked_impl is not None
or _per_block_cast_impl is not None):
or _grouped_masked_impl is not None):
return
if not has_deep_gemm():
@ -90,14 +87,6 @@ def _lazy_init() -> None:
_grouped_masked_impl = _resolve_symbol(
_dg, "fp8_m_grouped_gemm_nt_masked",
"m_grouped_gemm_fp8_fp8_bf16_nt_masked")
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
try:
_math_mod = importlib.import_module(
"deep_gemm.utils.math") # type: ignore
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
None)
except ModuleNotFoundError:
_per_block_cast_impl = None
def fp8_gemm_nt(*args, **kwargs):
@ -121,13 +110,37 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
return _grouped_masked_impl(*args, **kwargs)
def per_block_cast_to_fp8(x, *args, **kwargs):
_lazy_init()
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
return _per_block_cast_impl(x, use_ue8m0=True)
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
return _pbcf(x, *args, **kwargs)
def _ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def _align(x: int, y: int) -> int:
return cdiv(x, y) * y
DEFAULT_BLOCK_SIZE = [128, 128]
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
# TODO(wentao): optimize this function, using triton or cuda kernel
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size: list[int] = DEFAULT_BLOCK_SIZE,
use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
block_m, block_n = block_size
x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2))
def calc_diff(x: torch.Tensor, y: torch.Tensor):