From 934a9c3b79e6cb860a8d23b7f317a5f63adf0fae Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 8 Nov 2025 13:01:27 +0800 Subject: [PATCH] [Model] Consolidate Deepseek-MoE implementation with DeepSeek-v2 (#28101) Signed-off-by: Kunshang Ji Signed-off-by: Isotr0py Co-authored-by: Kunshang Ji --- tests/models/registry.py | 5 +- vllm/model_executor/models/deepseek.py | 517 --------------------- vllm/model_executor/models/deepseek_ocr.py | 8 - vllm/model_executor/models/deepseek_v2.py | 152 +++++- vllm/model_executor/models/deepseek_vl2.py | 8 - vllm/model_executor/models/registry.py | 2 +- 6 files changed, 144 insertions(+), 548 deletions(-) delete mode 100644 vllm/model_executor/models/deepseek.py diff --git a/tests/models/registry.py b/tests/models/registry.py index b52f241719e85..7b865c578dd43 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -219,7 +219,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "nvidia/Llama-3_3-Nemotron-Super-49B-v1", trust_remote_code=True, ), - "DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"), + "DeepseekForCausalLM": _HfExamplesInfo( + "deepseek-ai/deepseek-moe-16b-base", + trust_remote_code=True, + ), "DeepseekV2ForCausalLM": _HfExamplesInfo( "deepseek-ai/DeepSeek-V2-Lite-Chat", trust_remote_code=True, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py deleted file mode 100644 index 36cc12b51f13f..0000000000000 --- a/vllm/model_executor/models/deepseek.py +++ /dev/null @@ -1,517 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Deepseek model.""" - -from collections.abc import Iterable -from itertools import islice -from typing import Any - -import torch -from torch import nn -from transformers import PretrainedConfig - -from vllm.attention import Attention -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import ( - get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsLoRA, SupportsPP -from .utils import ( - AutoWeightsLoader, - extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, - make_layers, - maybe_prefix, -) - - -class DeepseekMLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: QuantizationConfig | None = None, - reduce_results: bool = True, - prefix: str = "", - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj", - ) - if hidden_act != "silu": - raise ValueError( - f"Unsupported activation: {hidden_act}. Only silu is supported for now." - ) - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - - -class DeepseekMoE(nn.Module): - def __init__( - self, - config: PretrainedConfig, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.n_routed_experts = config.n_routed_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.n_routed_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.n_routed_experts}." - ) - - self.experts = nn.ModuleList( - [ - DeepseekMLP( - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - ) - for idx in range(self.n_routed_experts) - ] - ) - self.pack_params() - - self.gate = ReplicatedLinear( - config.hidden_size, self.n_routed_experts, bias=False, quant_config=None - ) - - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekMLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - ) - - def pack_params(self): - w1 = [] - w2 = [] - for expert in self.experts: - w1.append(expert.gate_up_proj.weight) - w2.append(expert.down_proj.weight) - self.w1 = torch._utils._flatten_dense_tensors(w1) - w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) - for data, param in zip(w1s, w1): - param.data = data - self.w1 = self.w1.view(len(w1), *w1s[0].shape) - - self.w2 = torch._utils._flatten_dense_tensors(w2) - w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) - for data, param in zip(w2s, w2): - param.data = data - - self.w2 = self.w2.view(len(w2), *w2s[0].shape) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - if self.config.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - topk_weights, topk_ids, _ = fused_topk( - hidden_states, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - ) - - final_hidden_states = fused_experts( - hidden_states, self.w1, self.w2, topk_weights, topk_ids, inplace=True - ) - - if self.config.n_shared_experts is not None: - final_hidden_states = final_hidden_states + shared_output - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_dim) - - -class DeepseekAttention(nn.Module): - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: dict[str, Any] | None = None, - max_position_embeddings: int = 8192, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.o_proj(attn_output) - return output - - -class DeepseekDecoderLayer(nn.Module): - def __init__( - self, - config: PretrainedConfig, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ) -> None: - super().__init__() - layer_idx = extract_layer_index(prefix) - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - moe_layer_freq = getattr(config, "moe_layer_freq", 1) - self.self_attn = DeepseekAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % moe_layer_freq == 0 - ): - self.mlp = DeepseekMoE( - config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" - ) - else: - self.mlp = DeepseekMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor | None, - ) -> torch.Tensor: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -class DeepseekModel(nn.Module): - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: DeepseekDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix - ), - prefix=f"{prefix}.layers", - ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) - if not get_pp_group().is_last_rank: - return IntermediateTensors( - {"hidden_states": hidden_states, "residual": residual} - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if ( - "mlp.experts." in name or "mlp.shared_experts." in name - ) and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if ( - "mlp.experts." in name or "mlp.shared_experts." in name - ) and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class DeepseekForCausalLM(nn.Module, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = DeepseekModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - logits = self.logits_processor(self.lm_head, hidden_states) - return logits - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index bfde8328da6e1..0432567521843 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -417,18 +417,10 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" ) - if self.text_config.topk_method == "noaux_tc": - architectures = ["DeepseekV3ForCausalLM"] - elif not self.text_config.use_mla: - architectures = ["DeepseekForCausalLM"] - else: - architectures = ["DeepseekV2ForCausalLM"] - self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=self.text_config, prefix=maybe_prefix(prefix, "language_model"), - architectures=architectures, ) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 4858c30baab84..63eaf63cc3c48 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -58,6 +58,7 @@ from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, + QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) @@ -104,6 +105,92 @@ elif current_platform.is_xpu(): logger = init_logger(__name__) +class DeepseekAttention(nn.Module): + """Normal MHA implementation used by Deepseek v1.""" + + def __init__( + self, + vllm_config: VllmConfig, + config: DeepseekV2Config | DeepseekV3Config, + hidden_size: int, + num_heads: int, + rope_theta: float = 10000, + rope_scaling: dict[str, Any] | None = None, + max_position_embeddings: int = 8192, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + class DeepseekV2MLP(nn.Module): def __init__( self, @@ -163,7 +250,7 @@ class DeepseekV2MoE(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() - self.routed_scaling_factor = config.routed_scaling_factor + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) self.ep_group = get_ep_group().device_group self.ep_rank = get_ep_group().rank_in_group @@ -186,7 +273,7 @@ class DeepseekV2MoE(nn.Module): quant_config=None, prefix=f"{prefix}.gate", ) - if config.topk_method == "noaux_tc": + if getattr(config, "topk_method", None) == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( torch.empty(config.n_routed_experts, dtype=torch.float32) ) @@ -236,10 +323,10 @@ class DeepseekV2MoE(nn.Module): renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, + num_expert_group=getattr(config, "n_group", 1), + topk_group=getattr(config, "topk_group", 1), prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, + scoring_func=getattr(config, "scoring_func", "softmax"), # we do scaling outside, set factor to 1.0 to avoid double mul # aiter applies routed_scaling_factor internally routed_scaling_factor=1.0 @@ -999,7 +1086,19 @@ class DeepseekV2DecoderLayer(nn.Module): # with the layer's index. layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx - if model_config.use_mla: + + # verify MLA attention specific fields + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + v_head_dim = getattr(config, "v_head_dim", 0) + kv_lora_rank = getattr(config, "kv_lora_rank", 0) + use_mha = config.model_type == "deepseek" or all( + dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) + ) + + if use_mha: + attn_cls = DeepseekAttention + elif model_config.use_mla: attn_cls = DeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention @@ -1008,11 +1107,11 @@ class DeepseekV2DecoderLayer(nn.Module): config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, + kv_lora_rank=kv_lora_rank, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -1045,7 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module): self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) - self.routed_scaling_factor = config.routed_scaling_factor + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) def forward( self, @@ -1064,7 +1163,10 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states=hidden_states, ) - if hidden_states.dtype == torch.float16: + if ( + not isinstance(self.self_attn, DeepseekAttention) + and hidden_states.dtype == torch.float16 + ): # Fix FP16 overflow # We scale both hidden_states and residual before # rmsnorm, and rmsnorm result would not affect by scale. @@ -1227,6 +1329,15 @@ class DeepseekV2ForCausalLM( self.config = config self.quant_config = quant_config + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + self.use_mha = config.model_type == "deepseek" or all( + dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) + ) + + if self.use_mha: + self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] + # `packed_modules_mapping` needs to be modified before # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the @@ -1265,7 +1376,7 @@ class DeepseekV2ForCausalLM( def set_moe_parameters(self): self.expert_weights = [] - self.num_expert_groups = self.config.n_group + self.num_expert_groups = getattr(self.config, "n_group", 1) self.moe_layers = [] self.moe_mlp_layers = [] @@ -1321,9 +1432,20 @@ class DeepseekV2ForCausalLM( # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), + ] + mla_params_mapping = [ ("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] + mha_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + if self.use_mha: + stacked_params_mapping.extend(mha_params_mapping) + else: + stacked_params_mapping.extend(mla_params_mapping) # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) @@ -1506,6 +1628,10 @@ class DeepseekV2ForCausalLM( return loaded_params +class DeepseekForCausalLM(DeepseekV2ForCausalLM): + pass + + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index ea10245a84ee1..306eef3dca990 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -403,18 +403,10 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" ) - if self.text_config.topk_method == "noaux_tc": - architectures = ["DeepseekV3ForCausalLM"] - elif not self.text_config.use_mla: - architectures = ["DeepseekForCausalLM"] - else: - architectures = ["DeepseekV2ForCausalLM"] - self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=self.text_config, prefix=maybe_prefix(prefix, "language"), - architectures=architectures, ) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index dddbc88069ef1..4af8fa01f562b 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -76,7 +76,7 @@ _TEXT_GENERATION_MODELS = { "CwmForCausalLM": ("llama", "LlamaForCausalLM"), "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), - "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), + "DeepseekForCausalLM": ("deepseek_v2", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),