[Performance] Reduce DeepGEMM N dim restriction from 128 to 64 multiplier (#28687)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Alexander Matveev 2025-11-19 18:47:13 -05:00 committed by GitHub
parent 8e38e99829
commit 3aaa94ac99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 6 deletions

View File

@ -550,6 +550,26 @@ steps:
commands:
- pytest -v -s kernels/mamba
- label: Kernels DeepGEMM Test (H100)
timeout_in_minutes: 45
gpu: h100
num_gpus: 1
optional: true
source_file_dependencies:
- tools/install_deepgemm.sh
- vllm/utils/deep_gemm.py
- vllm/model_executor/layers/fused_moe
- vllm/model_executor/layers/quantization
- tests/kernels/quantization/test_block_fp8.py
- tests/kernels/moe/test_deepgemm.py
- tests/kernels/moe/test_batched_deepgemm.py
- tests/kernels/attention/test_deepgemm_attention.py
commands:
- pytest -v -s tests/kernels/quantization/test_block_fp8.py -k deep_gemm
- pytest -v -s tests/kernels/moe/test_deepgemm.py
- pytest -v -s tests/kernels/moe/test_batched_deepgemm.py
- pytest -v -s tests/kernels/attention/test_deepgemm_attention.py
- label: Model Executor Test # 23min
timeout_in_minutes: 35
torch_nightly: true

View File

@ -22,6 +22,7 @@ from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
per_block_cast_to_fp8,
should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.import_utils import has_deep_gemm
@ -157,10 +158,6 @@ def test_w8a8_block_fp8_cutlass_matmul():
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes
if M % 4 != 0 or K % 128 != 0 or N % 64 != 0:
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
torch.manual_seed(seed)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max = fp8_info.max
@ -168,6 +165,12 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
# only aligned sizes are supported by deepgemm
if not should_use_deepgemm_for_fp8_linear(
output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
):
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)

View File

@ -365,11 +365,18 @@ def should_use_deepgemm_for_fp8_linear(
):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
# Verify DeepGEMM N/K dims requirements
# NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
# test inside kernels/quatization/test_block_fp8.py
N_MULTIPLE = 64
K_MULTIPLE = 128
return (
supports_deep_gemm
and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0
and weight.shape[0] % N_MULTIPLE == 0
and weight.shape[1] % K_MULTIPLE == 0
)