mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 22:55:01 +08:00
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Signed-off-by: Mickael Seznec <mickael@mistral.ai> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Mickael Seznec <mickael@mistral.ai>
166 lines
5.8 KiB
Python
166 lines
5.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Iterable
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.parallel_state import get_pp_group
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
from vllm.model_executor.models.deepseek_v2 import (
|
|
DeepseekV2DecoderLayer,
|
|
DeepseekV2Model,
|
|
)
|
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
|
from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM
|
|
from vllm.multimodal.inputs import NestedTensors
|
|
|
|
from .utils import (
|
|
_merge_multimodal_embeddings,
|
|
make_empty_intermediate_tensors_factory,
|
|
maybe_prefix,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@support_torch_compile
|
|
class EagleMistralLarge3Model(DeepseekV2Model):
|
|
def __init__(
|
|
self, *, vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0
|
|
):
|
|
nn.Module.__init__(self)
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.vllm_config = vllm_config
|
|
|
|
self.vocab_size = config.vocab_size
|
|
|
|
assert get_pp_group().world_size == 1
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.embed_tokens",
|
|
)
|
|
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
DeepseekV2DecoderLayer(
|
|
vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
|
)
|
|
for i in range(self.config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.start_layer = 0
|
|
self.end_layer = self.config.num_hidden_layers
|
|
|
|
self.fc = RowParallelLinear(
|
|
self.config.hidden_size * 2,
|
|
self.config.hidden_size,
|
|
bias=False,
|
|
input_is_parallel=False,
|
|
quant_config=quant_config,
|
|
return_bias=False,
|
|
)
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
|
["hidden_states", "residual"], config.hidden_size
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_input_ids(input_ids)
|
|
inputs_embeds = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
|
|
output = super().forward(
|
|
input_ids, positions, intermediate_tensors=None, inputs_embeds=inputs_embeds
|
|
)
|
|
assert isinstance(output, torch.Tensor)
|
|
return output
|
|
|
|
|
|
class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
|
|
remapping = MistralLarge3ForCausalLM.remapping | {
|
|
r"eagle_linear\.weight": r"model.fc.weight",
|
|
r"eagle_linear\.qscale_act": r"model.fc.input_scale",
|
|
r"eagle_linear\.qscale_weight": r"model.fc.weight_scale",
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
target_layer_num = vllm_config.model_config.get_num_layers(
|
|
vllm_config.parallel_config
|
|
)
|
|
vllm_config.model_config = vllm_config.speculative_config.draft_model_config
|
|
# draft model quantization config may differ from target model
|
|
self.quant_config = VllmConfig.get_quantization_config(
|
|
vllm_config.speculative_config.draft_model_config, vllm_config.load_config
|
|
)
|
|
vllm_config.quant_config = self.quant_config
|
|
self.model_cls = partial(
|
|
EagleMistralLarge3Model, start_layer_id=target_layer_num
|
|
)
|
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
|
*,
|
|
is_multimodal: torch.Tensor | None = None,
|
|
handle_oov_mm_token: bool = False,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = super().embed_input_ids(input_ids)
|
|
|
|
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
|
return inputs_embeds
|
|
|
|
assert is_multimodal is not None
|
|
|
|
return _merge_multimodal_embeddings(
|
|
inputs_embeds=inputs_embeds,
|
|
multimodal_embeddings=multimodal_embeddings,
|
|
is_multimodal=is_multimodal,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
hidden_states = self.model(input_ids, positions, hidden_states, inputs_embeds)
|
|
return hidden_states, hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
# Pretend we've loaded the embedding and lm_head weights
|
|
# (later copied from target model)
|
|
return super().load_weights(weights) | {
|
|
"model.embed_tokens.weight",
|
|
"lm_head.weight",
|
|
}
|
|
|
|
def embed_input_ids(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: NestedTensors | None = None,
|
|
is_multimodal: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
return self.model.embed_input_ids(input_ids)
|