diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index bf533bf14e55c..3d06d60105af2 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -618,7 +618,7 @@ class SpeculativeConfig: f"{self.disable_by_batch_size=}" ) - eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"] + eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss", "kimi_k2"] if ( self.method == "eagle3" and self.target_model_config diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b22cdb6d6c80c..14bb1a86edfcc 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -85,7 +85,13 @@ from vllm.v1.attention.backends.mla.indexer import ( from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec from vllm.v1.worker.workspace import current_workspace_manager -from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP +from .interfaces import ( + MixtureOfExperts, + SupportsEagle, + SupportsEagle3, + SupportsLoRA, + SupportsPP, +) from .utils import ( PPMissingLayer, is_pp_missing_parameter, @@ -1381,7 +1387,12 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts): class DeepseekV2ForCausalLM( - nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle + nn.Module, + SupportsPP, + DeepseekV2MixtureOfExperts, + SupportsLoRA, + SupportsEagle, + SupportsEagle3, ): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], @@ -1463,6 +1474,13 @@ class DeepseekV2ForCausalLM( def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor,