diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index 43c54d56ca8c1..b99c2099f2c38 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -4,49 +4,16 @@ # ruff: noqa: E501 import time -# Import DeepGEMM functions -import deep_gemm import torch -from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor -# Import vLLM functions from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, w8a8_block_fp8_matmul, ) from vllm.triton_utils import triton - - -# Copied from -# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9 -def per_token_cast_to_fp8( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Convert tensor to FP8 format with per-token scaling.""" - assert x.dim() == 2 and x.size(1) % 128 == 0 - m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to( - torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) - - -# Copied from -# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17 -def per_block_cast_to_fp8( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Convert tensor to FP8 format with per-block scaling.""" - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( - x_amax / 448.0).view(x_view.size(0), x_view.size(2)) +from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8 def benchmark_shape(m: int, @@ -69,14 +36,14 @@ def benchmark_shape(m: int, # Pre-quantize B for all implementations # (weights can be pre-quantized offline) - B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B) - B_vllm, B_scale_vllm = per_block_cast_to_fp8(B) + B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True) + B_vllm, B_scale_vllm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True) # Block size configuration block_size = [128, 128] # Pre-quantize A for all implementations - A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) + A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1]) A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) @@ -85,7 +52,7 @@ def benchmark_shape(m: int, # === DeepGEMM Implementation === def deepgemm_gemm(): - deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), + fp8_gemm_nt((A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm) return C_deepgemm diff --git a/tests/kernels/moe/modular_kernel_tools/utils.py b/tests/kernels/moe/modular_kernel_tools/utils.py index 09bb4a34f3189..866f52882beee 100644 --- a/tests/kernels/moe/modular_kernel_tools/utils.py +++ b/tests/kernels/moe/modular_kernel_tools/utils.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math import torch import vllm._custom_ops as ops +from vllm.utils.deep_gemm import per_block_cast_to_fp8 def per_token_cast_to_fp8( @@ -20,29 +20,6 @@ def per_token_cast_to_fp8( return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) -def per_block_cast_to_fp8( - x: torch.Tensor, block_size_k: int, - block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - ( - int(math.ceil(m / block_size_k)) * block_size_k, - int(math.ceil(n / block_size_n)) * block_size_n, - ), - dtype=x.dtype, - device=x.device, - ) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, block_size_k, - x_padded.size(1) // block_size_k, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - def make_non_quant_weights( e: int, n: int, @@ -99,11 +76,9 @@ def make_block_quant_fp8_weights( for i in range(e): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i], - block_size_k=block_k, - block_size_n=block_n) + block_size=[block_k, block_n]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i], - block_size_k=block_k, - block_size_n=block_n) + block_size=[block_k, block_n]) return w1, w2, w1_s, w2_s diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index 67984fe7319a3..1aee1ed8c3762 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -12,10 +12,8 @@ import torch from tests.kernels.utils import baseline_scaled_mm from vllm import _custom_ops as ops from vllm.platforms import current_platform - - -def cdiv(a, b): - return (a + b - 1) // b +from vllm.utils import cdiv +from vllm.utils.deep_gemm import per_block_cast_to_fp8 def per_token_cast_to_fp8( @@ -32,21 +30,6 @@ def per_token_cast_to_fp8( return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) -def per_block_cast_to_fp8( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros((cdiv(m, 128) * 128, cdiv(n, 128) * 128), - device=x.device, - dtype=x.dtype) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(dtype=torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( - x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - - @pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ (4, 8192, 7168, 4096), (4, 8192, 2048, 7168), diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index 759d2814eefb9..b6ea4ee2324c9 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -69,8 +69,12 @@ def make_block_quant_fp8_weights( dtype=torch.float32) for i in range(e): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i], + block_size=block_size, + use_ue8m0=True) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i], + block_size=block_size, + use_ue8m0=True) return w1, w2, w1_s, w2_s diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index df89ad7e6da6f..c33134981acc0 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -5,8 +5,7 @@ from typing import Optional import torch import vllm._custom_ops as ops -from tests.kernels.quant_utils import (per_block_cast_to_fp8, - per_block_cast_to_int8) +from tests.kernels.quant_utils import per_block_cast_to_int8 from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) @@ -15,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) from vllm.utils import round_up +from vllm.utils.deep_gemm import per_block_cast_to_fp8 def triton_moe( diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 6f43d1111c98e..01a1ad2e7a0a5 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -222,25 +222,6 @@ def native_per_token_group_quant_int8(x, DEFAULT_BLOCK_SHAPE = [128, 128] -def per_block_cast_to_fp8( - x: torch.Tensor, - block_shape: list[int] = DEFAULT_BLOCK_SHAPE, -) -> tuple[torch.Tensor, torch.Tensor]: - block_m, block_n = block_shape - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - def per_block_cast_to_int8( x: torch.Tensor, block_shape: list[int] = DEFAULT_BLOCK_SHAPE, diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 26aa8d652e639..d9154d3fd7f33 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -117,7 +117,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1]) - B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32) + B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size) As = As_fp8.to(torch.float32) Bs = Bs_fp8.to(torch.float32) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index a49a59bd81253..4dedee2a3f862 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -14,7 +14,7 @@ import torch import vllm.envs as envs from vllm.platforms import current_platform -from vllm.utils import has_deep_gemm +from vllm.utils import cdiv, has_deep_gemm @functools.cache @@ -37,7 +37,7 @@ def is_blackwell_deep_gemm_used() -> bool: return False _lazy_init() - if _per_block_cast_impl is None: + if _fp8_gemm_nt_impl is None: return False return (current_platform.is_cuda() @@ -63,18 +63,15 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None: _fp8_gemm_nt_impl: Callable[..., Any] | None = None _grouped_impl: Callable[..., Any] | None = None _grouped_masked_impl: Callable[..., Any] | None = None -_per_block_cast_impl: Callable[..., Any] | None = None def _lazy_init() -> None: """Import deep_gemm and resolve symbols on first use.""" - global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \ - _per_block_cast_impl + global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl # fast path if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None - or _grouped_masked_impl is not None - or _per_block_cast_impl is not None): + or _grouped_masked_impl is not None): return if not has_deep_gemm(): @@ -90,14 +87,6 @@ def _lazy_init() -> None: _grouped_masked_impl = _resolve_symbol( _dg, "fp8_m_grouped_gemm_nt_masked", "m_grouped_gemm_fp8_fp8_bf16_nt_masked") - # Try to get per_token_cast_to_fp8 from DeepGEMM math utils. - try: - _math_mod = importlib.import_module( - "deep_gemm.utils.math") # type: ignore - _per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8", - None) - except ModuleNotFoundError: - _per_block_cast_impl = None def fp8_gemm_nt(*args, **kwargs): @@ -121,13 +110,37 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): return _grouped_masked_impl(*args, **kwargs) -def per_block_cast_to_fp8(x, *args, **kwargs): - _lazy_init() - if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used(): - return _per_block_cast_impl(x, use_ue8m0=True) - # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils - from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf - return _pbcf(x, *args, **kwargs) +def _ceil_to_ue8m0(x: torch.Tensor): + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def _align(x: int, y: int) -> int: + return cdiv(x, y) * y + + +DEFAULT_BLOCK_SIZE = [128, 128] + + +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38 +# TODO(wentao): optimize this function, using triton or cuda kernel +def per_block_cast_to_fp8( + x: torch.Tensor, + block_size: list[int] = DEFAULT_BLOCK_SIZE, + use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + block_m, block_n = block_size + x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(2)) def calc_diff(x: torch.Tensor, y: torch.Tensor):