[Speculators][Speculative Decoding] Fix Kimi K2 Eagle3 Support

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
chaunceyjiang 2025-12-24 03:05:37 +00:00
parent 27c6c2f98c
commit 3a4929e9ce

View File

@ -85,7 +85,13 @@ from vllm.v1.attention.backends.mla.indexer import (
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
from .interfaces import (
MixtureOfExperts,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
PPMissingLayer,
is_pp_missing_parameter,
@ -1378,7 +1384,12 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts):
class DeepseekV2ForCausalLM(
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
nn.Module,
SupportsPP,
DeepseekV2MixtureOfExperts,
SupportsLoRA,
SupportsEagle,
SupportsEagle3,
):
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
@ -1460,6 +1471,13 @@ class DeepseekV2ForCausalLM(
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor,