mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:15:34 +08:00
[CI/Build] Only use supported types and features on ROCm in MoE kernel tests (#29149)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
parent
77e1c035d0
commit
fd65015a14
@ -39,6 +39,11 @@ MNK_FACTORS = [
|
|||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64]
|
||||||
TOP_KS = [1, 2, 6]
|
TOP_KS = [1, 2, 6]
|
||||||
|
|
||||||
|
DTYPES = [torch.bfloat16]
|
||||||
|
|
||||||
|
if not current_platform.is_fp8_fnuz():
|
||||||
|
DTYPES.append(torch.float8_e4m3fn)
|
||||||
|
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
|
|
||||||
|
|
||||||
@ -96,7 +101,7 @@ class BatchedMMTensors:
|
|||||||
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
|
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
|
||||||
@pytest.mark.parametrize("K", [128, 1024])
|
@pytest.mark.parametrize("K", [128, 1024])
|
||||||
@pytest.mark.parametrize("N", [128, 1024])
|
@pytest.mark.parametrize("N", [128, 1024])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||||
def test_batched_mm(
|
def test_batched_mm(
|
||||||
@ -229,7 +234,7 @@ def test_batched_mm(
|
|||||||
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
|
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
@pytest.mark.parametrize("topk", TOP_KS)
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||||
@pytest.mark.parametrize("input_scales", [False])
|
@pytest.mark.parametrize("input_scales", [False])
|
||||||
|
|||||||
@ -31,6 +31,11 @@ dg_available = has_deep_gemm()
|
|||||||
|
|
||||||
if current_platform.get_device_capability() < (9, 0):
|
if current_platform.get_device_capability() < (9, 0):
|
||||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||||
|
if current_platform.is_fp8_fnuz():
|
||||||
|
pytest.skip(
|
||||||
|
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
|
|
||||||
|
|||||||
@ -270,6 +270,11 @@ class Case:
|
|||||||
@pytest.mark.parametrize("num_token", [2])
|
@pytest.mark.parametrize("num_token", [2])
|
||||||
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
|
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
|
||||||
def test_equiv(num_token, a_dtype, w_dtype, tp):
|
def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||||
|
from triton_kernels.tensor_details import layout
|
||||||
|
|
||||||
|
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
|
||||||
|
pytest.skip("make_default_matmul_mxfp4_w_layout not available")
|
||||||
|
|
||||||
M = num_token
|
M = num_token
|
||||||
E = ModelConfig.num_experts
|
E = ModelConfig.num_experts
|
||||||
K = ModelConfig.hidden_size
|
K = ModelConfig.hidden_size
|
||||||
|
|||||||
@ -46,6 +46,12 @@ meets_multi_gpu_requirements = pytest.mark.skipif(
|
|||||||
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
|
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if current_platform.is_fp8_fnuz():
|
||||||
|
pytest.skip(
|
||||||
|
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def format_result(verbose, msg, ex=None):
|
def format_result(verbose, msg, ex=None):
|
||||||
if ex is not None:
|
if ex is not None:
|
||||||
|
|||||||
@ -23,6 +23,12 @@ TOP_KS = [2, 6, 8]
|
|||||||
EP_SIZE = [1, 4, 16]
|
EP_SIZE = [1, 4, 16]
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
pytest.skip(
|
||||||
|
"moe_permute_unpermute_supported is not defined for ROCm",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def torch_permute(
|
def torch_permute(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -14,6 +14,12 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
|
from vllm.utils.deep_gemm import DeepGemmQuantScaleFMT, has_deep_gemm
|
||||||
from vllm.utils.math_utils import cdiv, round_up
|
from vllm.utils.math_utils import cdiv, round_up
|
||||||
|
|
||||||
|
if current_platform.is_fp8_fnuz():
|
||||||
|
pytest.skip(
|
||||||
|
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
fp8_dtype = torch.float8_e4m3fn
|
fp8_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
CASES = [
|
CASES = [
|
||||||
|
|||||||
@ -19,6 +19,12 @@ if current_platform.get_device_capability() < (9, 0):
|
|||||||
|
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
|
|
||||||
|
if current_platform.is_fp8_fnuz():
|
||||||
|
pytest.skip(
|
||||||
|
"Tests in this file require float8_e4m3fn and platform does not support",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||||
"""Matrix multiplication function that supports per-token input
|
"""Matrix multiplication function that supports per-token input
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user