[Speculators][Speculative Decoding] Fix Qwen 2 Eagle3 Support (#23337)

This commit is contained in:
PapaGoose 2025-08-22 19:39:19 +03:00 committed by GitHub
parent 613a23b57f
commit 88491c1b6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -52,7 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import is_interleaved
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
@ -442,7 +442,7 @@ class Qwen2Model(nn.Module):
return loaded_params
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -488,6 +488,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor,