Eagle3 that supports the Minicpm3 model (#24243)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: liudan <adan@minicpm.com>
Co-authored-by: liudan <liudan@qq.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
This commit is contained in:
阿丹(adan) 2025-09-27 01:04:57 +08:00 committed by GitHub
parent 56aafa8c0b
commit 33f6aaf972
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 9 deletions

View File

@ -540,7 +540,7 @@ class SpeculativeConfig:
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
if self.method == "eagle3" and self.target_model_config and not any(
supported_model in
self.target_model_config.hf_text_config.model_type

View File

@ -55,7 +55,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -381,6 +381,9 @@ class MiniCPMModel(nn.Module):
self.num_experts = getattr(self.config, "num_experts", 0)
self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size))
@ -408,7 +411,8 @@ class MiniCPMModel(nn.Module):
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
list[torch.Tensor]]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
@ -419,18 +423,29 @@ class MiniCPMModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
aux_hidden_states = []
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(
hidden_states +
residual if residual is not None else hidden_states)
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
@ -502,7 +517,7 @@ class MiniCPMModel(nn.Module):
return loaded_params
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -568,16 +583,36 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds) / self.scale_width
return hidden_states
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
list[torch.Tensor]]]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
if isinstance(model_output, tuple) and len(model_output) == 2:
# Aux hidden states are present.
hidden_states, aux_hidden_states = model_output
hidden_states = hidden_states / self.scale_width
return hidden_states, aux_hidden_states
else:
# Only hidden states or IntermediateTensors
if isinstance(model_output, IntermediateTensors):
return model_output
else:
hidden_states = model_output / self.scale_width
return hidden_states
def compute_logits(
self,