[CI/Build] Only use supported types and features on ROCm in MoE kernel tests (#29149)

Signed-off-by: Randall Smith <ransmith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
rasmith 2025-11-21 21:34:33 -06:00 committed by GitHub
parent 77e1c035d0
commit fd65015a14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 41 additions and 2 deletions

View File

@ -39,6 +39,11 @@ MNK_FACTORS = [
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
DTYPES = [torch.bfloat16]
if not current_platform.is_fp8_fnuz():
DTYPES.append(torch.float8_e4m3fn)
vllm_config = VllmConfig()
@ -96,7 +101,7 @@ class BatchedMMTensors:
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(
@ -229,7 +234,7 @@ def test_batched_mm(
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])

View File

@ -31,6 +31,11 @@ dg_available = has_deep_gemm()
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
if current_platform.is_fp8_fnuz():
pytest.skip(
"Tests in this file require float8_e4m3fn and platform does not support",
allow_module_level=True,
)
vllm_config = VllmConfig()

View File

@ -270,6 +270,11 @@ class Case:
@pytest.mark.parametrize("num_token", [2])
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
def test_equiv(num_token, a_dtype, w_dtype, tp):
from triton_kernels.tensor_details import layout
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
pytest.skip("make_default_matmul_mxfp4_w_layout not available")
M = num_token
E = ModelConfig.num_experts
K = ModelConfig.hidden_size

View File

@ -46,6 +46,12 @@ meets_multi_gpu_requirements = pytest.mark.skipif(
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
)
if current_platform.is_fp8_fnuz():
pytest.skip(
"Tests in this file require float8_e4m3fn and platform does not support",
allow_module_level=True,
)
def format_result(verbose, msg, ex=None):
if ex is not None:

View File

@ -23,6 +23,12 @@ TOP_KS = [2, 6, 8]
EP_SIZE = [1, 4, 16]
current_platform.seed_everything(0)
if current_platform.is_rocm():
pytest.skip(
"moe_permute_unpermute_supported is not defined for ROCm",
allow_module_level=True,
)
def torch_permute(
hidden_states: torch.Tensor,

View File

@ -14,6 +14,12 @@ from vllm.platforms import current_platform
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
from vllm.utils.math_utils import cdiv, round_up
if current_platform.is_fp8_fnuz():
pytest.skip(
"Tests in this file require float8_e4m3fn and platform does not support",
allow_module_level=True,
)
fp8_dtype = torch.float8_e4m3fn
CASES = [

View File

@ -19,6 +19,12 @@ if current_platform.get_device_capability() < (9, 0):
vllm_config = VllmConfig()
if current_platform.is_fp8_fnuz():
pytest.skip(
"Tests in this file require float8_e4m3fn and platform does not support",
allow_module_level=True,
)
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input