mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 22:44:29 +08:00
[Bugfix] Fix Maverick correctness by filling zero to cache space in cutlass_moe (#20167)
Signed-off-by: Ming Yang <yming@meta.com>
This commit is contained in:
parent
d2e841a10a
commit
afb7cff1b9
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -8,9 +9,12 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8, run_cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
||||
fused_topk)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [40, 64]
|
||||
@ -236,6 +240,7 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
ep_size: Optional[int] = None,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
@ -254,7 +259,13 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
|
||||
if ep_size is not None:
|
||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||
number_local_experts = e // ep_size
|
||||
else:
|
||||
number_local_experts = None
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
|
||||
number_local_experts)
|
||||
|
||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||
# the rest.
|
||||
@ -340,9 +351,62 @@ def test_cutlass_moe_8_bit_EP(
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
|
||||
per_out_channel, monkeypatch, ep_size)
|
||||
|
||||
|
||||
LARGE_MNK_FACTORS = [
|
||||
(1, 8192, 5120, 31),
|
||||
(32768, 1024, 1024, 16),
|
||||
(65536, 512, 1024, 16),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [128])
|
||||
@pytest.mark.parametrize("per_act_token", [False])
|
||||
@pytest.mark.parametrize("per_out_channel", [True])
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_8_bit_EP_large(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
|
||||
per_out_channel, monkeypatch, ep_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
|
||||
@pytest.mark.parametrize("e", [128])
|
||||
@pytest.mark.parametrize("per_act_token", [False])
|
||||
@pytest.mark.parametrize("per_out_channel", [True])
|
||||
@pytest.mark.parametrize("ep_size", [8])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_run_cutlass_moe_fp8(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_channel)
|
||||
@ -352,20 +416,53 @@ def test_cutlass_moe_8_bit_EP(
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
# we want to make sure there is at least one token that's generated in
|
||||
# this expert shard and at least one token that's NOT generated in this
|
||||
# expert shard
|
||||
topk_ids[0][0] = -1
|
||||
topk_ids[0][1] = 1
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
workspace13_shape = (m * topk, max(2 * n, k))
|
||||
workspace2_shape = (m * topk, n)
|
||||
output_shape = (m * topk, k)
|
||||
|
||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||
cutlass_output = run_8_bit(mt,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
num_local_experts=e // ep_size)
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5e-2,
|
||||
rtol=1e-2)
|
||||
num_local_experts = e // ep_size
|
||||
start, end = 0, num_local_experts
|
||||
expert_map = [-1] * e
|
||||
expert_map[start:end] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token)
|
||||
global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
|
||||
func = lambda output: run_cutlass_moe_fp8(
|
||||
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
|
||||
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
|
||||
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
|
||||
per_act_token, per_out_channel, False)
|
||||
|
||||
workspace13.random_()
|
||||
output_random_workspace = torch.empty(output_shape,
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
func(output_random_workspace)
|
||||
|
||||
workspace13.fill_(0)
|
||||
output_zero_workspace = torch.zeros(output_shape,
|
||||
device="cuda",
|
||||
dtype=mt.a.dtype)
|
||||
func(output_zero_workspace)
|
||||
|
||||
torch.testing.assert_close(output_random_workspace,
|
||||
output_zero_workspace,
|
||||
atol=5e-3,
|
||||
rtol=1e-3)
|
||||
|
||||
@ -180,7 +180,11 @@ def run_cutlass_moe_fp8(
|
||||
c2 = _resize_cache(workspace2, (M * topk, N))
|
||||
c3 = _resize_cache(workspace13, (M * topk, K))
|
||||
|
||||
c1.fill_(0)
|
||||
if not per_act_token and (expert_map is not None or use_batched_format):
|
||||
# this is necessary to avoid imprecise scale calculation caused by
|
||||
# random data in the unused workspace. The workspace is unused when
|
||||
# this rank handles only partial tokens, or when it is batched .
|
||||
c1.fill_(0)
|
||||
|
||||
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
||||
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
|
||||
@ -303,7 +307,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
activation_callable = lambda i, o: self.activation(activation, i, o)
|
||||
activation_callable = lambda o, i: self.activation(activation, o, i)
|
||||
in_dtype = hidden_states.dtype
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user