mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 01:45:01 +08:00
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
329 lines
12 KiB
Python
329 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Inference-only GraniteMoeShared model.
|
|
|
|
The architecture is the same as granitemoe but with the addition of shared
|
|
experts.
|
|
"""
|
|
|
|
from collections.abc import Iterable
|
|
from itertools import islice
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers.models.granitemoeshared import GraniteMoeSharedConfig
|
|
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import get_pp_group
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (
|
|
MergedColumnParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
|
|
from .interfaces import SupportsLoRA, SupportsPP
|
|
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
|
|
|
|
|
|
class GraniteMoeSharedMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: GraniteMoeSharedConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
|
|
self.input_size = config.hidden_size
|
|
self.hidden_size = config.shared_intermediate_size
|
|
self.input_linear = MergedColumnParallelLinear(
|
|
input_size=self.input_size,
|
|
output_sizes=[self.hidden_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.input_linear",
|
|
)
|
|
self.output_linear = RowParallelLinear(
|
|
self.hidden_size,
|
|
self.input_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.output_linear",
|
|
)
|
|
if config.hidden_act != "silu":
|
|
raise ValueError(
|
|
f"Unsupported activation: {config.hidden_act}. "
|
|
"Only silu is supported for now."
|
|
)
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states, _ = self.input_linear(hidden_states)
|
|
hidden_states = self.act_fn(hidden_states)
|
|
hidden_states, _ = self.output_linear(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class GraniteMoeSharedDecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: GraniteMoeSharedConfig,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.self_attn = GraniteMoeAttention(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
max_position=config.max_position_embeddings,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
rope_parameters=config.rope_parameters,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
attention_multiplier=config.attention_multiplier,
|
|
)
|
|
self.block_sparse_moe = GraniteMoeMoE(
|
|
num_experts=config.num_local_experts,
|
|
top_k=config.num_experts_per_tok,
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.block_sparse_moe",
|
|
)
|
|
self.shared_mlp = (
|
|
None
|
|
if getattr(config, "shared_intermediate_size", 0) == 0
|
|
else GraniteMoeSharedMLP(
|
|
config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
|
|
)
|
|
)
|
|
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
|
|
self.residual_multiplier = config.residual_multiplier
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# Self Attention
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
)
|
|
hidden_states = residual + hidden_states * self.residual_multiplier
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
if self.shared_mlp is None:
|
|
hidden_states = self.block_sparse_moe(hidden_states)
|
|
else:
|
|
# create a copy since block_sparse_moe modifies in-place
|
|
moe_hidden_states = hidden_states.clone()
|
|
moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
|
|
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
|
|
del moe_hidden_states
|
|
hidden_states = residual + hidden_states * self.residual_multiplier
|
|
|
|
return hidden_states
|
|
|
|
|
|
@support_torch_compile
|
|
class GraniteMoeSharedModel(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
self.quant_config = quant_config # Required by MixtralModel
|
|
self.padding_idx = config.pad_token_id
|
|
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
self.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
)
|
|
self.embedding_multiplier = config.embedding_multiplier
|
|
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda prefix: GraniteMoeSharedDecoderLayer(
|
|
config, cache_config, quant_config=quant_config, prefix=prefix
|
|
),
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.embed_input_ids(input_ids)
|
|
hidden_states *= self.embedding_multiplier
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
|
hidden_states = layer(positions, hidden_states)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors(
|
|
{
|
|
"hidden_states": hidden_states,
|
|
}
|
|
)
|
|
hidden_states = self.norm(hidden_states)
|
|
return hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
new_weights = {}
|
|
for n, p in weights:
|
|
if n.endswith(".block_sparse_moe.input_linear.weight"):
|
|
for e in range(p.size(0)):
|
|
w1_name = n.replace(
|
|
".block_sparse_moe.input_linear.weight",
|
|
f".block_sparse_moe.experts.{e}.w1.weight",
|
|
)
|
|
w3_name = n.replace(
|
|
".block_sparse_moe.input_linear.weight",
|
|
f".block_sparse_moe.experts.{e}.w3.weight",
|
|
)
|
|
w1_param, w3_param = p[e].chunk(2, dim=0)
|
|
assert w1_name not in new_weights
|
|
assert w3_name not in new_weights
|
|
new_weights[w1_name] = w1_param
|
|
new_weights[w3_name] = w3_param
|
|
elif n.endswith(".block_sparse_moe.output_linear.weight"):
|
|
for e in range(p.size(0)):
|
|
w2_name = n.replace(
|
|
".block_sparse_moe.output_linear.weight",
|
|
f".block_sparse_moe.experts.{e}.w2.weight",
|
|
)
|
|
w2_param = p[e]
|
|
assert w2_name not in new_weights
|
|
new_weights[w2_name] = w2_param
|
|
elif n.endswith(".block_sparse_moe.router.layer.weight"):
|
|
gate_name = n.replace(
|
|
".block_sparse_moe.router.layer.weight",
|
|
".block_sparse_moe.gate.weight",
|
|
)
|
|
assert gate_name not in new_weights
|
|
new_weights[gate_name] = p
|
|
else:
|
|
new_weights[n] = p
|
|
return GraniteMoeModel._load_weights(self, new_weights.items())
|
|
|
|
|
|
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|
fall_back_to_pt_during_load = False
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
}
|
|
|
|
# LoRA specific attributes
|
|
embedding_modules = {
|
|
"embed_tokens": "input_embeddings",
|
|
"lm_head": "output_embeddings",
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
|
|
self.model = GraniteMoeSharedModel(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
|
)
|
|
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "lm_head"),
|
|
)
|
|
if config.tie_word_embeddings:
|
|
self.lm_head.weight = self.model.embed_tokens.weight
|
|
|
|
self.logits_processor = LogitsProcessor(
|
|
config.vocab_size,
|
|
config.vocab_size,
|
|
scale=1 / self.config.logits_scaling,
|
|
)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.embed_input_ids(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(
|
|
input_ids, positions, intermediate_tensors, inputs_embeds
|
|
)
|
|
return hidden_states
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
|
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
return logits
|
|
|
|
def make_empty_intermediate_tensors(
|
|
self, batch_size: int, dtype: torch.dtype, device: torch.device
|
|
) -> IntermediateTensors:
|
|
return IntermediateTensors(
|
|
{
|
|
"hidden_states": torch.zeros(
|
|
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
|
),
|
|
}
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(
|
|
self,
|
|
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
|
)
|
|
return loader.load_weights(weights)
|