mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 10:34:58 +08:00
[CI/Build] Only use supported types and features on ROCm in MoE kernel tests (#29149)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
parent
77e1c035d0
commit
fd65015a14
@ -39,6 +39,11 @@ MNK_FACTORS = [
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [1, 2, 6]
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
|
||||
if not current_platform.is_fp8_fnuz():
|
||||
DTYPES.append(torch.float8_e4m3fn)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
|
||||
@ -96,7 +101,7 @@ class BatchedMMTensors:
|
||||
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
|
||||
@pytest.mark.parametrize("K", [128, 1024])
|
||||
@pytest.mark.parametrize("N", [128, 1024])
|
||||
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
def test_batched_mm(
|
||||
@ -229,7 +234,7 @@ def test_batched_mm(
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("input_scales", [False])
|
||||
|
||||
@ -31,6 +31,11 @@ dg_available = has_deep_gemm()
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
|
||||
@ -270,6 +270,11 @@ class Case:
|
||||
@pytest.mark.parametrize("num_token", [2])
|
||||
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
|
||||
def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
from triton_kernels.tensor_details import layout
|
||||
|
||||
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
|
||||
pytest.skip("make_default_matmul_mxfp4_w_layout not available")
|
||||
|
||||
M = num_token
|
||||
E = ModelConfig.num_experts
|
||||
K = ModelConfig.hidden_size
|
||||
|
||||
@ -46,6 +46,12 @@ meets_multi_gpu_requirements = pytest.mark.skipif(
|
||||
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def format_result(verbose, msg, ex=None):
|
||||
if ex is not None:
|
||||
|
||||
@ -23,6 +23,12 @@ TOP_KS = [2, 6, 8]
|
||||
EP_SIZE = [1, 4, 16]
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"moe_permute_unpermute_supported is not defined for ROCm",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def torch_permute(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -14,6 +14,12 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
|
||||
CASES = [
|
||||
|
||||
@ -19,6 +19,12 @@ if current_platform.get_device_capability() < (9, 0):
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
pytest.skip(
|
||||
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
"""Matrix multiplication function that supports per-token input
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user