mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:06:15 +08:00
664 lines
24 KiB
Python
664 lines
24 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Iterable
|
|
from typing import Any
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
|
|
from vllm.distributed import (
|
|
get_pp_group,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_reduce,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.kda import KimiDeltaAttention
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
|
MambaStateDtypeCalculator,
|
|
MambaStateShapeCalculator,
|
|
)
|
|
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
|
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
default_weight_loader,
|
|
maybe_remap_kv_scale_name,
|
|
)
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
|
|
|
|
from .interfaces import HasInnerState, IsHybrid, MixtureOfExperts, SupportsPP
|
|
from .utils import (
|
|
PPMissingLayer,
|
|
is_pp_missing_parameter,
|
|
make_layers,
|
|
maybe_prefix,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class KimiMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
quant_config: QuantizationConfig | None = None,
|
|
reduce_results: bool = True,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj",
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
reduce_results=reduce_results,
|
|
prefix=f"{prefix}.down_proj",
|
|
)
|
|
if hidden_act != "silu":
|
|
raise ValueError(
|
|
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
|
)
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class KimiMoE(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: KimiLinearConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
layer_idx: int = 0,
|
|
):
|
|
super().__init__()
|
|
hidden_size = config.hidden_size
|
|
intermediate_size = config.intermediate_size
|
|
moe_intermediate_size = config.moe_intermediate_size
|
|
num_experts = config.num_experts
|
|
moe_renormalize = config.moe_renormalize
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.routed_scaling_factor = config.routed_scaling_factor
|
|
self.num_shared_experts = config.num_shared_experts
|
|
self.layer_idx = layer_idx
|
|
|
|
if config.hidden_act != "silu":
|
|
raise ValueError(
|
|
f"Unsupported activation: {config.hidden_act}. "
|
|
"Only silu is supported for now."
|
|
)
|
|
|
|
# Gate always runs at half / full precision for now.
|
|
self.gate = ReplicatedLinear(
|
|
hidden_size,
|
|
num_experts,
|
|
bias=False,
|
|
quant_config=None,
|
|
prefix=f"{prefix}.gate",
|
|
)
|
|
|
|
self.gate.e_score_correction_bias = nn.Parameter(torch.empty(num_experts))
|
|
|
|
self.experts = FusedMoE(
|
|
num_experts=num_experts,
|
|
top_k=config.num_experts_per_token,
|
|
hidden_size=hidden_size,
|
|
intermediate_size=moe_intermediate_size,
|
|
reduce_results=False,
|
|
renormalize=moe_renormalize,
|
|
quant_config=quant_config,
|
|
use_grouped_topk=config.use_grouped_topk,
|
|
num_expert_group=config.num_expert_group,
|
|
topk_group=config.topk_group,
|
|
prefix=f"{prefix}.experts",
|
|
scoring_func=config.moe_router_activation_func,
|
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
|
)
|
|
|
|
if self.num_shared_experts is not None:
|
|
intermediate_size = moe_intermediate_size * self.num_shared_experts
|
|
self.shared_experts = KimiMLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
reduce_results=False,
|
|
prefix=f"{prefix}.shared_experts",
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
hidden_states = hidden_states.view(-1, hidden_size)
|
|
if self.num_shared_experts is not None:
|
|
shared_output = self.shared_experts(hidden_states)
|
|
router_logits, _ = self.gate(hidden_states)
|
|
final_hidden_states = (
|
|
self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
|
* self.routed_scaling_factor
|
|
)
|
|
if shared_output is not None:
|
|
final_hidden_states = final_hidden_states + shared_output
|
|
|
|
if self.tp_size > 1:
|
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
return final_hidden_states.view(num_tokens, hidden_size)
|
|
|
|
|
|
class KimiMLAAttention(nn.Module):
|
|
"""
|
|
Main reference: DeepseekV2 vllm Implementation
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: KimiLinearConfig,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
v_head_dim: int,
|
|
q_lora_rank: int | None,
|
|
kv_lora_rank: int,
|
|
rope_theta: float = 10000,
|
|
use_nope: bool = False,
|
|
rope_scaling: dict[str, Any] | None = None,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.num_heads = num_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.num_local_heads = num_heads // tp_size
|
|
self.scaling = self.qk_head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.use_nope = use_nope
|
|
assert self.use_nope is True
|
|
assert self.q_lora_rank is None
|
|
assert rope_scaling is None
|
|
assert num_heads % tp_size == 0
|
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
|
)
|
|
self.q_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_proj",
|
|
)
|
|
self.kv_a_layernorm = RMSNorm(
|
|
self.kv_lora_rank,
|
|
eps=config.rms_norm_eps,
|
|
)
|
|
self.kv_b_proj = ColumnParallelLinear(
|
|
self.kv_lora_rank,
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.kv_b_proj",
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
self.num_heads * self.v_head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
mla_modules = MLAModules(
|
|
kv_a_layernorm=self.kv_a_layernorm,
|
|
kv_b_proj=self.kv_b_proj,
|
|
rotary_emb=None,
|
|
o_proj=self.o_proj,
|
|
fused_qkv_a_proj=None,
|
|
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
|
q_a_layernorm=None,
|
|
q_b_proj=None,
|
|
q_proj=self.q_proj,
|
|
indexer=None,
|
|
is_sparse=False,
|
|
topk_indices_buffer=None,
|
|
)
|
|
self.mla_attn = MultiHeadLatentAttentionWrapper(
|
|
self.hidden_size,
|
|
self.num_local_heads,
|
|
self.scaling,
|
|
self.qk_nope_head_dim,
|
|
self.qk_rope_head_dim,
|
|
self.v_head_dim,
|
|
self.q_lora_rank,
|
|
self.kv_lora_rank,
|
|
mla_modules,
|
|
cache_config,
|
|
quant_config,
|
|
prefix,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
output: torch.Tensor,
|
|
) -> None:
|
|
output[:] = self.mla_attn(positions, hidden_states)
|
|
|
|
|
|
class KimiDecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: KimiLinearConfig,
|
|
layer_idx: int,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
parallel_config: ParallelConfig | None = None,
|
|
model_config: ModelConfig | None = None,
|
|
prefix: str = "",
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.is_moe = config.is_moe
|
|
|
|
if config.is_kda_layer(layer_idx):
|
|
self.self_attn = KimiDeltaAttention(
|
|
layer_idx=layer_idx,
|
|
hidden_size=config.hidden_size,
|
|
quant_config=quant_config,
|
|
cache_config=cache_config,
|
|
model_config=config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
else:
|
|
self.self_attn = KimiMLAAttention(
|
|
layer_idx=layer_idx,
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
quant_config=quant_config,
|
|
cache_config=cache_config,
|
|
model_config=model_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
config=config,
|
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
|
v_head_dim=config.v_head_dim,
|
|
q_lora_rank=config.q_lora_rank,
|
|
kv_lora_rank=config.kv_lora_rank,
|
|
use_nope=config.mla_use_nope,
|
|
)
|
|
|
|
if (
|
|
self.is_moe
|
|
and config.num_experts is not None
|
|
and layer_idx >= config.first_k_dense_replace
|
|
and layer_idx % config.moe_layer_freq == 0
|
|
):
|
|
self.block_sparse_moe = KimiMoE(
|
|
config=config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.block_sparse_moe",
|
|
)
|
|
self.mlp = self.block_sparse_moe
|
|
else:
|
|
self.mlp = KimiMLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Self Attention
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
|
|
attn_output = torch.empty_like(hidden_states)
|
|
self.self_attn(
|
|
hidden_states=hidden_states,
|
|
positions=positions,
|
|
output=attn_output,
|
|
)
|
|
hidden_states = attn_output
|
|
|
|
# Fully Connected
|
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
|
hidden_states = self.mlp(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
@support_torch_compile
|
|
class KimiLinearModel(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_text_config
|
|
model_config = vllm_config.model_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
parallel_config = vllm_config.parallel_config
|
|
self.config = config
|
|
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
if get_pp_group().is_first_rank:
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
prefix=f"{prefix}.embed_tokens",
|
|
)
|
|
else:
|
|
self.embed_tokens = PPMissingLayer()
|
|
|
|
extra_kwargs = {}
|
|
|
|
def get_layer(prefix: str):
|
|
layer_idx = int(prefix.rsplit(".", 1)[1])
|
|
return KimiDecoderLayer(
|
|
config,
|
|
layer_idx,
|
|
cache_config,
|
|
quant_config,
|
|
parallel_config,
|
|
model_config,
|
|
prefix,
|
|
**extra_kwargs,
|
|
)
|
|
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
get_layer,
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
|
|
if get_pp_group().is_last_rank:
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
else:
|
|
self.norm = PPMissingLayer()
|
|
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
assert config.num_attention_heads % world_size == 0, (
|
|
"num_attention_heads must be divisible by world_size"
|
|
)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
|
|
for _, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
|
|
hidden_states, residual = layer(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
)
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors(
|
|
{"hidden_states": hidden_states, "residual": residual}
|
|
)
|
|
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
class KimiLinearForCausalLM(
|
|
nn.Module, HasInnerState, SupportsPP, MixtureOfExperts, IsHybrid
|
|
):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
self.model_config = vllm_config.model_config
|
|
self.vllm_config = vllm_config
|
|
self.config = self.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.quant_config = quant_config
|
|
self.model = KimiLinearModel(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
|
)
|
|
if get_pp_group().is_last_rank:
|
|
self.lm_head = ParallelLMHead(
|
|
self.config.vocab_size,
|
|
self.config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "lm_head"),
|
|
)
|
|
else:
|
|
self.lm_head = PPMissingLayer()
|
|
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
|
self.logits_processor = LogitsProcessor(
|
|
self.config.vocab_size, scale=logit_scale
|
|
)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.get_input_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
hidden_states = self.model(
|
|
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
|
|
)
|
|
return hidden_states
|
|
|
|
@classmethod
|
|
def get_mamba_state_dtype_from_config(
|
|
cls,
|
|
vllm_config: "VllmConfig",
|
|
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
|
|
return MambaStateDtypeCalculator.kda_state_dtype(
|
|
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
|
|
)
|
|
|
|
@classmethod
|
|
def get_mamba_state_shape_from_config(
|
|
cls, vllm_config: "VllmConfig"
|
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
|
parallel_config = vllm_config.parallel_config
|
|
hf_config = vllm_config.model_config.hf_config
|
|
tp_size = parallel_config.tensor_parallel_size
|
|
num_spec = (
|
|
vllm_config.speculative_config.num_speculative_tokens
|
|
if vllm_config.speculative_config
|
|
else 0
|
|
)
|
|
return MambaStateShapeCalculator.kda_state_shape(
|
|
tp_size,
|
|
hf_config.linear_attn_config["num_heads"],
|
|
hf_config.linear_attn_config["head_dim"],
|
|
conv_kernel_size=hf_config.linear_attn_config["short_conv_kernel_size"],
|
|
num_spec=num_spec,
|
|
)
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
return self.logits_processor(self.lm_head, hidden_states)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".gate_up_proj", ".gate_proj", 0),
|
|
(".gate_up_proj", ".up_proj", 1),
|
|
]
|
|
if self.config.is_moe:
|
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
ckpt_gate_proj_name="w1",
|
|
ckpt_down_proj_name="w2",
|
|
ckpt_up_proj_name="w3",
|
|
num_experts=self.config.num_experts,
|
|
)
|
|
else:
|
|
expert_params_mapping = []
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
for args in weights:
|
|
name, loaded_weight = args[:2]
|
|
kwargs = args[2] if len(args) > 2 else {}
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
|
|
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
|
if spec_layer is not None:
|
|
continue # skip spec decode layers for main model
|
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
|
# Models trained using ColossalAI may include these tensors in
|
|
# the checkpoint. Skip them.
|
|
continue
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
|
# Since we handle the experts below in expert_params_mapping,
|
|
# we need to skip here BEFORE we update the name, otherwise
|
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
# will then be updated below in expert_params_mapping
|
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
if ("mlp.experts." in name) and name not in params_dict:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
for idx, (param_name, weight_name, expert_id, shard_id) in enumerate(
|
|
expert_params_mapping
|
|
):
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(
|
|
param,
|
|
loaded_weight,
|
|
name,
|
|
expert_id=expert_id,
|
|
shard_id=shard_id,
|
|
)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if (
|
|
name.endswith(".bias")
|
|
and name not in params_dict
|
|
and not self.config.is_linear_attn
|
|
): # noqa: E501
|
|
continue
|
|
# Remapping the name of FP8 kv-scale.
|
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
if name is None:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight, **kwargs)
|
|
loaded_params.add(name)
|
|
|
|
|
|
def get_spec_layer_idx_from_weight_name(
|
|
config: KimiLinearConfig, weight_name: str
|
|
) -> int | None:
|
|
if hasattr(config, "num_nextn_predict_layers") and (
|
|
config.num_nextn_predict_layers > 0
|
|
):
|
|
layer_idx = config.num_hidden_layers
|
|
for i in range(config.num_nextn_predict_layers):
|
|
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
|
|
return layer_idx + i
|
|
return None
|