[DSA][MLA] Tiny refactor on DeepSeek to make it reusable for different backends (#26656)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2025-10-15 15:16:44 +08:00 committed by GitHub
parent 8865da157b
commit 302ef403a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 3 deletions

View File

@ -587,6 +587,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
prefix: str = "",
use_sparse: bool = False,
indexer: object | None = None,
**extra_impl_args,
):
super().__init__()
self.num_heads = num_heads
@ -639,6 +640,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
v_head_dim=self.v_head_dim,
kv_b_proj=kv_b_proj,
indexer=indexer,
**extra_impl_args,
)
self.use_direct_call = not current_platform.opaque_attention_op()

View File

@ -17,9 +17,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name
from .deepseek_v2 import (
DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name,
)
from .interfaces import SupportsPP
from .utils import maybe_prefix
@ -56,6 +60,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.device = current_platform.device_type
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
@ -63,7 +69,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda",
device=self.device,
)
else:
topk_indices_buffer = None

View File

@ -1165,6 +1165,7 @@ class DeepseekV2Model(nn.Module):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.device = current_platform.device_type
self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
@ -1174,7 +1175,7 @@ class DeepseekV2Model(nn.Module):
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda",
device=self.device,
)
else:
topk_indices_buffer = None