mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
[Performance] Reduce DeepGEMM N dim restriction from 128 to 64 multiplier (#28687)
Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
8e38e99829
commit
3aaa94ac99
@ -550,6 +550,26 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s kernels/mamba
|
- pytest -v -s kernels/mamba
|
||||||
|
|
||||||
|
- label: Kernels DeepGEMM Test (H100)
|
||||||
|
timeout_in_minutes: 45
|
||||||
|
gpu: h100
|
||||||
|
num_gpus: 1
|
||||||
|
optional: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- tools/install_deepgemm.sh
|
||||||
|
- vllm/utils/deep_gemm.py
|
||||||
|
- vllm/model_executor/layers/fused_moe
|
||||||
|
- vllm/model_executor/layers/quantization
|
||||||
|
- tests/kernels/quantization/test_block_fp8.py
|
||||||
|
- tests/kernels/moe/test_deepgemm.py
|
||||||
|
- tests/kernels/moe/test_batched_deepgemm.py
|
||||||
|
- tests/kernels/attention/test_deepgemm_attention.py
|
||||||
|
commands:
|
||||||
|
- pytest -v -s tests/kernels/quantization/test_block_fp8.py -k deep_gemm
|
||||||
|
- pytest -v -s tests/kernels/moe/test_deepgemm.py
|
||||||
|
- pytest -v -s tests/kernels/moe/test_batched_deepgemm.py
|
||||||
|
- pytest -v -s tests/kernels/attention/test_deepgemm_attention.py
|
||||||
|
|
||||||
- label: Model Executor Test # 23min
|
- label: Model Executor Test # 23min
|
||||||
timeout_in_minutes: 35
|
timeout_in_minutes: 35
|
||||||
torch_nightly: true
|
torch_nightly: true
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from vllm.utils.deep_gemm import (
|
|||||||
fp8_gemm_nt,
|
fp8_gemm_nt,
|
||||||
get_col_major_tma_aligned_tensor,
|
get_col_major_tma_aligned_tensor,
|
||||||
per_block_cast_to_fp8,
|
per_block_cast_to_fp8,
|
||||||
|
should_use_deepgemm_for_fp8_linear,
|
||||||
)
|
)
|
||||||
from vllm.utils.import_utils import has_deep_gemm
|
from vllm.utils.import_utils import has_deep_gemm
|
||||||
|
|
||||||
@ -157,10 +158,6 @@ def test_w8a8_block_fp8_cutlass_matmul():
|
|||||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
|
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||||
# only aligned sizes
|
|
||||||
if M % 4 != 0 or K % 128 != 0 or N % 64 != 0:
|
|
||||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
fp8_max = fp8_info.max
|
fp8_max = fp8_info.max
|
||||||
@ -168,6 +165,12 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
|
|
||||||
|
# only aligned sizes are supported by deepgemm
|
||||||
|
if not should_use_deepgemm_for_fp8_linear(
|
||||||
|
output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
|
||||||
|
):
|
||||||
|
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||||
|
|
||||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
|
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
|
||||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
|
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
|
||||||
|
|
||||||
|
|||||||
@ -365,11 +365,18 @@ def should_use_deepgemm_for_fp8_linear(
|
|||||||
):
|
):
|
||||||
if supports_deep_gemm is None:
|
if supports_deep_gemm is None:
|
||||||
supports_deep_gemm = is_deep_gemm_supported()
|
supports_deep_gemm = is_deep_gemm_supported()
|
||||||
|
|
||||||
|
# Verify DeepGEMM N/K dims requirements
|
||||||
|
# NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
|
||||||
|
# test inside kernels/quatization/test_block_fp8.py
|
||||||
|
N_MULTIPLE = 64
|
||||||
|
K_MULTIPLE = 128
|
||||||
|
|
||||||
return (
|
return (
|
||||||
supports_deep_gemm
|
supports_deep_gemm
|
||||||
and output_dtype == torch.bfloat16
|
and output_dtype == torch.bfloat16
|
||||||
and weight.shape[0] % 128 == 0
|
and weight.shape[0] % N_MULTIPLE == 0
|
||||||
and weight.shape[1] % 128 == 0
|
and weight.shape[1] % K_MULTIPLE == 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user