mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 08:05:31 +08:00
[Refactor] Remove Duplicate per_block_cast_to_fp8, Remove Dependencies of DeepGEMM (#21787)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
0bd409cf01
commit
3700642013
@ -4,49 +4,16 @@
|
||||
# ruff: noqa: E501
|
||||
import time
|
||||
|
||||
# Import DeepGEMM functions
|
||||
import deep_gemm
|
||||
import torch
|
||||
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||
|
||||
# Import vLLM functions
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
# Copied from
|
||||
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert tensor to FP8 format with per-token scaling."""
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
|
||||
torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
# Copied from
|
||||
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert tensor to FP8 format with per-block scaling."""
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
|
||||
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8
|
||||
|
||||
|
||||
def benchmark_shape(m: int,
|
||||
@ -69,14 +36,14 @@ def benchmark_shape(m: int,
|
||||
|
||||
# Pre-quantize B for all implementations
|
||||
# (weights can be pre-quantized offline)
|
||||
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B)
|
||||
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B)
|
||||
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True)
|
||||
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True)
|
||||
|
||||
# Block size configuration
|
||||
block_size = [128, 128]
|
||||
|
||||
# Pre-quantize A for all implementations
|
||||
A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
||||
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
|
||||
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
@ -85,7 +52,7 @@ def benchmark_shape(m: int,
|
||||
|
||||
# === DeepGEMM Implementation ===
|
||||
def deepgemm_gemm():
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
|
||||
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm),
|
||||
(B_deepgemm, B_scale_deepgemm),
|
||||
C_deepgemm)
|
||||
return C_deepgemm
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
@ -20,29 +20,6 @@ def per_token_cast_to_fp8(
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor, block_size_k: int,
|
||||
block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(
|
||||
int(math.ceil(m / block_size_k)) * block_size_k,
|
||||
int(math.ceil(n / block_size_n)) * block_size_n,
|
||||
),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_size_k,
|
||||
x_padded.size(1) // block_size_k, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def make_non_quant_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
@ -99,11 +76,9 @@ def make_block_quant_fp8_weights(
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
|
||||
block_size_k=block_k,
|
||||
block_size_n=block_n)
|
||||
block_size=[block_k, block_n])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
|
||||
block_size_k=block_k,
|
||||
block_size_n=block_n)
|
||||
block_size=[block_k, block_n])
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
|
||||
|
||||
@ -12,10 +12,8 @@ import torch
|
||||
from tests.kernels.utils import baseline_scaled_mm
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + b - 1) // b
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
@ -32,21 +30,6 @@ def per_token_cast_to_fp8(
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((cdiv(m, 128) * 128, cdiv(n, 128) * 128),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(dtype=torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
|
||||
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
|
||||
(4, 8192, 7168, 4096),
|
||||
(4, 8192, 2048, 7168),
|
||||
|
||||
@ -69,8 +69,12 @@ def make_block_quant_fp8_weights(
|
||||
dtype=torch.float32)
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
|
||||
block_size=block_size,
|
||||
use_ue8m0=True)
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
|
||||
block_size=block_size,
|
||||
use_ue8m0=True)
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
|
||||
|
||||
@ -5,8 +5,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
|
||||
per_block_cast_to_int8)
|
||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
@ -15,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.utils import round_up
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
|
||||
def triton_moe(
|
||||
|
||||
@ -222,25 +222,6 @@ def native_per_token_group_quant_int8(x,
|
||||
DEFAULT_BLOCK_SHAPE = [128, 128]
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
block_m, block_n = block_shape
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def per_block_cast_to_int8(
|
||||
x: torch.Tensor,
|
||||
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
|
||||
|
||||
@ -117,7 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
Bs = Bs_fp8.to(torch.float32)
|
||||
|
||||
@ -14,7 +14,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils import cdiv, has_deep_gemm
|
||||
|
||||
|
||||
@functools.cache
|
||||
@ -37,7 +37,7 @@ def is_blackwell_deep_gemm_used() -> bool:
|
||||
return False
|
||||
|
||||
_lazy_init()
|
||||
if _per_block_cast_impl is None:
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
return False
|
||||
|
||||
return (current_platform.is_cuda()
|
||||
@ -63,18 +63,15 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
|
||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||
_grouped_impl: Callable[..., Any] | None = None
|
||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_per_block_cast_impl: Callable[..., Any] | None = None
|
||||
|
||||
|
||||
def _lazy_init() -> None:
|
||||
"""Import deep_gemm and resolve symbols on first use."""
|
||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \
|
||||
_per_block_cast_impl
|
||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
||||
|
||||
# fast path
|
||||
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
|
||||
or _grouped_masked_impl is not None
|
||||
or _per_block_cast_impl is not None):
|
||||
or _grouped_masked_impl is not None):
|
||||
return
|
||||
|
||||
if not has_deep_gemm():
|
||||
@ -90,14 +87,6 @@ def _lazy_init() -> None:
|
||||
_grouped_masked_impl = _resolve_symbol(
|
||||
_dg, "fp8_m_grouped_gemm_nt_masked",
|
||||
"m_grouped_gemm_fp8_fp8_bf16_nt_masked")
|
||||
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
|
||||
try:
|
||||
_math_mod = importlib.import_module(
|
||||
"deep_gemm.utils.math") # type: ignore
|
||||
_per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
|
||||
None)
|
||||
except ModuleNotFoundError:
|
||||
_per_block_cast_impl = None
|
||||
|
||||
|
||||
def fp8_gemm_nt(*args, **kwargs):
|
||||
@ -121,13 +110,37 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
return _grouped_masked_impl(*args, **kwargs)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x, *args, **kwargs):
|
||||
_lazy_init()
|
||||
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
|
||||
return _per_block_cast_impl(x, use_ue8m0=True)
|
||||
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
|
||||
from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
|
||||
return _pbcf(x, *args, **kwargs)
|
||||
def _ceil_to_ue8m0(x: torch.Tensor):
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
|
||||
def _align(x: int, y: int) -> int:
|
||||
return cdiv(x, y) * y
|
||||
|
||||
|
||||
DEFAULT_BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
|
||||
# TODO(wentao): optimize this function, using triton or cuda kernel
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size: list[int] = DEFAULT_BLOCK_SIZE,
|
||||
use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
block_m, block_n = block_size
|
||||
x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user