From 4b795020eda910ecf16c289a23c4a6c119a4b43b Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:46:06 -0700 Subject: [PATCH] [EP] Add logging for experts map (#22685) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Simon Mo --- vllm/model_executor/layers/fused_moe/layer.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index aa8ceda1bb25a..b16c21b7013a0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -695,6 +695,26 @@ def determine_expert_map( return (local_num_experts, expert_map) +def get_compressed_expert_map(expert_map: torch.Tensor) -> str: + """ + Compresses the expert map by removing any -1 entries. + + Args: + expert_map (torch.Tensor): A tensor of shape (global_num_experts,) + mapping from global to local index. Contains -1 for experts not + assigned to the current rank. + + Returns: + str: A string mapping from local to global index. + Using str to support hashing for logging once only. + """ + global_indices = torch.where(expert_map != -1)[0] + local_indices = expert_map[global_indices] + return ", ".join( + f"{local_index.item()}->{global_index.item()}" + for local_index, global_index in zip(local_indices, global_indices)) + + @CustomOp.register("fused_moe") class FusedMoE(CustomOp): """FusedMoE layer for MoE models. @@ -795,6 +815,12 @@ class FusedMoE(CustomOp): ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) + logger.info_once( + "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" + " number of experts: %s/%s. Experts local to global index map:" + " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + self.global_num_experts, + get_compressed_expert_map(self.expert_map)) else: self.local_num_experts, self.expert_map = (self.global_num_experts, None)