From 302ef403a2305e9158064f8e386d1b5284d12cb2 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Wed, 15 Oct 2025 15:16:44 +0800 Subject: [PATCH] [DSA][MLA] Tiny refactor on DeepSeek to make it reusable for different backends (#26656) Signed-off-by: MengqingCao --- vllm/attention/layer.py | 2 ++ vllm/model_executor/models/deepseek_mtp.py | 10 ++++++++-- vllm/model_executor/models/deepseek_v2.py | 3 ++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 8b5b87cba4044..16c5799f7d0be 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -587,6 +587,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): prefix: str = "", use_sparse: bool = False, indexer: object | None = None, + **extra_impl_args, ): super().__init__() self.num_heads = num_heads @@ -639,6 +640,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): v_head_dim=self.v_head_dim, kv_b_proj=kv_b_proj, indexer=indexer, + **extra_impl_args, ) self.use_direct_call = not current_platform.opaque_attention_op() diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index de80833130179..576977b00e616 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -17,9 +17,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name +from .deepseek_v2 import ( + DeepseekV2DecoderLayer, + get_spec_layer_idx_from_weight_name, +) from .interfaces import SupportsPP from .utils import maybe_prefix @@ -56,6 +60,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + self.device = current_platform.device_type + self.is_v32 = hasattr(config, "index_topk") if self.is_v32: topk_tokens = config.index_topk @@ -63,7 +69,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda", + device=self.device, ) else: topk_indices_buffer = None diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 970fa80826aba..3d26327c732ea 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1165,6 +1165,7 @@ class DeepseekV2Model(nn.Module): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config + self.device = current_platform.device_type self.vocab_size = config.vocab_size self.is_v32 = hasattr(config, "index_topk") @@ -1174,7 +1175,7 @@ class DeepseekV2Model(nn.Module): vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda", + device=self.device, ) else: topk_indices_buffer = None