diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 648dfca374c5b..36ac75a8df4b8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1320,8 +1320,13 @@ class FusedMoE(torch.nn.Module): def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - return torch.ops.vllm.moe_forward(hidden_states, router_logits, - self.layer_name) + # TODO: Once the OOM issue for the TPU backend is resolved, we will + # switch to using the moe_forward custom op. + if current_platform.is_tpu(): + return self.forward_impl(hidden_states, router_logits) + else: + return torch.ops.vllm.moe_forward(hidden_states, router_logits, + self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor):