[Kernel][Minor] Re-fuse triton moe weight application (#16071)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm 2025-04-04 19:27:34 -04:00 committed by GitHub
parent af51d80fa1
commit d6fc629f4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1297,30 +1297,24 @@ def fused_experts_impl(hidden_states: torch.Tensor,
qintermediate_cache2 = intermediate_cache2 qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale a2q_scale = a2_scale
invoke_fused_moe_kernel( invoke_fused_moe_kernel(qintermediate_cache2,
qintermediate_cache2, w2,
w2, intermediate_cache3,
intermediate_cache3, a2q_scale,
a2q_scale, w2_scale,
w2_scale, w2_zp,
w2_zp, curr_topk_weights,
curr_topk_weights, sorted_token_ids,
sorted_token_ids, expert_ids,
expert_ids, num_tokens_post_padded,
num_tokens_post_padded, True,
False, #True, 1,
1, config,
config, compute_type=compute_type,
compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8,
use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16,
use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a16=use_int4_w4a16, block_shape=block_shape)
block_shape=block_shape)
if True:
intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K)
intermediate_cache3.mul_(
curr_topk_weights.view(tokens_in_chunk, -1, 1))
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) out_hidden_states[begin_chunk_idx:end_chunk_idx])