diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 2e4b3d3a6b202..5f462442148f8 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -540,7 +540,7 @@ class SpeculativeConfig: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") - eagle3_target_supported = ["llama", "qwen", "gpt_oss"] + eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"] if self.method == "eagle3" and self.target_model_config and not any( supported_model in self.target_model_config.hf_text_config.model_type diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 0986ea07406a9..55fe3e2ae3ae7 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -55,7 +55,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -381,6 +381,9 @@ class MiniCPMModel(nn.Module): self.num_experts = getattr(self.config, "num_experts", 0) self._init_layers(prefix, config, cache_config, quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.aux_hidden_state_layers = tuple[int, ...]() + self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], self.config.hidden_size)) @@ -408,7 +411,8 @@ class MiniCPMModel(nn.Module): positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, + list[torch.Tensor]]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -419,18 +423,29 @@ class MiniCPMModel(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): + aux_hidden_states = [] + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer)): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append( + hidden_states + + residual if residual is not None else hidden_states) hidden_states, residual = layer( positions, hidden_states, residual, ) + if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) + hidden_states = self.norm(hidden_states) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states def load_weights(self, weights: Iterable[tuple[str, @@ -502,7 +517,7 @@ class MiniCPMModel(nn.Module): return loaded_params -class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -568,16 +583,36 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) / self.scale_width - return hidden_states + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, + list[torch.Tensor]]]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + + if isinstance(model_output, tuple) and len(model_output) == 2: + # Aux hidden states are present. + hidden_states, aux_hidden_states = model_output + hidden_states = hidden_states / self.scale_width + return hidden_states, aux_hidden_states + else: + # Only hidden states or IntermediateTensors + if isinstance(model_output, IntermediateTensors): + return model_output + else: + hidden_states = model_output / self.scale_width + return hidden_states def compute_logits( self,