Enable Eagle3 speculative decoding for GPT-OSS model (#25246)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Eldar Kurtić 2025-09-22 10:50:39 +02:00 committed by yewentao256
parent 2f237d3df4
commit ef85a438da
3 changed files with 41 additions and 12 deletions

View File

@ -527,7 +527,7 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
eagle3_target_supported = ["llama", "qwen"]
eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
if self.method == "eagle3" and self.target_model_config and not any(
supported_model in
self.target_model_config.hf_text_config.model_type

View File

@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from .interfaces import SupportsPP
from .interfaces import SupportsEagle3, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
@ -238,6 +238,7 @@ class GptOssModel(nn.Module):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size))
self.aux_hidden_state_layers = tuple[int, ...]()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids)
@ -261,8 +262,12 @@ class GptOssModel(nn.Module):
x = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(x if residual is None else x +
residual)
x, residual = layer(x, positions, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
@ -270,6 +275,9 @@ class GptOssModel(nn.Module):
"residual": residual
})
x, _ = self.norm(x, residual)
if len(aux_hidden_states) > 0:
return x, aux_hidden_states
return x
def _load_weights_mxfp4(
@ -610,7 +618,7 @@ class GptOssModel(nn.Module):
weights, stacked_params_mapping)
class GptOssForCausalLM(nn.Module, SupportsPP):
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(
@ -658,6 +666,13 @@ class GptOssForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

View File

@ -823,15 +823,29 @@ class EagleProposer:
else:
target_language_model = target_model
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1 \
and self.model.model.embed_tokens.weight.shape \
== target_language_model.model.embed_tokens.weight.shape:
if get_pp_group().world_size == 1:
if hasattr(target_language_model.model, 'embed_tokens'):
target_embed_tokens = target_language_model.model.embed_tokens
elif hasattr(target_language_model.model, 'embedding'):
target_embed_tokens = target_language_model.model.embedding
else:
raise AttributeError(
"Target model does not have 'embed_tokens' or 'embedding' "
"attribute")
# Check if shapes match and we found the embedding
eagle_shape = self.model.model.embed_tokens.weight.shape
target_shape = target_embed_tokens.weight.shape
if eagle_shape == target_shape:
logger.info(
"Assuming the EAGLE head shares the same vocab embedding"
" with the target model.")
del self.model.model.embed_tokens
self.model.model.embed_tokens = (
target_language_model.model.embed_tokens)
self.model.model.embed_tokens = target_embed_tokens
else:
logger.info(
"The EAGLE head's vocab embedding will be loaded separately"
" from the target model.")
else:
logger.info(
"The EAGLE head's vocab embedding will be loaded separately"