mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 14:26:07 +08:00
253 lines
9.2 KiB
Python
253 lines
9.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Iterable
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.parallel_state import get_pp_group
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
default_weight_loader,
|
|
maybe_remap_kv_scale_name,
|
|
)
|
|
from vllm.model_executor.models.deepseek_v2 import (
|
|
DeepseekV2DecoderLayer,
|
|
DeepseekV3ForCausalLM,
|
|
)
|
|
|
|
from .utils import AutoWeightsLoader, maybe_prefix
|
|
|
|
|
|
@support_torch_compile
|
|
class DeepseekV2Model(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
start_layer_id: int = 0,
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = vllm_config.speculative_config.draft_model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.vocab_size = self.config.vocab_size
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
self.config.vocab_size,
|
|
self.config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "embed_tokens"),
|
|
)
|
|
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
DeepseekV2DecoderLayer(
|
|
vllm_config,
|
|
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
|
config=self.config,
|
|
)
|
|
for i in range(self.config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
self.fc = nn.Linear(
|
|
self.config.model.hidden_size * 2,
|
|
self.config.model.hidden_size,
|
|
bias=False,
|
|
)
|
|
|
|
self.enorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
|
self.hnorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
|
self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
input_embeds = self.embed_tokens(input_ids)
|
|
|
|
inputs = torch.cat(
|
|
[self.enorm(input_embeds), self.hnorm(hidden_states)], dim=-1
|
|
)
|
|
hidden_states = self.fc(inputs)
|
|
residual = None
|
|
for layer in self.layers:
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
residual,
|
|
)
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states, hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
("fused_qkv_a_proj", "q_a_proj", 0),
|
|
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
|
]
|
|
|
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
ckpt_gate_proj_name="gate_proj",
|
|
ckpt_down_proj_name="down_proj",
|
|
ckpt_up_proj_name="up_proj",
|
|
num_experts=self.config.n_routed_experts,
|
|
)
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
# Skip non-stacked layers and experts (experts handled below).
|
|
if weight_name not in name:
|
|
continue
|
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
|
# Since we handle the experts below in expert_params_mapping,
|
|
# we need to skip here BEFORE we update the name, otherwise
|
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
# will then be updated below in expert_params_mapping
|
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
if ("mlp.experts." in name) and name not in params_dict:
|
|
continue
|
|
name_mapped = name.replace(weight_name, param_name)
|
|
|
|
# QKV fusion is optional, fall back to normal
|
|
# weight loading if it's not enabled
|
|
# if go with fusion option, then update name
|
|
if (
|
|
param_name == "fused_qkv_a_proj"
|
|
) and name_mapped not in params_dict:
|
|
continue
|
|
else:
|
|
name = name_mapped
|
|
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
for mapping in expert_params_mapping:
|
|
param_name, weight_name, expert_id, shard_id = mapping
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(
|
|
param,
|
|
loaded_weight,
|
|
name,
|
|
shard_id=shard_id,
|
|
expert_id=expert_id,
|
|
)
|
|
break
|
|
else:
|
|
# if PP disabled then draft will share embed with target
|
|
if get_pp_group().world_size == 1 and "embed_tokens." in name:
|
|
continue
|
|
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
# Remapping the name of FP8 kv-scale.
|
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
if name is None:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
nn.Module.__init__(self)
|
|
self.config = vllm_config.speculative_config.draft_model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
target_layer_num = vllm_config.model_config.get_num_layers(
|
|
vllm_config.parallel_config
|
|
)
|
|
self.model = DeepseekV2Model(
|
|
vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num
|
|
)
|
|
|
|
self.lm_head = ParallelLMHead(
|
|
self.config.vocab_size,
|
|
self.config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "lm_head"),
|
|
)
|
|
|
|
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
|
self.logits_processor = LogitsProcessor(
|
|
self.config.vocab_size, scale=logit_scale
|
|
)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.get_input_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
if inputs_embeds is not None:
|
|
raise NotImplementedError(
|
|
f"{type(self).__name__} does not support multimodal inputs yet."
|
|
)
|
|
return self.model(input_ids, positions, hidden_states)
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
def transform(inputs):
|
|
name, loaded_weight = inputs
|
|
if "lm_head" not in name:
|
|
name = "model." + name
|
|
return name, loaded_weight
|
|
|
|
loader = AutoWeightsLoader(
|
|
self,
|
|
skip_prefixes=None,
|
|
)
|
|
loader.load_weights(map(transform, weights))
|