mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-17 05:37:03 +08:00
[MTP] Refactor mtp predictor to avoid d2h operation (#27643)
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
parent
ba33e8830d
commit
1004205795
@ -97,7 +97,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds[positions == 0] = 0
|
||||
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user