mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 08:59:14 +08:00
[Bugfix] Fix Maverick correctness by filling zero to cache space in cutlass_moe (#20167)
Signed-off-by: Ming Yang <yming@meta.com>
This commit is contained in:
parent
d2e841a10a
commit
afb7cff1b9
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
from math import prod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -8,9 +9,12 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
cutlass_moe_fp8, run_cutlass_moe_fp8)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
||||||
fused_topk)
|
fused_topk)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
|
moe_kernel_quantize_input)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
NUM_EXPERTS = [40, 64]
|
NUM_EXPERTS = [40, 64]
|
||||||
@ -236,6 +240,7 @@ def test_cutlass_moe_8_bit_no_graph(
|
|||||||
per_act_token: bool,
|
per_act_token: bool,
|
||||||
per_out_ch: bool,
|
per_out_ch: bool,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
|
ep_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||||
@ -254,7 +259,13 @@ def test_cutlass_moe_8_bit_no_graph(
|
|||||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||||
topk_ids)
|
topk_ids)
|
||||||
|
|
||||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
|
if ep_size is not None:
|
||||||
|
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||||
|
number_local_experts = e // ep_size
|
||||||
|
else:
|
||||||
|
number_local_experts = None
|
||||||
|
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
|
||||||
|
number_local_experts)
|
||||||
|
|
||||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||||
# the rest.
|
# the rest.
|
||||||
@ -340,9 +351,62 @@ def test_cutlass_moe_8_bit_EP(
|
|||||||
per_out_channel: bool,
|
per_out_channel: bool,
|
||||||
ep_size: int,
|
ep_size: int,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
|
):
|
||||||
|
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
|
||||||
|
per_out_channel, monkeypatch, ep_size)
|
||||||
|
|
||||||
|
|
||||||
|
LARGE_MNK_FACTORS = [
|
||||||
|
(1, 8192, 5120, 31),
|
||||||
|
(32768, 1024, 1024, 16),
|
||||||
|
(65536, 512, 1024, 16),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("e", [128])
|
||||||
|
@pytest.mark.parametrize("per_act_token", [False])
|
||||||
|
@pytest.mark.parametrize("per_out_channel", [True])
|
||||||
|
@pytest.mark.parametrize("ep_size", [8])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||||
|
current_platform.get_device_capability()),
|
||||||
|
reason="Grouped gemm is not supported on this GPU type.")
|
||||||
|
def test_cutlass_moe_8_bit_EP_large(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_channel: bool,
|
||||||
|
ep_size: int,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
|
||||||
|
per_out_channel, monkeypatch, ep_size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
|
||||||
|
@pytest.mark.parametrize("e", [128])
|
||||||
|
@pytest.mark.parametrize("per_act_token", [False])
|
||||||
|
@pytest.mark.parametrize("per_out_channel", [True])
|
||||||
|
@pytest.mark.parametrize("ep_size", [8])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||||
|
current_platform.get_device_capability()),
|
||||||
|
reason="Grouped gemm is not supported on this GPU type.")
|
||||||
|
def test_run_cutlass_moe_fp8(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_channel: bool,
|
||||||
|
ep_size: int,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||||
per_out_channel)
|
per_out_channel)
|
||||||
@ -352,20 +416,53 @@ def test_cutlass_moe_8_bit_EP(
|
|||||||
score,
|
score,
|
||||||
topk,
|
topk,
|
||||||
renormalize=False)
|
renormalize=False)
|
||||||
|
# we want to make sure there is at least one token that's generated in
|
||||||
|
# this expert shard and at least one token that's NOT generated in this
|
||||||
|
# expert shard
|
||||||
|
topk_ids[0][0] = -1
|
||||||
|
topk_ids[0][1] = 1
|
||||||
|
|
||||||
# Note that we are using the dequantized versions of the tensors.
|
workspace13_shape = (m * topk, max(2 * n, k))
|
||||||
# Using a, w1 and w2 directly results in minor output differences.
|
workspace2_shape = (m * topk, n)
|
||||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
output_shape = (m * topk, k)
|
||||||
topk_ids)
|
|
||||||
|
|
||||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
workspace13 = torch.empty(prod(workspace13_shape),
|
||||||
cutlass_output = run_8_bit(mt,
|
device="cuda",
|
||||||
topk_weights,
|
dtype=mt.a.dtype)
|
||||||
topk_ids,
|
workspace2 = torch.empty(prod(workspace2_shape),
|
||||||
per_act_token,
|
device="cuda",
|
||||||
num_local_experts=e // ep_size)
|
dtype=mt.a.dtype)
|
||||||
|
|
||||||
torch.testing.assert_close(triton_output,
|
num_local_experts = e // ep_size
|
||||||
cutlass_output,
|
start, end = 0, num_local_experts
|
||||||
atol=5e-2,
|
expert_map = [-1] * e
|
||||||
rtol=1e-2)
|
expert_map[start:end] = list(range(num_local_experts))
|
||||||
|
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||||
|
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
per_act_token)
|
||||||
|
global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
|
||||||
|
func = lambda output: run_cutlass_moe_fp8(
|
||||||
|
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
|
||||||
|
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
|
||||||
|
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
|
||||||
|
per_act_token, per_out_channel, False)
|
||||||
|
|
||||||
|
workspace13.random_()
|
||||||
|
output_random_workspace = torch.empty(output_shape,
|
||||||
|
device="cuda",
|
||||||
|
dtype=mt.a.dtype)
|
||||||
|
func(output_random_workspace)
|
||||||
|
|
||||||
|
workspace13.fill_(0)
|
||||||
|
output_zero_workspace = torch.zeros(output_shape,
|
||||||
|
device="cuda",
|
||||||
|
dtype=mt.a.dtype)
|
||||||
|
func(output_zero_workspace)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output_random_workspace,
|
||||||
|
output_zero_workspace,
|
||||||
|
atol=5e-3,
|
||||||
|
rtol=1e-3)
|
||||||
|
|||||||
@ -180,7 +180,11 @@ def run_cutlass_moe_fp8(
|
|||||||
c2 = _resize_cache(workspace2, (M * topk, N))
|
c2 = _resize_cache(workspace2, (M * topk, N))
|
||||||
c3 = _resize_cache(workspace13, (M * topk, K))
|
c3 = _resize_cache(workspace13, (M * topk, K))
|
||||||
|
|
||||||
c1.fill_(0)
|
if not per_act_token and (expert_map is not None or use_batched_format):
|
||||||
|
# this is necessary to avoid imprecise scale calculation caused by
|
||||||
|
# random data in the unused workspace. The workspace is unused when
|
||||||
|
# this rank handles only partial tokens, or when it is batched .
|
||||||
|
c1.fill_(0)
|
||||||
|
|
||||||
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
||||||
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
|
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
|
||||||
@ -303,7 +307,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
):
|
):
|
||||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||||
activation_callable = lambda i, o: self.activation(activation, i, o)
|
activation_callable = lambda o, i: self.activation(activation, o, i)
|
||||||
in_dtype = hidden_states.dtype
|
in_dtype = hidden_states.dtype
|
||||||
run_cutlass_moe_fp8(
|
run_cutlass_moe_fp8(
|
||||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user