mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:34:58 +08:00
[Refactor] Use DeepGEMM Col Major TMA Aligned Tensor (#25517)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
6160ba4151
commit
1f29141258
@ -8,12 +8,16 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
get_col_major_tma_aligned_tensor,
|
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
w8a8_triton_block_scaled_mm,
|
w8a8_triton_block_scaled_mm,
|
||||||
)
|
)
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8
|
from vllm.utils.deep_gemm import (
|
||||||
|
calc_diff,
|
||||||
|
fp8_gemm_nt,
|
||||||
|
get_col_major_tma_aligned_tensor,
|
||||||
|
per_block_cast_to_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def benchmark_shape(m: int,
|
def benchmark_shape(m: int,
|
||||||
|
|||||||
@ -11,11 +11,12 @@ from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
|||||||
native_w8a8_block_matmul)
|
native_w8a8_block_matmul)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
cutlass_scaled_mm, get_col_major_tma_aligned_tensor,
|
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
|
||||||
per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import has_deep_gemm
|
from vllm.utils import has_deep_gemm
|
||||||
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
|
from vllm.utils.deep_gemm import (fp8_gemm_nt,
|
||||||
|
get_col_major_tma_aligned_tensor,
|
||||||
|
per_block_cast_to_fp8)
|
||||||
|
|
||||||
if current_platform.get_device_capability() < (9, 0):
|
if current_platform.get_device_capability() < (9, 0):
|
||||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||||
|
|||||||
@ -34,8 +34,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
|||||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||||
select_nvfp4_gemm_impl)
|
select_nvfp4_gemm_impl)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
expert_weight_is_col_major, get_col_major_tma_aligned_tensor,
|
expert_weight_is_col_major, requant_weight_ue8m0_inplace)
|
||||||
requant_weight_ue8m0_inplace)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||||
marlin_moe_permute_scales)
|
marlin_moe_permute_scales)
|
||||||
@ -50,7 +49,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import (get_col_major_tma_aligned_tensor,
|
||||||
|
is_deep_gemm_e8m0_used)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -34,9 +34,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|||||||
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
|
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
|
||||||
create_fp8_input_scale, create_fp8_scale_parameter,
|
create_fp8_input_scale, create_fp8_scale_parameter,
|
||||||
create_fp8_weight_parameter, expert_weight_is_col_major,
|
create_fp8_weight_parameter, expert_weight_is_col_major,
|
||||||
get_col_major_tma_aligned_tensor, maybe_post_process_fp8_weight_block,
|
maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy,
|
||||||
process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy,
|
process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace,
|
||||||
requant_weight_ue8m0_inplace, validate_fp8_block_shape)
|
validate_fp8_block_shape)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin,
|
||||||
prepare_moe_fp8_layer_for_marlin)
|
prepare_moe_fp8_layer_for_marlin)
|
||||||
@ -53,7 +53,9 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils import has_deep_gemm
|
from vllm.utils import has_deep_gemm
|
||||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
|
from vllm.utils.deep_gemm import (get_col_major_tma_aligned_tensor,
|
||||||
|
is_deep_gemm_e8m0_used,
|
||||||
|
is_deep_gemm_supported)
|
||||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
|||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import cdiv, direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
|
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
|
||||||
is_deep_gemm_supported,
|
is_deep_gemm_supported,
|
||||||
should_use_deepgemm_for_fp8_linear)
|
should_use_deepgemm_for_fp8_linear)
|
||||||
@ -749,70 +749,6 @@ def w8a8_triton_block_scaled_mm(
|
|||||||
return C
|
return C
|
||||||
|
|
||||||
|
|
||||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
|
|
||||||
# TODO(wentao): remove this function when DeepGEMM exposes this function
|
|
||||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
|
||||||
"""
|
|
||||||
Global memory address of TMA must be 16-byte aligned.
|
|
||||||
Since we use column-major layout for the LHS scaling tensor,
|
|
||||||
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
|
|
||||||
16 bytes.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
x: original M-axis shape of the LHS scaling tensor.
|
|
||||||
element_size: element size of the LHS scaling tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
M-axis shape of the LHS scaling tensor after padding.
|
|
||||||
"""
|
|
||||||
tma_alignment_bytes = 16
|
|
||||||
assert tma_alignment_bytes % element_size == 0
|
|
||||||
alignment = tma_alignment_bytes // element_size
|
|
||||||
return cdiv(x, alignment) * alignment
|
|
||||||
|
|
||||||
|
|
||||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
|
|
||||||
# TODO(wentao): remove this function when DeepGEMM exposes this function
|
|
||||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
|
|
||||||
will be called if necessary.
|
|
||||||
If the input tensor is already column-major layout and 16-byte aligned along
|
|
||||||
the M axis (thus meets the requirement of LHS scaling tensor in
|
|
||||||
DeepGEMM), this function will do nothing.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
x: usually the LHS scaling tensor in GEMM.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The LHS scaling tensor of TMA-aligned transposed format.
|
|
||||||
"""
|
|
||||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in
|
|
||||||
# CUDA
|
|
||||||
assert x.dim() in (2, 3)
|
|
||||||
remove_dim = False
|
|
||||||
m, n = x.shape[-2], x.shape[-1]
|
|
||||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
|
||||||
if x.dim() == 2:
|
|
||||||
if x.stride(0) == 1 and x.stride(1) == aligned_m:
|
|
||||||
return x
|
|
||||||
x, remove_dim = x.unsqueeze(0), True
|
|
||||||
|
|
||||||
b = x.shape[0]
|
|
||||||
|
|
||||||
# The last kernel gives a column-major TMA aligned layout
|
|
||||||
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(
|
|
||||||
2) == aligned_m:
|
|
||||||
return x.squeeze(0) if remove_dim else x
|
|
||||||
|
|
||||||
# Normal layout requires transposing
|
|
||||||
aligned_x = torch.transpose(
|
|
||||||
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
|
||||||
aligned_x[:, :m, :] = x
|
|
||||||
aligned_x = aligned_x[:, :m, :]
|
|
||||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
|
||||||
|
|
||||||
|
|
||||||
def requant_weight_ue8m0_inplace(
|
def requant_weight_ue8m0_inplace(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
|
|||||||
@ -70,11 +70,13 @@ def _missing(*_: Any, **__: Any) -> NoReturn:
|
|||||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||||
_grouped_impl: Callable[..., Any] | None = None
|
_grouped_impl: Callable[..., Any] | None = None
|
||||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||||
|
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
def _lazy_init() -> None:
|
def _lazy_init() -> None:
|
||||||
"""Import deep_gemm and resolve symbols on first use."""
|
"""Import deep_gemm and resolve symbols on first use."""
|
||||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl,\
|
||||||
|
_get_mn_major_tma_aligned_tensor_impl
|
||||||
|
|
||||||
# fast path
|
# fast path
|
||||||
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
|
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
|
||||||
@ -95,6 +97,16 @@ def _lazy_init() -> None:
|
|||||||
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
||||||
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||||
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
|
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
|
||||||
|
_get_mn_major_tma_aligned_tensor_impl = getattr(
|
||||||
|
_dg, "get_mn_major_tma_aligned_tensor", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
|
||||||
|
_lazy_init()
|
||||||
|
if _get_mn_major_tma_aligned_tensor_impl is None:
|
||||||
|
return _missing()
|
||||||
|
return _get_mn_major_tma_aligned_tensor_impl(x)
|
||||||
|
|
||||||
|
|
||||||
def fp8_gemm_nt(*args, **kwargs):
|
def fp8_gemm_nt(*args, **kwargs):
|
||||||
@ -191,4 +203,5 @@ __all__ = [
|
|||||||
"is_deep_gemm_e8m0_used",
|
"is_deep_gemm_e8m0_used",
|
||||||
"is_deep_gemm_supported",
|
"is_deep_gemm_supported",
|
||||||
"should_use_deepgemm_for_fp8_linear",
|
"should_use_deepgemm_for_fp8_linear",
|
||||||
|
"get_col_major_tma_aligned_tensor",
|
||||||
]
|
]
|
||||||
Loading…
x
Reference in New Issue
Block a user