mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 10:46:05 +08:00
Eagle3 that supports the Minicpm3 model (#24243)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: liudan <adan@minicpm.com> Co-authored-by: liudan <liudan@qq.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
This commit is contained in:
parent
56aafa8c0b
commit
33f6aaf972
@ -540,7 +540,7 @@ class SpeculativeConfig:
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{self.disable_by_batch_size=}")
|
||||
|
||||
eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
|
||||
eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
|
||||
if self.method == "eagle3" and self.target_model_config and not any(
|
||||
supported_model in
|
||||
self.target_model_config.hf_text_config.model_type
|
||||
|
||||
@ -55,7 +55,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@ -381,6 +381,9 @@ class MiniCPMModel(nn.Module):
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
self._init_layers(prefix, config, cache_config, quant_config)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], self.config.hidden_size))
|
||||
@ -408,7 +411,8 @@ class MiniCPMModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
|
||||
list[torch.Tensor]]]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
@ -419,18 +423,29 @@ class MiniCPMModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
aux_hidden_states = []
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(
|
||||
hidden_states +
|
||||
residual if residual is not None else hidden_states)
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if len(aux_hidden_states) > 0:
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
@ -502,7 +517,7 @@ class MiniCPMModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -568,16 +583,36 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds) / self.scale_width
|
||||
return hidden_states
|
||||
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
|
||||
list[torch.Tensor]]]:
|
||||
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
if isinstance(model_output, tuple) and len(model_output) == 2:
|
||||
# Aux hidden states are present.
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
hidden_states = hidden_states / self.scale_width
|
||||
return hidden_states, aux_hidden_states
|
||||
else:
|
||||
# Only hidden states or IntermediateTensors
|
||||
if isinstance(model_output, IntermediateTensors):
|
||||
return model_output
|
||||
else:
|
||||
hidden_states = model_output / self.scale_width
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user