diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 929db91775375..5fac7166bc262 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +from math import prod from typing import Optional import pytest @@ -8,9 +9,12 @@ import torch from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8, run_cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) from vllm.platforms import current_platform NUM_EXPERTS = [40, 64] @@ -236,6 +240,7 @@ def test_cutlass_moe_8_bit_no_graph( per_act_token: bool, per_out_ch: bool, monkeypatch, + ep_size: Optional[int] = None, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") @@ -254,7 +259,13 @@ def test_cutlass_moe_8_bit_no_graph( triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids) - cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token) + if ep_size is not None: + assert e % ep_size == 0, "Cannot distribute experts evenly" + number_local_experts = e // ep_size + else: + number_local_experts = None + cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token, + number_local_experts) # Note 5.5 only needed for larger problem sizes, 5 works ok for # the rest. @@ -340,9 +351,62 @@ def test_cutlass_moe_8_bit_EP( per_out_channel: bool, ep_size: int, monkeypatch, +): + test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token, + per_out_channel, monkeypatch, ep_size) + + +LARGE_MNK_FACTORS = [ + (1, 8192, 5120, 31), + (32768, 1024, 1024, 16), + (65536, 512, 1024, 16), +] + + +@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS) +@pytest.mark.parametrize("e", [128]) +@pytest.mark.parametrize("per_act_token", [False]) +@pytest.mark.parametrize("per_out_channel", [True]) +@pytest.mark.parametrize("ep_size", [8]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_8_bit_EP_large( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_channel: bool, + ep_size: int, + monkeypatch, +): + test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token, + per_out_channel, monkeypatch, ep_size) + + +@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)]) +@pytest.mark.parametrize("e", [128]) +@pytest.mark.parametrize("per_act_token", [False]) +@pytest.mark.parametrize("per_out_channel", [True]) +@pytest.mark.parametrize("ep_size", [8]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_run_cutlass_moe_fp8( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_channel: bool, + ep_size: int, ): current_platform.seed_everything(7) - monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) @@ -352,20 +416,53 @@ def test_cutlass_moe_8_bit_EP( score, topk, renormalize=False) + # we want to make sure there is at least one token that's generated in + # this expert shard and at least one token that's NOT generated in this + # expert shard + topk_ids[0][0] = -1 + topk_ids[0][1] = 1 - # Note that we are using the dequantized versions of the tensors. - # Using a, w1 and w2 directly results in minor output differences. - triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, - topk_ids) + workspace13_shape = (m * topk, max(2 * n, k)) + workspace2_shape = (m * topk, n) + output_shape = (m * topk, k) - assert e % ep_size == 0, "Cannot distribute experts evenly" - cutlass_output = run_8_bit(mt, - topk_weights, - topk_ids, - per_act_token, - num_local_experts=e // ep_size) + workspace13 = torch.empty(prod(workspace13_shape), + device="cuda", + dtype=mt.a.dtype) + workspace2 = torch.empty(prod(workspace2_shape), + device="cuda", + dtype=mt.a.dtype) - torch.testing.assert_close(triton_output, - cutlass_output, - atol=5e-2, - rtol=1e-2) + num_local_experts = e // ep_size + start, end = 0, num_local_experts + expert_map = [-1] * e + expert_map[start:end] = list(range(num_local_experts)) + expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") + + activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) + a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, + torch.float8_e4m3fn, + per_act_token) + global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0) + func = lambda output: run_cutlass_moe_fp8( + output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, + global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, + a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, + per_act_token, per_out_channel, False) + + workspace13.random_() + output_random_workspace = torch.empty(output_shape, + device="cuda", + dtype=mt.a.dtype) + func(output_random_workspace) + + workspace13.fill_(0) + output_zero_workspace = torch.zeros(output_shape, + device="cuda", + dtype=mt.a.dtype) + func(output_zero_workspace) + + torch.testing.assert_close(output_random_workspace, + output_zero_workspace, + atol=5e-3, + rtol=1e-3) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 0f41414c4896d..d771a7a54cfc1 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -180,7 +180,11 @@ def run_cutlass_moe_fp8( c2 = _resize_cache(workspace2, (M * topk, N)) c3 = _resize_cache(workspace13, (M * topk, K)) - c1.fill_(0) + if not per_act_token and (expert_map is not None or use_batched_format): + # this is necessary to avoid imprecise scale calculation caused by + # random data in the unused workspace. The workspace is unused when + # this rank handles only partial tokens, or when it is batched . + c1.fill_(0) ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, problem_sizes1, ab_strides1, ab_strides1, c_strides1, @@ -303,7 +307,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): ): assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" - activation_callable = lambda i, o: self.activation(activation, i, o) + activation_callable = lambda o, i: self.activation(activation, o, i) in_dtype = hidden_states.dtype run_cutlass_moe_fp8( output, hidden_states, w1, w2, topk_ids, activation_callable,