mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:14:54 +08:00
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
215 lines
7.3 KiB
Python
215 lines
7.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import copy
|
|
from itertools import product
|
|
from typing import Optional
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
|
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
|
BatchedTritonOrDeepGemmExperts)
|
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|
BatchedTritonExperts)
|
|
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
|
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
|
TritonOrDeepGemmExperts)
|
|
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
|
|
|
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
|
|
reference_moe_impl,
|
|
run_modular_kernel)
|
|
from .modular_kernel_tools.mk_objects import (
|
|
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
|
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
|
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
|
parallel_launch_with_config)
|
|
|
|
# TODO (varun): These requirements are very strict and could be relaxed.
|
|
has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx())
|
|
|
|
meets_package_requirements = pytest.mark.skipif(
|
|
not has_all_packages,
|
|
reason="Requires deep_ep & deep_gemm & pplx packages",
|
|
)
|
|
|
|
|
|
def rank_worker(
|
|
pgi: ProcessGroupInfo,
|
|
vllm_config: VllmConfig,
|
|
cpu_group,
|
|
config: Config,
|
|
weights: WeightTensors,
|
|
):
|
|
current_platform.seed_everything(pgi.rank)
|
|
|
|
# sanity check
|
|
from vllm import envs
|
|
if config.fused_moe_chunk_size is not None:
|
|
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
|
|
|
# get weights to this device
|
|
weights.to_current_device()
|
|
|
|
Ms = config.Ms
|
|
assert isinstance(Ms, list)
|
|
TOPKs = config.topks
|
|
assert isinstance(TOPKs, list)
|
|
|
|
for m, topk in product(Ms, TOPKs):
|
|
print(f"Running m={m}, topk={topk} ...")
|
|
# override m and topk
|
|
cfgx = copy.deepcopy(config)
|
|
cfgx.Ms = m
|
|
cfgx.topks = topk
|
|
|
|
# inputs for rank
|
|
rank_tensors = RankTensors.make(cfgx, pgi)
|
|
|
|
# modular kernel out
|
|
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
|
rank_tensors)
|
|
|
|
with set_current_vllm_config(vllm_config):
|
|
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
|
|
|
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
|
|
|
|
|
|
def run(config: Config):
|
|
assert config.is_valid()
|
|
print(f"Testing config \n{config.describe()} ...")
|
|
|
|
weights: WeightTensors = WeightTensors.make(config)
|
|
|
|
vllm_config, env_dict = config.make_env_data()
|
|
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
|
env_dict, config, weights)
|
|
|
|
|
|
Ms = [32, 64]
|
|
Ks = [7168] # hidden sizes
|
|
Ns = [2048]
|
|
TOPKs = [4, 1]
|
|
Es = [32]
|
|
DTYPEs = [torch.bfloat16]
|
|
FUSED_MOE_CHUNK_SIZEs = [None, 16]
|
|
|
|
|
|
def is_nyi_config(config: Config) -> bool:
|
|
# We know these configs to be legitimate. but still fail.
|
|
|
|
if (config.fused_experts_type in [
|
|
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
|
|
TritonExperts, TritonOrDeepGemmExperts
|
|
]):
|
|
# The triton kernels expect both per-act-token-quant and
|
|
# per-out-ch-quant or neither.
|
|
unsupported_quant_config = ((config.is_per_act_token_quant +
|
|
config.is_per_out_ch_quant) == 1)
|
|
return unsupported_quant_config
|
|
|
|
# cutlass kernels dont support expert_maps yet.
|
|
return config.fused_experts_type == CutlassExpertsFp8
|
|
|
|
|
|
@pytest.mark.parametrize("k", Ks)
|
|
@pytest.mark.parametrize("n", Ns)
|
|
@pytest.mark.parametrize("e", Es)
|
|
@pytest.mark.parametrize("dtype", DTYPEs)
|
|
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
|
|
@pytest.mark.parametrize(
|
|
"combination",
|
|
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
|
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
|
@pytest.mark.parametrize("world_size", [2])
|
|
@meets_package_requirements
|
|
def test_modular_kernel_combinations_multigpu(
|
|
k: int, n: int, e: int, dtype: torch.dtype,
|
|
quant_config: FusedMoEQuantConfig,
|
|
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
|
mk.FusedMoEPermuteExpertsUnpermute],
|
|
fused_moe_chunk_size: Optional[int], world_size: int):
|
|
|
|
config = Config(
|
|
Ms=Ms,
|
|
K=k,
|
|
N=n,
|
|
E=e,
|
|
topks=TOPKs,
|
|
dtype=dtype,
|
|
quant_config=quant_config,
|
|
prepare_finalize_type=combination[0],
|
|
fused_experts_type=combination[1],
|
|
fused_moe_chunk_size=fused_moe_chunk_size,
|
|
world_size=world_size,
|
|
)
|
|
if not config.is_valid():
|
|
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
|
|
|
if is_nyi_config(config):
|
|
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
|
|
|
print(f"{config.describe()}")
|
|
run(config)
|
|
|
|
|
|
@pytest.mark.parametrize("k", Ks)
|
|
@pytest.mark.parametrize("n", Ns)
|
|
@pytest.mark.parametrize("e", Es)
|
|
@pytest.mark.parametrize("dtype", DTYPEs)
|
|
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
|
|
@pytest.mark.parametrize(
|
|
"combination",
|
|
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
|
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
|
@pytest.mark.parametrize("world_size", [1])
|
|
@meets_package_requirements
|
|
def test_modular_kernel_combinations_singlegpu(
|
|
k: int, n: int, e: int, dtype: torch.dtype,
|
|
quant_config: FusedMoEQuantConfig,
|
|
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
|
mk.FusedMoEPermuteExpertsUnpermute],
|
|
fused_moe_chunk_size: Optional[int], world_size: int):
|
|
config = Config(
|
|
Ms=Ms,
|
|
K=k,
|
|
N=n,
|
|
E=e,
|
|
topks=TOPKs,
|
|
dtype=dtype,
|
|
quant_config=quant_config,
|
|
prepare_finalize_type=combination[0],
|
|
fused_experts_type=combination[1],
|
|
fused_moe_chunk_size=fused_moe_chunk_size,
|
|
world_size=world_size,
|
|
)
|
|
|
|
if not config.is_valid():
|
|
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
|
|
|
if is_nyi_config(config):
|
|
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
|
|
|
run(config)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Ability to test individual PrepareAndFinalize and FusedExperts combination
|
|
from .modular_kernel_tools.cli_args import (make_config,
|
|
make_config_arg_parser)
|
|
parser = make_config_arg_parser(description=(
|
|
"Run single prepare-finalize & fused-experts combination test"
|
|
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501
|
|
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
|
))
|
|
args = parser.parse_args()
|
|
config = make_config(args)
|
|
|
|
run(config)
|