mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-22 14:42:29 +08:00
[Kernel] Add expert_map support to Cutlass FP8 MOE (#16861)
Signed-off-by: varun sundar rabindranath <vsundarr@redhat.com> Co-authored-by: varun sundar rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
c9acbf1141
commit
7b8a2ab76f
@ -46,14 +46,26 @@ __global__ void compute_expert_offsets(
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer, const int topk_length,
|
||||
const int topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
int const blk_expert_id = blockIdx.x;
|
||||
int const num_experts = gridDim.x;
|
||||
int32_t const num_tokens = expert_offsets[num_experts];
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
if (topk_ids[i] == expert_id) {
|
||||
int const expert_id = topk_ids[i];
|
||||
if (expert_id == -1 && blockIdx.x == 0) {
|
||||
// output_permutation is used to re-order the moe outputs. It is
|
||||
// used as c2 = c2[c_map], where c2 is a torch.tensor that is the
|
||||
// output of the cutlass kernels and c_map is the output_permutation.
|
||||
// c2 is initialized to zeros, therefore by setting the output_permutation
|
||||
// to num_tokens, we are guaranteed to fill the moe outputs to zero
|
||||
// for "invalid" topk_ids.
|
||||
output_permutation[i] = num_tokens;
|
||||
} else if (expert_id == blk_expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
@ -83,6 +95,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -12,32 +15,204 @@ from vllm.platforms import current_platform
|
||||
NUM_EXPERTS = [40, 64]
|
||||
TOP_KS = [6, 8]
|
||||
|
||||
|
||||
def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return cutlass_moe_fp8(a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 3072, 1536),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
(224, 3072, 1024),
|
||||
(224, 3072, 1536),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [2, 64, 224])
|
||||
@pytest.mark.parametrize("n", [1024, 3072])
|
||||
@pytest.mark.parametrize("k", [1024, 1536])
|
||||
@dataclasses.dataclass
|
||||
class MOETensors:
|
||||
a: torch.Tensor
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
ab_strides1: torch.Tensor
|
||||
c_strides1: torch.Tensor
|
||||
ab_strides2: torch.Tensor
|
||||
c_strides2: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors(m: int, k: int, n: int, e: int,
|
||||
dtype: torch.dtype) -> "MOETensors":
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
return MOETensors(a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
ab_strides1=ab_strides1,
|
||||
c_strides1=c_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides2=c_strides2)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MOETensors8Bit(MOETensors):
|
||||
# quantized
|
||||
a_q: Optional[torch.Tensor] = None # a -> a_q
|
||||
w1_q: Optional[torch.Tensor] = None # w1 -> w1_q
|
||||
w2_q: Optional[torch.Tensor] = None # w2 -> w2_q
|
||||
a_scale: Optional[torch.Tensor] = None
|
||||
w1_scale: Optional[torch.Tensor] = None
|
||||
w2_scale: Optional[torch.Tensor] = None
|
||||
# dequantized
|
||||
a_d: Optional[torch.Tensor] = None # a -> a_q -> a_d
|
||||
w1_d: Optional[torch.Tensor] = None # w1 -> w1_q -> w1_d
|
||||
w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool) -> "MOETensors8Bit":
|
||||
dtype = torch.half
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype)
|
||||
|
||||
# a -> a_q, w1 -> w1_q, w2 -> w2_q
|
||||
n_b_scales = 2 * n if per_out_channel else 1
|
||||
k_b_scales = k if per_out_channel else 1
|
||||
# Get the right scale for tests.
|
||||
_, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
|
||||
a_scale,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w1[expert],
|
||||
use_per_token_if_dynamic=per_out_channel)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w2[expert],
|
||||
use_per_token_if_dynamic=per_out_channel)
|
||||
|
||||
# a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
|
||||
a_d = a_q.float().mul(a_scale).to(dtype)
|
||||
w1_d = torch.empty_like(moe_tensors_fp16.w1)
|
||||
w2_d = torch.empty_like(moe_tensors_fp16.w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
|
||||
|
||||
return MOETensors8Bit(a=moe_tensors_fp16.a,
|
||||
w1=moe_tensors_fp16.w1,
|
||||
w2=moe_tensors_fp16.w2,
|
||||
ab_strides1=moe_tensors_fp16.ab_strides1,
|
||||
c_strides1=moe_tensors_fp16.c_strides1,
|
||||
ab_strides2=moe_tensors_fp16.ab_strides2,
|
||||
c_strides2=moe_tensors_fp16.c_strides2,
|
||||
a_q=a_q,
|
||||
w1_q=w1_q,
|
||||
w2_q=w2_q,
|
||||
a_scale=a_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a_d=a_d,
|
||||
w1_d=w1_d,
|
||||
w2_d=w2_d)
|
||||
|
||||
|
||||
def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
**cutlass_moe_kwargs):
|
||||
|
||||
def slice_experts():
|
||||
slice_params = [
|
||||
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
|
||||
"c_strides2", "w1_scale", "w2_scale"
|
||||
]
|
||||
full_tensors = {
|
||||
k: v
|
||||
for k, v in cutlass_moe_kwargs.items()
|
||||
if k in slice_params and k in cutlass_moe_kwargs
|
||||
}
|
||||
|
||||
for i in range(0, num_experts, num_local_experts):
|
||||
s, e = i, i + num_local_experts
|
||||
|
||||
# make expert map
|
||||
expert_map = [-1] * num_experts
|
||||
expert_map[s:e] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
|
||||
# update cutlass moe arg with expert_map
|
||||
cutlass_moe_kwargs["expert_map"] = expert_map
|
||||
# update cutlass moe arg tensors
|
||||
for k, t in full_tensors.items():
|
||||
cutlass_moe_kwargs[k] = t[s:e]
|
||||
|
||||
yield cutlass_moe_kwargs
|
||||
|
||||
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
|
||||
for kwargs in slice_experts():
|
||||
out_tensor = out_tensor + cutlass_moe_fp8(**kwargs)
|
||||
|
||||
return out_tensor
|
||||
|
||||
|
||||
def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||
assert not any([
|
||||
t is None for t in [
|
||||
moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale,
|
||||
moe_tensors.w2_scale, moe_tensors.a_scale
|
||||
]
|
||||
])
|
||||
|
||||
kwargs = {
|
||||
'a': moe_tensors.a,
|
||||
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
|
||||
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
|
||||
'topk_weights': topk_weights,
|
||||
'topk_ids_': topk_ids,
|
||||
'ab_strides1': moe_tensors.ab_strides1,
|
||||
'c_strides1': moe_tensors.c_strides1,
|
||||
'ab_strides2': moe_tensors.ab_strides2,
|
||||
'c_strides2': moe_tensors.c_strides2,
|
||||
'w1_scale': moe_tensors.w1_scale,
|
||||
'w2_scale': moe_tensors.w2_scale,
|
||||
'a1_scale': moe_tensors.a_scale
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||
if not with_ep:
|
||||
return cutlass_moe_fp8(**kwargs)
|
||||
|
||||
assert num_local_experts is not None
|
||||
return run_with_expert_maps(
|
||||
num_experts,
|
||||
num_local_experts, # type: ignore[arg-type]
|
||||
**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@ -46,7 +221,7 @@ def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_no_graph(
|
||||
def test_cutlass_moe_8_bit_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
@ -60,80 +235,21 @@ def test_cutlass_moe_no_graph(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
dtype = torch.half
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
# Get the right scale for tests.
|
||||
_, a_scale1 = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(a,
|
||||
a_scale1,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
a_d = a_q.float().mul(a_scale1).to(dtype)
|
||||
|
||||
n_b_scales = 2 * n if per_out_ch else 1
|
||||
k_b_scales = k if per_out_ch else 1
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w1_q = w1_q.transpose(1, 2)
|
||||
w2_q = w2_q.transpose(1, 2)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
w1_d = torch.empty_like(w1)
|
||||
w2_d = torch.empty_like(w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids)
|
||||
|
||||
cutlass_output = cutlass_moe_fp8(a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale1)
|
||||
|
||||
#print(triton_output)
|
||||
#print(cutlass_output)
|
||||
#print("*")
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
@ -141,9 +257,7 @@ def test_cutlass_moe_no_graph(
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [2, 64, 224])
|
||||
@pytest.mark.parametrize("n", [1024, 3072])
|
||||
@pytest.mark.parametrize("k", [1024, 1536])
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@ -152,7 +266,7 @@ def test_cutlass_moe_no_graph(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_cuda_graph(
|
||||
def test_cutlass_moe_8_bit_cuda_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
@ -168,77 +282,83 @@ def test_cutlass_moe_cuda_graph(
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
# Get the right scale for tests.
|
||||
_, a_scale1 = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(a,
|
||||
a_scale1,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
a_d = a_q.float().mul(a_scale1).to(dtype)
|
||||
|
||||
n_b_scales = 2 * n if per_out_ch else 1
|
||||
k_b_scales = k if per_out_ch else 1
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_ch)
|
||||
w1_q = w1_q.transpose(1, 2)
|
||||
w2_q = w2_q.transpose(1, 2)
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
w1_d = torch.empty_like(w1)
|
||||
w2_d = torch.empty_like(w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half()
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
topk_weights, topk_ids = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids)
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale,
|
||||
topk_weights, topk_ids, ab_strides1,
|
||||
c_strides1, ab_strides2, c_strides2)
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
#print(triton_output)
|
||||
#print(cutlass_output)
|
||||
#print("*")
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=9e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64])
|
||||
@pytest.mark.parametrize("n", [1024])
|
||||
@pytest.mark.parametrize("k", [4096])
|
||||
@pytest.mark.parametrize("e", [16])
|
||||
@pytest.mark.parametrize("topk", [1, 8])
|
||||
@pytest.mark.parametrize("per_act_token", [True])
|
||||
@pytest.mark.parametrize("per_out_channel", [True])
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_8_bit_EP(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_channel)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||
cutlass_output = run_8_bit(mt,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_local_experts=e // ep_size)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
@ -1693,6 +1693,7 @@ class ParallelConfig:
|
||||
factors: list[Any] = []
|
||||
factors.append(self.pipeline_parallel_size)
|
||||
factors.append(self.tensor_parallel_size)
|
||||
factors.append(self.enable_expert_parallel)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
@ -15,7 +15,7 @@ def cutlass_moe_fp8(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_ids_: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
@ -23,6 +23,7 @@ def cutlass_moe_fp8(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: torch.dtype = torch.half,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -57,12 +58,19 @@ def cutlass_moe_fp8(
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
- out_dtype (torch.Tensor): The output tensor type.
|
||||
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
||||
every Rank is responsible for a subset of experts. expert_map is a
|
||||
mapping from global expert-id to local expert-id. When expert_map[i]
|
||||
is -1, it means that this Rank is not responsible for global
|
||||
expert-id i.
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is 1.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
||||
assert w1_q.dtype == torch.float8_e4m3fn
|
||||
assert w2_q.dtype == torch.float8_e4m3fn
|
||||
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
||||
@ -96,7 +104,13 @@ def cutlass_moe_fp8(
|
||||
k = w1_q.size(1)
|
||||
n = w2_q.size(1)
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
local_topk_ids = topk_ids_
|
||||
if expert_map is not None:
|
||||
"Translate info from expert_map to topk_ids"
|
||||
local_topk_ids = torch.where(expert_map[topk_ids_] != -1,
|
||||
expert_map[topk_ids_], -1)
|
||||
|
||||
topk = local_topk_ids.size(1)
|
||||
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
@ -120,10 +134,23 @@ def cutlass_moe_fp8(
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
a_map_initializer = torch.empty
|
||||
c2_initializer = torch.empty
|
||||
if expert_map is not None:
|
||||
# With expert_map each Rank processes only a subset of experts. As
|
||||
# a result not all of a_map and c2 tensors are filled. We fill it
|
||||
# zeros for correctness.
|
||||
a_map_initializer = torch.zeros
|
||||
c2_initializer = torch.zeros
|
||||
|
||||
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
||||
a_map = a_map_initializer((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
c_map = torch.empty((local_topk_ids.numel()),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, a_map, c_map, num_experts, n,
|
||||
k)
|
||||
|
||||
@ -131,7 +158,7 @@ def cutlass_moe_fp8(
|
||||
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||
c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype)
|
||||
|
||||
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
|
||||
expert_offsets[:-1], problem_sizes1, ab_strides1,
|
||||
|
||||
@ -67,7 +67,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
else:
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||
and layer.activation == "silu" and layer.expert_map is None):
|
||||
and layer.activation == "silu"):
|
||||
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||
@ -510,8 +510,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert activation == "silu"
|
||||
assert global_num_experts == layer.w13_weight.shape[0]
|
||||
assert expert_map is None
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
@ -542,6 +540,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
out_dtype=x.dtype,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user