From 4548c03c50d8ef067b296fb8d610f9d8a8178482 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Sat, 5 Jul 2025 21:19:09 -0700 Subject: [PATCH] [TPU][Bugfix] fix the MoE OOM issue (#20339) Signed-off-by: Chengji Yao --- vllm/model_executor/layers/fused_moe/layer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 648dfca374c5b..36ac75a8df4b8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1320,8 +1320,13 @@ class FusedMoE(torch.nn.Module): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + # TODO: Once the OOM issue for the TPU backend is resolved, we will + # switch to using the moe_forward custom op. + if current_platform.is_tpu(): + return self.forward_impl(hidden_states, router_logits) + else: + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor):