[EP] Add logging for experts map (#22685)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
22quinn 2025-08-20 16:46:06 -07:00 committed by GitHub
parent c86af22f31
commit 4b795020ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -695,6 +695,26 @@ def determine_expert_map(
return (local_num_experts, expert_map)
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
"""
Compresses the expert map by removing any -1 entries.
Args:
expert_map (torch.Tensor): A tensor of shape (global_num_experts,)
mapping from global to local index. Contains -1 for experts not
assigned to the current rank.
Returns:
str: A string mapping from local to global index.
Using str to support hashing for logging once only.
"""
global_indices = torch.where(expert_map != -1)[0]
local_indices = expert_map[global_indices]
return ", ".join(
f"{local_index.item()}->{global_index.item()}"
for local_index, global_index in zip(local_indices, global_indices))
@CustomOp.register("fused_moe")
class FusedMoE(CustomOp):
"""FusedMoE layer for MoE models.
@ -795,6 +815,12 @@ class FusedMoE(CustomOp):
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts)
logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
" number of experts: %s/%s. Experts local to global index map:"
" %s.", self.ep_rank, self.ep_size, self.local_num_experts,
self.global_num_experts,
get_compressed_expert_map(self.expert_map))
else:
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)