mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 01:25:01 +08:00
Enable Eagle3 speculative decoding for GPT-OSS model (#25246)
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
2f237d3df4
commit
ef85a438da
@ -527,7 +527,7 @@ class SpeculativeConfig:
|
|||||||
"speculative decoding is > 1, but got "
|
"speculative decoding is > 1, but got "
|
||||||
f"{self.disable_by_batch_size=}")
|
f"{self.disable_by_batch_size=}")
|
||||||
|
|
||||||
eagle3_target_supported = ["llama", "qwen"]
|
eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
|
||||||
if self.method == "eagle3" and self.target_model_config and not any(
|
if self.method == "eagle3" and self.target_model_config and not any(
|
||||||
supported_model in
|
supported_model in
|
||||||
self.target_model_config.hf_text_config.model_type
|
self.target_model_config.hf_text_config.model_type
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
from .interfaces import SupportsPP
|
from .interfaces import SupportsEagle3, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
|
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
@ -238,6 +238,7 @@ class GptOssModel(nn.Module):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
make_empty_intermediate_tensors_factory(
|
make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], self.config.hidden_size))
|
["hidden_states", "residual"], self.config.hidden_size))
|
||||||
|
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embedding(input_ids)
|
return self.embedding(input_ids)
|
||||||
@ -261,8 +262,12 @@ class GptOssModel(nn.Module):
|
|||||||
x = intermediate_tensors["hidden_states"]
|
x = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
aux_hidden_states = []
|
||||||
for i in range(self.start_layer, self.end_layer):
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
|
if i in self.aux_hidden_state_layers:
|
||||||
|
aux_hidden_states.append(x if residual is None else x +
|
||||||
|
residual)
|
||||||
x, residual = layer(x, positions, residual)
|
x, residual = layer(x, positions, residual)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return IntermediateTensors({
|
return IntermediateTensors({
|
||||||
@ -270,6 +275,9 @@ class GptOssModel(nn.Module):
|
|||||||
"residual": residual
|
"residual": residual
|
||||||
})
|
})
|
||||||
x, _ = self.norm(x, residual)
|
x, _ = self.norm(x, residual)
|
||||||
|
|
||||||
|
if len(aux_hidden_states) > 0:
|
||||||
|
return x, aux_hidden_states
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _load_weights_mxfp4(
|
def _load_weights_mxfp4(
|
||||||
@ -610,7 +618,7 @@ class GptOssModel(nn.Module):
|
|||||||
weights, stacked_params_mapping)
|
weights, stacked_params_mapping)
|
||||||
|
|
||||||
|
|
||||||
class GptOssForCausalLM(nn.Module, SupportsPP):
|
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
|
||||||
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
|
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
@ -658,6 +666,13 @@ class GptOssForCausalLM(nn.Module, SupportsPP):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.model.make_empty_intermediate_tensors)
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||||
|
self.model.aux_hidden_state_layers = layers
|
||||||
|
|
||||||
|
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||||
|
num_layers = len(self.model.layers)
|
||||||
|
return (2, num_layers // 2, num_layers - 3)
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embeddings(input_ids)
|
return self.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
|||||||
@ -823,15 +823,29 @@ class EagleProposer:
|
|||||||
else:
|
else:
|
||||||
target_language_model = target_model
|
target_language_model = target_model
|
||||||
# share embed_tokens with the target model if needed
|
# share embed_tokens with the target model if needed
|
||||||
if get_pp_group().world_size == 1 \
|
if get_pp_group().world_size == 1:
|
||||||
and self.model.model.embed_tokens.weight.shape \
|
if hasattr(target_language_model.model, 'embed_tokens'):
|
||||||
== target_language_model.model.embed_tokens.weight.shape:
|
target_embed_tokens = target_language_model.model.embed_tokens
|
||||||
logger.info(
|
elif hasattr(target_language_model.model, 'embedding'):
|
||||||
"Assuming the EAGLE head shares the same vocab embedding"
|
target_embed_tokens = target_language_model.model.embedding
|
||||||
" with the target model.")
|
else:
|
||||||
del self.model.model.embed_tokens
|
raise AttributeError(
|
||||||
self.model.model.embed_tokens = (
|
"Target model does not have 'embed_tokens' or 'embedding' "
|
||||||
target_language_model.model.embed_tokens)
|
"attribute")
|
||||||
|
|
||||||
|
# Check if shapes match and we found the embedding
|
||||||
|
eagle_shape = self.model.model.embed_tokens.weight.shape
|
||||||
|
target_shape = target_embed_tokens.weight.shape
|
||||||
|
if eagle_shape == target_shape:
|
||||||
|
logger.info(
|
||||||
|
"Assuming the EAGLE head shares the same vocab embedding"
|
||||||
|
" with the target model.")
|
||||||
|
del self.model.model.embed_tokens
|
||||||
|
self.model.model.embed_tokens = target_embed_tokens
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"The EAGLE head's vocab embedding will be loaded separately"
|
||||||
|
" from the target model.")
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"The EAGLE head's vocab embedding will be loaded separately"
|
"The EAGLE head's vocab embedding will be loaded separately"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user