Enable Eagle3 speculative decoding for GPT-OSS model (#25246)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Eldar Kurtić 2025-09-22 10:50:39 +02:00 committed by yewentao256
parent 2f237d3df4
commit ef85a438da
3 changed files with 41 additions and 12 deletions

View File

@ -527,7 +527,7 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got " "speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}") f"{self.disable_by_batch_size=}")
eagle3_target_supported = ["llama", "qwen"] eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
if self.method == "eagle3" and self.target_model_config and not any( if self.method == "eagle3" and self.target_model_config and not any(
supported_model in supported_model in
self.target_model_config.hf_text_config.model_type self.target_model_config.hf_text_config.model_type

View File

@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv from vllm.utils import cdiv
from .interfaces import SupportsPP from .interfaces import SupportsEagle3, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
@ -238,6 +238,7 @@ class GptOssModel(nn.Module):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size)) ["hidden_states", "residual"], self.config.hidden_size))
self.aux_hidden_state_layers = tuple[int, ...]()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids) return self.embedding(input_ids)
@ -261,8 +262,12 @@ class GptOssModel(nn.Module):
x = intermediate_tensors["hidden_states"] x = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(x if residual is None else x +
residual)
x, residual = layer(x, positions, residual) x, residual = layer(x, positions, residual)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
@ -270,6 +275,9 @@ class GptOssModel(nn.Module):
"residual": residual "residual": residual
}) })
x, _ = self.norm(x, residual) x, _ = self.norm(x, residual)
if len(aux_hidden_states) > 0:
return x, aux_hidden_states
return x return x
def _load_weights_mxfp4( def _load_weights_mxfp4(
@ -610,7 +618,7 @@ class GptOssModel(nn.Module):
weights, stacked_params_mapping) weights, stacked_params_mapping)
class GptOssForCausalLM(nn.Module, SupportsPP): class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
@ -658,6 +666,13 @@ class GptOssForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)

View File

@ -823,15 +823,29 @@ class EagleProposer:
else: else:
target_language_model = target_model target_language_model = target_model
# share embed_tokens with the target model if needed # share embed_tokens with the target model if needed
if get_pp_group().world_size == 1 \ if get_pp_group().world_size == 1:
and self.model.model.embed_tokens.weight.shape \ if hasattr(target_language_model.model, 'embed_tokens'):
== target_language_model.model.embed_tokens.weight.shape: target_embed_tokens = target_language_model.model.embed_tokens
logger.info( elif hasattr(target_language_model.model, 'embedding'):
"Assuming the EAGLE head shares the same vocab embedding" target_embed_tokens = target_language_model.model.embedding
" with the target model.") else:
del self.model.model.embed_tokens raise AttributeError(
self.model.model.embed_tokens = ( "Target model does not have 'embed_tokens' or 'embedding' "
target_language_model.model.embed_tokens) "attribute")
# Check if shapes match and we found the embedding
eagle_shape = self.model.model.embed_tokens.weight.shape
target_shape = target_embed_tokens.weight.shape
if eagle_shape == target_shape:
logger.info(
"Assuming the EAGLE head shares the same vocab embedding"
" with the target model.")
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
"The EAGLE head's vocab embedding will be loaded separately"
" from the target model.")
else: else:
logger.info( logger.info(
"The EAGLE head's vocab embedding will be loaded separately" "The EAGLE head's vocab embedding will be loaded separately"