Matthew Bonanni 430dd4d9eb
[Attention] Remove imports from vllm/attention/__init__.py (#29342)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2025-11-26 10:53:15 -07:00

1052 lines
38 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from collections.abc import Callable, Iterable
from typing import Any
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.distributed import (
get_ep_group,
get_pp_group,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.interfaces import (
MixtureOfExperts,
SupportsLoRA,
SupportsPP,
)
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
sequence_parallel_chunk,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import set_default_rope_theta
def check_ffn_act_fn(act_fn: str):
if act_fn != "silu":
raise ValueError(
f"Unsupported activation: {act_fn}. Only silu is supported for now."
)
class OpenPanguMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj",
)
check_ffn_act_fn(hidden_act)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_up_proj(x)[0]))[0]
class OpenPanguMoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
parallel_config: ParallelConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
check_ffn_act_fn(config.hidden_act)
self.gate = ReplicatedLinear(
config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.gate.e_score_correction_bias = None
# Load balancing settings.
eplb_config = parallel_config.eplb_config
self.enable_eplb = parallel_config.enable_eplb
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
self.physical_expert_end = (
self.physical_expert_start + self.n_local_physical_experts
)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = OpenPanguMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
is_sequence_parallel=self.is_sequence_parallel,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
else:
self.shared_experts = None
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=1,
topk_group=1,
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None:
assert shared_output is None
if hidden_states.dtype != torch.float16:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
class OpenPanguMLAAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.tp_size = get_tensor_model_parallel_world_size()
if num_heads % self.tp_size != 0:
raise ValueError(
f"num_heads {num_heads} is not divisible by tp_size {self.tp_size}."
)
self.num_local_heads = num_heads // self.tp_size
self.scaling = self.qk_head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
self.prefix = prefix
if self.q_lora_rank is not None:
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True,
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# TODO: remove hard coding
set_default_rope_theta(config, default_theta=10000)
rope_parameters = {
"rope_theta": config.rope_parameters["rope_theta"],
"beta_fast": 32,
"beta_slow": 1,
"factor": 1,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": max_position_embeddings,
"type": "yarn",
"rope_type": "deepseek_yarn",
}
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=False,
)
mla_modules = MLAModules(
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
rotary_emb=self.rotary_emb,
o_proj=self.o_proj,
fused_qkv_a_proj=self.fused_qkv_a_proj
if self.q_lora_rank is not None
else None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
if self.q_lora_rank is None
else None,
q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else None,
indexer=None,
is_sparse=False,
topk_indices_buffer=None,
)
self.mla_attn = MultiHeadLatentAttentionWrapper(
self.hidden_size,
self.num_local_heads,
self.scaling,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
self.v_head_dim,
self.q_lora_rank,
self.kv_lora_rank,
mla_modules,
cache_config,
quant_config,
prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.mla_attn(positions, hidden_states)
class OpenPanguEmbeddedAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: CacheConfig | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
if self.total_num_heads % tp_size != 0:
raise ValueError(
f"total_num_heads {self.total_num_heads} "
f"is not divisible by tp_size {tp_size}."
)
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads > tp_size and self.total_num_kv_heads % tp_size != 0:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel ranks.
raise ValueError(
"Number of KV heads is greater than TP size, "
f"but total_num_kv_heads {self.total_num_kv_heads} "
f"is not divisible by tp_size {tp_size}."
)
elif (
self.total_num_kv_heads < tp_size and tp_size % self.total_num_kv_heads != 0
):
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel ranks.
raise ValueError(
f"Number of KV heads is less than TP size, but tp_size {tp_size} "
f"is not divisible by total_num_kv_heads {self.total_num_kv_heads}."
)
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = self.hidden_size // self.total_num_heads
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self._init_rotary_emb(config, quant_config=quant_config)
if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window
if isinstance(interleaved_sliding_window, int):
sliding_window = interleaved_sliding_window
elif isinstance(interleaved_sliding_window, list):
sw_idx = layer_idx % len(interleaved_sliding_window)
sliding_window = interleaved_sliding_window[sw_idx]
else:
raise ValueError(
f"{type(interleaved_sliding_window)} "
"for interleaved_sliding_window is not supported."
)
else:
sliding_window = None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
def _init_rotary_emb(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None,
) -> None:
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "PanguEmbedded":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
is_neox_style=is_neox_style,
)
class OpenPanguDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
vllm_config: VllmConfig,
) -> None:
super().__init__()
if config is None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
layer_idx = int(prefix.split(sep=".")[-1])
self.layer_idx = layer_idx
self.use_mla = (
hasattr(config, "qk_nope_head_dim")
and hasattr(config, "qk_rope_head_dim")
and hasattr(config, "v_head_dim")
and hasattr(config, "kv_lora_rank")
)
if self.use_mla:
self.self_attn = OpenPanguMLAAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=(
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
else:
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False
)
bias_o_proj = attention_bias
if hasattr(config, "qkv_bias"):
attention_bias = config.qkv_bias
# By default, PanguEmbedded uses causal attention
# as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = OpenPanguEmbeddedAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
if (
getattr(config, "n_routed_experts", None) is not None
and layer_idx >= config.first_k_dense_replace
):
self.mlp = OpenPanguMoE(
config=config,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
else:
self.mlp = OpenPanguMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.num_hidden_layers = config.num_hidden_layers
self.first_k_dense_replace = getattr(
config, "first_k_dense_replace", self.num_hidden_layers
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.tp_group = get_tp_group().device_group
self.sandwich_norm = getattr(config, "sandwich_norm", False)
if self.sandwich_norm:
self.pre_mlp_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_mlp_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> torch.Tensor:
if residual is None:
residual = hidden_states.clone()
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if (
self.routed_scaling_factor is not None
and hidden_states.dtype == torch.float16
):
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1.0 / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1.0 / self.routed_scaling_factor
if self.sandwich_norm:
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.pre_mlp_layernorm(hidden_states, residual)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected
hidden_states = self.mlp(hidden_states)
if (
self.routed_scaling_factor is not None
and isinstance(self.mlp, OpenPanguMLP)
and hidden_states.dtype == torch.float16
):
hidden_states *= 1.0 / self.routed_scaling_factor
if self.sandwich_norm:
hidden_states = self.post_mlp_layernorm(hidden_states)
return hidden_states, residual
@support_torch_compile
class OpenPanguModel(nn.Module):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
eplb_config = vllm_config.parallel_config.eplb_config
self.config = config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OpenPanguDecoderLayer(config, prefix, vllm_config),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_attn_mlp_weight(
self,
attn_mlp_replace_mapping: list[tuple[str, str, int]],
params_dict: dict[str, Any],
weight_name: str,
loaded_weight: torch.Tensor,
loaded_params: set[str],
) -> bool:
for param_name, origin_name, shard_id in attn_mlp_replace_mapping:
if origin_name not in weight_name or (
("mlp.experts." in weight_name) and weight_name not in params_dict
):
continue
weight_name_mapped = weight_name.replace(origin_name, param_name)
if (
param_name == "fused_qkv_a_proj"
and weight_name_mapped not in params_dict
):
continue
else:
weight_name = weight_name_mapped
if weight_name.endswith(".bias") and weight_name not in params_dict:
continue
if is_pp_missing_parameter(weight_name, self):
continue
param = params_dict[weight_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(weight_name)
return True
return False
def load_expert_weight(
self,
expert_merge_mapping: list[tuple[str, str, int, str]],
params_dict: dict[str, Any],
weight_name: str,
loaded_weight: torch.Tensor,
loaded_params: set[str],
flag_dict: dict[str, bool],
) -> bool:
for mapping in expert_merge_mapping:
param_name, origin_name, expert_id, shard_id = mapping
if origin_name not in weight_name:
continue
flag_dict["is_expert_weight"] = True
weight_name_mapped = weight_name.replace(origin_name, param_name)
if is_pp_missing_parameter(weight_name_mapped, self):
continue
param = params_dict[weight_name_mapped]
weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
success = weight_loader(
param,
loaded_weight,
weight_name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
weight_name = weight_name_mapped
loaded_params.add(weight_name_mapped)
return True
return False
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
attn_mlp_replace_mapping = [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".fused_qkv_a_proj", ".q_a_proj", 0),
(".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
has_experts = hasattr(self.config, "n_routed_experts")
if has_experts:
expert_merge_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_redundant_experts=self.num_redundant_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if (
"layers" in name
and hasattr(self.config, "num_nextn_predict_layers")
and (self.config.num_nextn_predict_layers > 0)
):
layer_idx = int(name.split("layers.")[-1].split(".")[0])
mtp_idx = layer_idx - self.config.num_hidden_layers
if mtp_idx >= 0 and mtp_idx < self.config.num_nextn_predict_layers:
continue # skip spec decode layers for main model
flag_dict = {"is_expert_weight": False}
if (
self.load_attn_mlp_weight(
attn_mlp_replace_mapping,
params_dict,
name,
loaded_weight,
loaded_params,
)
or has_experts
and self.load_expert_weight(
expert_merge_mapping,
params_dict,
name,
loaded_weight,
loaded_params,
flag_dict,
)
):
continue
else:
if flag_dict["is_expert_weight"]:
continue
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class OpenPanguModelBase(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.fuse_qkv_a_proj = (
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
)
if self.fuse_qkv_a_proj:
self.packed_modules_mapping["fused_qkv_a_proj"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
self.model = OpenPanguModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
# Set MoE hyperparameters
self.expert_weights = []
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = 1
self.moe_layers = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, OpenPanguDecoderLayer)
if isinstance(layer.mlp, OpenPanguMoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_moe is None:
raise RuntimeError("No MOE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.n_routed_experts = example_moe.n_routed_experts
self.n_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.mlp, OpenPanguMoE):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
class OpenPanguEmbeddedModel(OpenPanguModelBase):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
class PanguEmbeddedForCausalLM(OpenPanguEmbeddedModel):
pass
class PanguUltraMoEForCausalLM(OpenPanguMoEModel):
pass