From 1f291412586110bffdebe597bd9d5f49c1cd4f73 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 24 Sep 2025 18:52:36 -0400 Subject: [PATCH] [Refactor] Use DeepGEMM Col Major TMA Aligned Tensor (#25517) Signed-off-by: yewentao256 --- .../benchmark_fp8_block_dense_gemm.py | 8 ++- tests/kernels/quantization/test_block_fp8.py | 7 +- .../compressed_tensors_moe.py | 6 +- .../model_executor/layers/quantization/fp8.py | 10 +-- .../layers/quantization/utils/fp8_utils.py | 66 +------------------ vllm/utils/deep_gemm.py | 15 ++++- 6 files changed, 34 insertions(+), 78 deletions(-) diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index b3c3742825de..2010b8038563 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -8,12 +8,16 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm, ) from vllm.triton_utils import triton -from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8 +from vllm.utils.deep_gemm import ( + calc_diff, + fp8_gemm_nt, + get_col_major_tma_aligned_tensor, + per_block_cast_to_fp8, +) def benchmark_shape(m: int, diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c0b934fc55ae..e02df540ce9d 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,11 +11,12 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, native_w8a8_block_matmul) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, get_col_major_tma_aligned_tensor, - per_token_group_quant_fp8, w8a8_triton_block_scaled_mm) + cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8 +from vllm.utils.deep_gemm import (fp8_gemm_nt, + get_col_major_tma_aligned_tensor, + per_block_cast_to_fp8) if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index a7d3e920414d..3a81a0059df8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -34,8 +34,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - expert_weight_is_col_major, get_col_major_tma_aligned_tensor, - requant_weight_ue8m0_inplace) + expert_weight_is_col_major, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) @@ -50,7 +49,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import CpuArchEnum, current_platform from vllm.scalar_type import scalar_types -from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import (get_col_major_tma_aligned_tensor, + is_deep_gemm_e8m0_used) logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c4951712baa7..f77e5880209d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -34,9 +34,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, expert_weight_is_col_major, - get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block, - process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy, - requant_weight_ue8m0_inplace, validate_fp8_block_shape) + maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, + process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace, + validate_fp8_block_shape) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) @@ -53,7 +53,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported +from vllm.utils.deep_gemm import (get_col_major_tma_aligned_tensor, + is_deep_gemm_e8m0_used, + is_deep_gemm_supported) from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index a4cfc7d6c15c..441bba6baacc 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -23,7 +23,7 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter, PerTensorScaleParameter) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear) @@ -749,70 +749,6 @@ def w8a8_triton_block_scaled_mm( return C -# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947 -# TODO(wentao): remove this function when DeepGEMM exposes this function -def get_tma_aligned_size(x: int, element_size: int) -> int: - """ - Global memory address of TMA must be 16-byte aligned. - Since we use column-major layout for the LHS scaling tensor, - the M-axis of the LHS scaling tensor needs to be padded to a multiple of - 16 bytes. - - Arguments: - x: original M-axis shape of the LHS scaling tensor. - element_size: element size of the LHS scaling tensor. - - Returns: - M-axis shape of the LHS scaling tensor after padding. - """ - tma_alignment_bytes = 16 - assert tma_alignment_bytes % element_size == 0 - alignment = tma_alignment_bytes // element_size - return cdiv(x, alignment) * alignment - - -# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947 -# TODO(wentao): remove this function when DeepGEMM exposes this function -def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: - """ - Returns TMA-aligned transposed format of the input tensor. `torch.transpose` - will be called if necessary. - If the input tensor is already column-major layout and 16-byte aligned along - the M axis (thus meets the requirement of LHS scaling tensor in - DeepGEMM), this function will do nothing. - - Arguments: - x: usually the LHS scaling tensor in GEMM. - - Returns: - The LHS scaling tensor of TMA-aligned transposed format. - """ - # NOTES: for the extreme performance, you may rewrite/fuse this function in - # CUDA - assert x.dim() in (2, 3) - remove_dim = False - m, n = x.shape[-2], x.shape[-1] - aligned_m = get_tma_aligned_size(m, x.element_size()) - if x.dim() == 2: - if x.stride(0) == 1 and x.stride(1) == aligned_m: - return x - x, remove_dim = x.unsqueeze(0), True - - b = x.shape[0] - - # The last kernel gives a column-major TMA aligned layout - if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride( - 2) == aligned_m: - return x.squeeze(0) if remove_dim else x - - # Normal layout requires transposing - aligned_x = torch.transpose( - torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) - aligned_x[:, :m, :] = x - aligned_x = aligned_x[:, :m, :] - return aligned_x.squeeze(0) if remove_dim else aligned_x - - def requant_weight_ue8m0_inplace( weight: torch.Tensor, weight_scale: torch.Tensor, diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 2f533ca0639f..979c10f2c3e9 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -70,11 +70,13 @@ def _missing(*_: Any, **__: Any) -> NoReturn: _fp8_gemm_nt_impl: Callable[..., Any] | None = None _grouped_impl: Callable[..., Any] | None = None _grouped_masked_impl: Callable[..., Any] | None = None +_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None def _lazy_init() -> None: """Import deep_gemm and resolve symbols on first use.""" - global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl + global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl,\ + _get_mn_major_tma_aligned_tensor_impl # fast path if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None @@ -95,6 +97,16 @@ def _lazy_init() -> None: _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None) _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None) _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None) + _get_mn_major_tma_aligned_tensor_impl = getattr( + _dg, "get_mn_major_tma_aligned_tensor", None) + + +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor""" + _lazy_init() + if _get_mn_major_tma_aligned_tensor_impl is None: + return _missing() + return _get_mn_major_tma_aligned_tensor_impl(x) def fp8_gemm_nt(*args, **kwargs): @@ -191,4 +203,5 @@ __all__ = [ "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", "should_use_deepgemm_for_fp8_linear", + "get_col_major_tma_aligned_tensor", ] \ No newline at end of file