[flashinfer] fix FI all2all with FI cutlass moe (#28166)

Signed-off-by: Xiaozhu <mxz297@gmail.com>
This commit is contained in:
Xiaozhu Meng 2025-11-05 21:52:16 -08:00 committed by GitHub
parent bde5039325
commit e31946f86e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -233,12 +233,13 @@ def flashinfer_alltoall_dispatch(
max_num_token = (
max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0]
)
orig_topk_weights_dtype = topk_weights.dtype
alltoall_info, topk_ids, topk_weights, _ = (
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
topk_ids,
topk_weights,
None,
all2all_manager.prepare_workspace,
all2all_manager.prepare_workspace_tensor,
max_num_token,
ep_rank,
ep_size,
@ -247,6 +248,7 @@ def flashinfer_alltoall_dispatch(
top_k,
)
)
topk_weights = topk_weights.view(dtype=orig_topk_weights_dtype)
x, x_sf = moe_kernel_quantize_input(
x,