[TPU][Bugfix] fix the MoE OOM issue (#20339)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao 2025-07-05 21:19:09 -07:00 committed by GitHub
parent 40b86aa05e
commit 4548c03c50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1320,8 +1320,13 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name)
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):