mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 04:45:01 +08:00
[TPU][Bugfix] fix the MoE OOM issue (#20339)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
40b86aa05e
commit
4548c03c50
@ -1320,8 +1320,13 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor,
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
# TODO: Once the OOM issue for the TPU backend is resolved, we will
|
||||||
self.layer_name)
|
# switch to using the moe_forward custom op.
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
return self.forward_impl(hidden_states, router_logits)
|
||||||
|
else:
|
||||||
|
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||||
|
self.layer_name)
|
||||||
|
|
||||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||||
full_router_logits: torch.Tensor):
|
full_router_logits: torch.Tensor):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user