mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:45:45 +08:00
[TPU][Bugfix] fix the MoE OOM issue (#20339)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
40b86aa05e
commit
4548c03c50
@ -1320,8 +1320,13 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
self.layer_name)
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we will
|
||||
# switch to using the moe_forward custom op.
|
||||
if current_platform.is_tpu():
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
else:
|
||||
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
self.layer_name)
|
||||
|
||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user