From 10042057953cd1528701234925de3d7b109e26de Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Fri, 31 Oct 2025 01:27:39 +0800 Subject: [PATCH] [MTP] Refactor mtp predictor to avoid d2h operation (#27643) Signed-off-by: MengqingCao --- vllm/model_executor/models/deepseek_mtp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index aa176ef05fccb..3984d23970ac5 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -97,7 +97,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ) -> torch.Tensor: assert inputs_embeds is not None # masking inputs at position 0, as not needed by MTP - inputs_embeds[positions == 0] = 0 + inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds) inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states)