Revert "[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE (#20762) (#21334)

Signed-off-by: Ming Yang <minos.future@gmail.com>
This commit is contained in:
Ming Yang 2025-07-21 21:49:01 -07:00 committed by GitHub
parent 90f1e55421
commit e7b2042681
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 38 additions and 174 deletions

View File

@ -80,11 +80,6 @@ def bench_run(
a, score, topk, renormalize=False a, score, topk, renormalize=False
) )
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
def run_triton_moe( def run_triton_moe(
a: torch.Tensor, a: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
@ -116,10 +111,6 @@ def bench_run(
w2: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
per_act_token: bool, per_act_token: bool,
@ -134,10 +125,6 @@ def bench_run(
topk_ids, topk_ids,
w1_scale, w1_scale,
w2_scale, w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token, per_act_token,
a1_scale=None, a1_scale=None,
) )
@ -149,10 +136,6 @@ def bench_run(
w2_q: torch.Tensor, w2_q: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
): ):
@ -167,10 +150,6 @@ def bench_run(
topk_ids, topk_ids,
w1_scale, w1_scale,
w2_scale, w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token, per_act_token,
a1_scale=None, a1_scale=None,
) )
@ -215,10 +194,6 @@ def bench_run(
w2_q, w2_q,
w1_scale, w1_scale,
w2_scale, w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights, topk_weights,
topk_ids, topk_ids,
) )
@ -256,10 +231,6 @@ def bench_run(
"w1_scale": w1_scale, "w1_scale": w1_scale,
"w2_scale": w2_scale, "w2_scale": w2_scale,
"per_act_token": per_act_token, "per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params # cuda graph params
"cutlass_graph": cutlass_graph, "cutlass_graph": cutlass_graph,
"triton_graph": triton_graph, "triton_graph": triton_graph,
@ -318,10 +289,6 @@ def bench_run(
w2_q, w2_q,
w1_scale, w1_scale,
w2_scale, w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights, topk_weights,
topk_ids, topk_ids,
per_act_token, per_act_token,
@ -330,7 +297,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,

View File

@ -160,30 +160,6 @@ __global__ void shuffleInputRowsKernel(const T* input,
} }
} }
template <typename T>
__global__ void shuffleInputRowsKernelSlow(const T* input,
const int32_t* dst2src_map,
T* output, int64_t num_src_rows,
int64_t num_dst_rows,
int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];
if (blockIdx.x < num_dst_rows) {
// Duplicate and permute rows
auto const* source_row_ptr = input + source_row_idx * num_cols;
auto* dest_row_ptr = output + dest_row_idx * num_cols;
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
for (int elem_index = start_offset; elem_index < num_cols;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
void shuffle_rows(const torch::Tensor& input_tensor, void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map, const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) { torch::Tensor& output_tensor) {
@ -197,24 +173,17 @@ void shuffle_rows(const torch::Tensor& input_tensor,
int64_t const num_src_rows = input_tensor.size(0); int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1); int64_t const num_cols = input_tensor.size(1);
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) { TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
// use slow kernel if num_cols can't be aligned to 128 bits "num_cols must be divisible by 128 / "
MOE_DISPATCH(input_tensor.scalar_type(), [&] { "sizeof(input_tensor.scalar_type()) / 8");
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()), MOE_DISPATCH(input_tensor.scalar_type(), [&] {
dst2src_map.data_ptr<int32_t>(), shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows, reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
num_dest_rows, num_cols); dst2src_map.data_ptr<int32_t>(),
}); reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
} else { num_dest_rows, num_cols);
MOE_DISPATCH(input_tensor.scalar_type(), [&] { });
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
} }
#else #else

View File

@ -207,10 +207,6 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids, 'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale, 'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale, 'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token, 'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale 'a1_scale': None #moe_tensors.a_scale
} }
@ -444,11 +440,6 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts)) expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn, torch.float8_e4m3fn,
@ -457,9 +448,8 @@ def test_run_cutlass_moe_fp8(
func = lambda output: run_cutlass_moe_fp8( func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2, a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
workspace13, workspace2, None, mt.a.dtype, per_act_token, per_act_token, per_out_channel, False)
per_out_channel, False)
workspace13.random_() workspace13.random_()
output_random_workspace = torch.empty(output_shape, output_random_workspace = torch.empty(output_shape,

View File

@ -75,7 +75,6 @@ def pplx_cutlass_moe(
assert torch.cuda.current_device() == pgi.local_rank assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape num_tokens, hidden_dim = a.shape
intermediate_dim = w2.shape[2]
num_experts = w1.shape[0] num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases block_size = hidden_dim # TODO support more cases
device = pgi.device device = pgi.device
@ -124,31 +123,10 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts, num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers) num_dispatchers=num_dispatchers)
ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
experts = CutlassExpertsFp8(num_local_experts, experts = CutlassExpertsFp8(num_local_experts,
out_dtype, out_dtype,
per_act_token, per_act_token,
per_out_ch, per_out_ch,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
num_dispatchers=num_dispatchers, num_dispatchers=num_dispatchers,
use_batched_format=True) use_batched_format=True)

View File

@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
_resize_cache, _resize_cache,
extract_required_args) extract_required_args)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
@ -34,10 +35,6 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
@ -156,11 +153,27 @@ def run_cutlass_moe_fp8(
problem_sizes1, problem_sizes2, a_map, problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K) c_map, global_num_experts, N, K)
a1q = ops.shuffle_rows(a1q, a_map) a1q = _fp8_perm(a1q, a_map)
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map) a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
if per_act_token else a1q_scale)
expert_offsets = expert_offsets[:-1] expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
if use_batched_format: if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N)) c2 = _resize_cache(workspace2, (local_E * padded_M, N))
@ -197,8 +210,7 @@ def run_cutlass_moe_fp8(
else: else:
# We can't do this inplace because output may point to the same tensor # We can't do this inplace because output may point to the same tensor
# as c3. # as c3.
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K), output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
non_blocking=True)
# TODO (bnell): split class batched vs. non-batched? # TODO (bnell): split class batched vs. non-batched?
@ -211,10 +223,6 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
per_out_ch_quant: bool, per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None, num_dispatchers: Optional[int] = None,
use_batched_format: bool = False, use_batched_format: bool = False,
@ -231,10 +239,6 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
self.max_experts_per_worker = max_experts_per_worker self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
self.use_batched_format = use_batched_format self.use_batched_format = use_batched_format
@property @property
@ -314,8 +318,7 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8( run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable, output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1, a2_scale, workspace13, workspace2, expert_num_tokens,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype, self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant, self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format) self.use_batched_format)
@ -329,10 +332,6 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None, per_act_token: Optional[bool] = None,
activation: str = "silu", activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
@ -360,17 +359,6 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N] Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K] Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M] Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@ -403,10 +391,6 @@ def cutlass_moe_fp8(
out_dtype=a.dtype, out_dtype=a.dtype,
per_act_token_quant=per_act_token, per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch, per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
use_batched_format=False, use_batched_format=False,
), ),
) )

View File

@ -859,21 +859,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False) requires_grad=False)
device = layer.w13_weight.device
# ab_strides1 and c_strides2 are the same
self.ab_strides1_c_strides2 = torch.full((layer.local_num_experts, ),
layer.hidden_size,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full((layer.local_num_experts, ),
layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full((layer.local_num_experts, ),
2 * layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
@ -896,10 +881,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
moe.in_dtype, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN, self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL, self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
num_dispatchers=num_dispatchers, num_dispatchers=num_dispatchers,
use_batched_format=use_batched_format, use_batched_format=use_batched_format,
) )
@ -946,8 +927,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias)
indices_type=self.topk_indices_dtype)
per_act_token = ( per_act_token = (
self.input_quant.strategy == QuantizationStrategy.TOKEN) self.input_quant.strategy == QuantizationStrategy.TOKEN)
@ -968,10 +948,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )