[MTP] Refactor mtp predictor to avoid d2h operation (#27643)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2025-10-31 01:27:39 +08:00 committed by GitHub
parent ba33e8830d
commit 1004205795
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -97,7 +97,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)