mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-04 07:37:57 +08:00
[Model]: support Ling2.0 (#24627)
Signed-off-by: vito.yy <vito.yy@antgroup.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
bf214ca226
commit
72c99f2a75
@ -328,6 +328,7 @@ th {
|
|||||||
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
|
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
| `BailingMoeV2ForCausalLM` | Ling | `inclusionAI/Ling-mini-2.0`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ |
|
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
||||||
|
|||||||
@ -180,6 +180,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
|
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
"BailingMoeV2ForCausalLM": _HfExamplesInfo("inclusionAI/Ling-mini-2.0",
|
||||||
|
trust_remote_code=True),
|
||||||
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
|
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
|
||||||
min_transformers_version="4.55.3",
|
min_transformers_version="4.55.3",
|
||||||
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
|
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501
|
||||||
|
|||||||
@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -68,6 +67,7 @@ class BailingAttention(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
reduce_results: bool = True,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -84,10 +84,11 @@ class BailingAttention(nn.Module):
|
|||||||
self.head_dim = config.head_dim or (self.hidden_size //
|
self.head_dim = config.head_dim or (self.hidden_size //
|
||||||
self.total_num_heads)
|
self.total_num_heads)
|
||||||
self.q_size_per_rank = self.head_dim * self.num_heads
|
self.q_size_per_rank = self.head_dim * self.num_heads
|
||||||
|
|
||||||
self.num_kv_heads = self.total_kv_heads // tp_size
|
self.num_kv_heads = self.total_kv_heads // tp_size
|
||||||
self.kv_size_per_rank = self.num_kv_heads * self.head_dim
|
self.kv_size_per_rank = self.num_kv_heads * self.head_dim
|
||||||
self.scale = self.head_dim**-0.5
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
||||||
|
self.use_rmsnorm = getattr(config, "use_rmsnorm", False)
|
||||||
|
|
||||||
self.query_key_value = QKVParallelLinear(
|
self.query_key_value = QKVParallelLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
@ -99,28 +100,45 @@ class BailingAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.query_key_value",
|
prefix=f"{prefix}.query_key_value",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.query_layernorm = (RMSNorm(
|
||||||
|
self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm
|
||||||
|
else nn.LayerNorm(self.head_dim, eps=1e-6))
|
||||||
|
self.key_layernorm = (RMSNorm(
|
||||||
|
self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm
|
||||||
|
else nn.LayerNorm(self.head_dim, eps=1e-6))
|
||||||
|
|
||||||
self.dense = RowParallelLinear(
|
self.dense = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
reduce_results=reduce_results,
|
||||||
prefix=f"{prefix}.dense",
|
prefix=f"{prefix}.dense",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn = Attention(self.num_heads,
|
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
|
||||||
self.head_dim,
|
1.0)
|
||||||
self.scale,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
self.rotary_dim = getattr(config, "rotary_dim", self.head_dim)
|
||||||
cache_config=cache_config,
|
|
||||||
prefix=f"{prefix}.attn")
|
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.rotary_dim,
|
||||||
max_position=config.max_position_embeddings,
|
max_position=config.max_position_embeddings,
|
||||||
base=config.rope_theta,
|
base=config.rope_theta,
|
||||||
is_neox_style=True,
|
is_neox_style=True,
|
||||||
rope_scaling=config.rope_scaling,
|
rope_scaling=config.rope_scaling,
|
||||||
|
partial_rotary_factor=self.partial_rotary_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scale,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -135,6 +153,14 @@ class BailingAttention(nn.Module):
|
|||||||
],
|
],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
q = q.view(-1, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
q = self.query_layernorm(q)
|
||||||
|
k = self.key_layernorm(k)
|
||||||
|
q = q.view(-1, self.q_size_per_rank)
|
||||||
|
k = k.view(-1, self.kv_size_per_rank)
|
||||||
|
|
||||||
q, k = self.rotary_emb(position_ids, q, k)
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
|
|
||||||
context_layer = self.attn(q, k, v)
|
context_layer = self.attn(q, k, v)
|
||||||
@ -198,24 +224,72 @@ class BailingMoE(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.num_shared_experts = config.num_shared_experts
|
self.num_shared_experts = config.num_shared_experts
|
||||||
# Gate always runs at half / full precision for now.
|
self.score_function = getattr(config, "score_function", None)
|
||||||
self.gate = ReplicatedLinear(self.hidden_size,
|
self.n_group = getattr(config, "n_group", None)
|
||||||
self.num_experts,
|
self.topk_group = getattr(config, "topk_group", None)
|
||||||
bias=False,
|
self.use_grouped_topk = (self.n_group is not None
|
||||||
quant_config=None)
|
and self.topk_group is not None)
|
||||||
|
self.routed_scaling_factor = getattr(config, "routed_scaling_factor",
|
||||||
|
1.0)
|
||||||
|
|
||||||
self.experts = FusedMoE(num_experts=self.num_experts,
|
router_dtype = getattr(config, "router_dtype", None)
|
||||||
top_k=self.top_k,
|
if router_dtype is None:
|
||||||
hidden_size=self.hidden_size,
|
self.router_dtype = None
|
||||||
intermediate_size=config.moe_intermediate_size,
|
elif router_dtype == "fp32":
|
||||||
reduce_results=False,
|
self.router_dtype = torch.float32
|
||||||
renormalize=self.norm_expert_prob,
|
else:
|
||||||
quant_config=quant_config,
|
self.router_dtype = torch.bfloat16
|
||||||
prefix=f"{prefix}.experts")
|
|
||||||
|
self.gate = nn.Linear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_experts,
|
||||||
|
bias=False,
|
||||||
|
dtype=self.router_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(config, "moe_router_enable_expert_bias", False):
|
||||||
|
self.gate.expert_bias = nn.Parameter(
|
||||||
|
torch.empty((config.num_experts, ), dtype=torch.float32))
|
||||||
|
else:
|
||||||
|
self.gate.expert_bias = None
|
||||||
|
|
||||||
|
self.correction_bias = (self.gate.expert_bias.data
|
||||||
|
if self.gate.expert_bias is not None else None)
|
||||||
|
|
||||||
|
if self.score_function is not None:
|
||||||
|
assert (
|
||||||
|
self.score_function == "softmax"
|
||||||
|
and self.correction_bias is None
|
||||||
|
) or (
|
||||||
|
self.score_function == "sigmoid"
|
||||||
|
and self.correction_bias is not None
|
||||||
|
), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" # noqa: E501
|
||||||
|
else:
|
||||||
|
# default value for scoring_func
|
||||||
|
self.score_function = "softmax"
|
||||||
|
|
||||||
|
self.experts = FusedMoE(
|
||||||
|
num_experts=self.num_experts,
|
||||||
|
top_k=self.top_k,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.moe_intermediate_size,
|
||||||
|
reduce_results=False,
|
||||||
|
renormalize=self.norm_expert_prob,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.experts",
|
||||||
|
scoring_func=self.score_function,
|
||||||
|
e_score_correction_bias=self.gate.expert_bias,
|
||||||
|
num_expert_group=self.n_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
|
)
|
||||||
|
|
||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
intermediate_size = (config.moe_intermediate_size *
|
if hasattr(config, "moe_shared_expert_intermediate_size"):
|
||||||
self.num_shared_experts)
|
intermediate_size = config.moe_shared_expert_intermediate_size
|
||||||
|
else:
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
intermediate_size *= config.num_shared_experts
|
||||||
self.shared_experts = BailingMLP(
|
self.shared_experts = BailingMLP(
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
config=config,
|
config=config,
|
||||||
@ -228,14 +302,18 @@ class BailingMoE(nn.Module):
|
|||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_size)
|
hidden_states = hidden_states.view(-1, hidden_size)
|
||||||
if self.num_shared_experts > 0:
|
if self.shared_experts:
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states.to(self.router_dtype))
|
||||||
|
router_logits = router_logits.to(hidden_states.dtype)
|
||||||
|
|
||||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||||
router_logits=router_logits)
|
router_logits=router_logits)
|
||||||
|
|
||||||
if self.num_shared_experts > 0:
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
|
||||||
|
if self.shared_experts:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@ -254,20 +332,30 @@ class BailingMoeBlock(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
layer_idx = int(prefix.split('.')[-1])
|
||||||
|
self.config = config
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
|
|
||||||
self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
|
||||||
self.attention = BailingAttention(config,
|
self.attention = BailingAttention(config,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix=f"{prefix}.attention")
|
prefix=f"{prefix}.attention")
|
||||||
|
|
||||||
self.post_attention_layernorm = RMSNorm(hidden_size,
|
self.post_attention_layernorm = RMSNorm(hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.mlp = BailingMoE(intermediate_size,
|
|
||||||
config,
|
# Choose MLP class based on the number of experts and layer index
|
||||||
quant_config,
|
if layer_idx < config.first_k_dense_replace:
|
||||||
True,
|
mlp_class = BailingMLP
|
||||||
prefix=f"{prefix}.mlp")
|
else:
|
||||||
|
mlp_class = BailingMoE
|
||||||
|
self.mlp = mlp_class(intermediate_size,
|
||||||
|
config,
|
||||||
|
quant_config,
|
||||||
|
True,
|
||||||
|
prefix=f"{prefix}.mlp")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -310,11 +398,17 @@ class BailingMoeModel(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
|
self.tie_word_embeddings = getattr(config, "tie_word_embeddings",
|
||||||
|
False)
|
||||||
|
|
||||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
if get_pp_group().is_first_rank or (self.tie_word_embeddings
|
||||||
and get_pp_group().is_last_rank):
|
and get_pp_group().is_last_rank):
|
||||||
self.word_embeddings = VocabParallelEmbedding(
|
self.word_embeddings = VocabParallelEmbedding(
|
||||||
self.vocab_size, self.embed_dim)
|
self.vocab_size,
|
||||||
|
self.embed_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.word_embeddings",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.word_embeddings = PPMissingLayer()
|
self.word_embeddings = PPMissingLayer()
|
||||||
|
|
||||||
@ -372,8 +466,11 @@ class BailingMoeModel(nn.Module):
|
|||||||
"hidden_states": hidden_states,
|
"hidden_states": hidden_states,
|
||||||
"residual": residual
|
"residual": residual
|
||||||
})
|
})
|
||||||
|
else:
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
if residual is None:
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
@ -396,7 +493,8 @@ class BailingMoeModel(nn.Module):
|
|||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
expert_params_mapping = self.get_expert_mapping()
|
expert_params_mapping = self.get_expert_mapping()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if self.config.norm_head and "lm_head.weight" in name:
|
if (hasattr(self.config, "norm_head") and self.config.norm_head
|
||||||
|
and "lm_head.weight" in name):
|
||||||
loaded_weight = F.normalize(loaded_weight,
|
loaded_weight = F.normalize(loaded_weight,
|
||||||
dim=0,
|
dim=0,
|
||||||
p=2,
|
p=2,
|
||||||
@ -430,13 +528,17 @@ class BailingMoeModel(nn.Module):
|
|||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(
|
||||||
loaded_weight,
|
param,
|
||||||
name,
|
loaded_weight,
|
||||||
shard_id=shard_id,
|
name,
|
||||||
expert_id=expert_id)
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
@ -473,19 +575,30 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config.get_text_config()
|
||||||
|
vllm_config.model_config.hf_config = config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.model = BailingMoeModel(vllm_config=vllm_config,
|
self.model = BailingMoeModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
self.tie_word_embeddings = getattr(config, "tie_word_embeddings",
|
||||||
|
False)
|
||||||
|
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
self.lm_head = (self.word_embeddings if config.tie_word_embeddings
|
if self.tie_word_embeddings:
|
||||||
else ParallelLMHead(config.vocab_size,
|
self.lm_head = self.model.word_embeddings
|
||||||
config.hidden_size,
|
else:
|
||||||
quant_config=quant_config))
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.lm_head",
|
||||||
|
)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
else:
|
else:
|
||||||
self.lm_head = PPMissingLayer()
|
self.lm_head = PPMissingLayer()
|
||||||
@ -520,10 +633,13 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
|||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
self,
|
self,
|
||||||
skip_prefixes=(["lm_head."]
|
skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None),
|
||||||
if self.config.tie_word_embeddings else None),
|
|
||||||
)
|
)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
return self.model.get_expert_mapping()
|
return self.model.get_expert_mapping()
|
||||||
|
|
||||||
|
|
||||||
|
class BailingMoeV2ForCausalLM(BailingMoeForCausalLM):
|
||||||
|
pass
|
||||||
|
|||||||
@ -52,6 +52,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
# baichuan-13b, lower case 'c' in the class name
|
# baichuan-13b, lower case 'c' in the class name
|
||||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
|
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
|
||||||
"BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
|
"BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
|
||||||
|
"BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
|
||||||
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
|
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
|
||||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user