diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 2c861723c3966..d533930e1c7aa 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -527,7 +527,7 @@ class SpeculativeConfig: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") - eagle3_target_supported = ["llama", "qwen"] + eagle3_target_supported = ["llama", "qwen", "gpt_oss"] if self.method == "eagle3" and self.target_model_config and not any( supported_model in self.target_model_config.hf_text_config.model_type diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 4fe59f91124dd..7c755a00e1c98 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.utils import cdiv -from .interfaces import SupportsPP +from .interfaces import SupportsEagle3, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -238,6 +238,7 @@ class GptOssModel(nn.Module): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], self.config.hidden_size)) + self.aux_hidden_state_layers = tuple[int, ...]() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embedding(input_ids) @@ -261,8 +262,12 @@ class GptOssModel(nn.Module): x = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + aux_hidden_states = [] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] + if i in self.aux_hidden_state_layers: + aux_hidden_states.append(x if residual is None else x + + residual) x, residual = layer(x, positions, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -270,6 +275,9 @@ class GptOssModel(nn.Module): "residual": residual }) x, _ = self.norm(x, residual) + + if len(aux_hidden_states) > 0: + return x, aux_hidden_states return x def _load_weights_mxfp4( @@ -610,7 +618,7 @@ class GptOssModel(nn.Module): weights, stacked_params_mapping) -class GptOssForCausalLM(nn.Module, SupportsPP): +class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]} hf_to_vllm_mapper = WeightsMapper( @@ -658,6 +666,13 @@ class GptOssForCausalLM(nn.Module, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5dacf60886966..dc97d5c8f39d4 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -823,15 +823,29 @@ class EagleProposer: else: target_language_model = target_model # share embed_tokens with the target model if needed - if get_pp_group().world_size == 1 \ - and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: - logger.info( - "Assuming the EAGLE head shares the same vocab embedding" - " with the target model.") - del self.model.model.embed_tokens - self.model.model.embed_tokens = ( - target_language_model.model.embed_tokens) + if get_pp_group().world_size == 1: + if hasattr(target_language_model.model, 'embed_tokens'): + target_embed_tokens = target_language_model.model.embed_tokens + elif hasattr(target_language_model.model, 'embedding'): + target_embed_tokens = target_language_model.model.embedding + else: + raise AttributeError( + "Target model does not have 'embed_tokens' or 'embedding' " + "attribute") + + # Check if shapes match and we found the embedding + eagle_shape = self.model.model.embed_tokens.weight.shape + target_shape = target_embed_tokens.weight.shape + if eagle_shape == target_shape: + logger.info( + "Assuming the EAGLE head shares the same vocab embedding" + " with the target model.") + del self.model.model.embed_tokens + self.model.model.embed_tokens = target_embed_tokens + else: + logger.info( + "The EAGLE head's vocab embedding will be loaded separately" + " from the target model.") else: logger.info( "The EAGLE head's vocab embedding will be loaded separately"