diff --git a/tests/models/registry.py b/tests/models/registry.py index 28fe9063169e..739d96227971 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -530,6 +530,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 trust_remote_code=True), + "EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random", + speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501 + trust_remote_code=True), "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", trust_remote_code=True, speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 7b3f45831279..bd0fa6b80781 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -144,6 +144,8 @@ def test_ngram_correctness( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True, marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + (("eagle", "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", 1), False), ], ids=[ # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 @@ -151,7 +153,8 @@ def test_ngram_correctness( "llama3_eagle", "llama3_eagle3", "llama4_eagle", - "llama4_eagle_mm" + "llama4_eagle_mm", + "deepseek_eagle" ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @@ -177,6 +180,7 @@ def test_eagle_correctness( ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_MLA_DISABLE", "1") m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) if (attn_backend == "TRITON_ATTN_VLLM_V1" diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py new file mode 100644 index 000000000000..0c9c83cf6100 --- /dev/null +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer, + DeepseekV3ForCausalLM) +from vllm.model_executor.sampling_metadata import SamplingMetadata + +from .utils import AutoWeightsLoader, maybe_prefix + + +@support_torch_compile +class DeepseekV2Model(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + ) -> None: + super().__init__() + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.vocab_size = self.config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList([ + DeepseekV2DecoderLayer( + self.config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ) for i in range(self.config.num_hidden_layers) + ]) + + self.fc = nn.Linear( + self.config.model.hidden_size * 2, + self.config.model.hidden_size, + bias=False, + ) + + self.enorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.hnorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_embeds = self.embed_tokens(input_ids) + + inputs = torch.cat( + [self.enorm(input_embeds), + self.hnorm(hidden_states)], dim=-1) + hidden_states = self.fc(inputs) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + # if go with fusion option, then update name + if ((param_name == "fused_qkv_a_proj") + and name_mapped not in params_dict): + continue + else: + name = name_mapped + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + quant_config = vllm_config.quant_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + self.model = DeepseekV2Model(vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num) + + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) + return self.model(input_ids, positions, hidden_states) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + ) + + model_weights = {} + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8728684d8e68..a94231b0f846 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -264,6 +264,7 @@ _SPECULATIVE_DECODING_MODELS = { "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"),