diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index aa176ef05fccb..3984d23970ac5 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -97,7 +97,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ) -> torch.Tensor: assert inputs_embeds is not None # masking inputs at position 0, as not needed by MTP - inputs_embeds[positions == 0] = 0 + inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds) inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states)