[Model]: support Ling2.0 (#24627)

Signed-off-by: vito.yy <vito.yy@antgroup.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
ant-yy 2025-09-15 20:09:30 +08:00 committed by GitHub
parent bf214ca226
commit 72c99f2a75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 169 additions and 49 deletions

View File

@ -328,6 +328,7 @@ th {
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `BailingMoeV2ForCausalLM` | Ling | `inclusionAI/Ling-mini-2.0`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ |
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |

View File

@ -180,6 +180,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
trust_remote_code=True), trust_remote_code=True),
"BailingMoeV2ForCausalLM": _HfExamplesInfo("inclusionAI/Ling-mini-2.0",
trust_remote_code=True),
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1", "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
min_transformers_version="4.55.3", min_transformers_version="4.55.3",
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501

View File

@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -68,6 +67,7 @@ class BailingAttention(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
@ -84,10 +84,11 @@ class BailingAttention(nn.Module):
self.head_dim = config.head_dim or (self.hidden_size // self.head_dim = config.head_dim or (self.hidden_size //
self.total_num_heads) self.total_num_heads)
self.q_size_per_rank = self.head_dim * self.num_heads self.q_size_per_rank = self.head_dim * self.num_heads
self.num_kv_heads = self.total_kv_heads // tp_size self.num_kv_heads = self.total_kv_heads // tp_size
self.kv_size_per_rank = self.num_kv_heads * self.head_dim self.kv_size_per_rank = self.num_kv_heads * self.head_dim
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.use_rmsnorm = getattr(config, "use_rmsnorm", False)
self.query_key_value = QKVParallelLinear( self.query_key_value = QKVParallelLinear(
self.hidden_size, self.hidden_size,
@ -99,28 +100,45 @@ class BailingAttention(nn.Module):
prefix=f"{prefix}.query_key_value", prefix=f"{prefix}.query_key_value",
) )
if self.use_qk_norm:
self.query_layernorm = (RMSNorm(
self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm
else nn.LayerNorm(self.head_dim, eps=1e-6))
self.key_layernorm = (RMSNorm(
self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm
else nn.LayerNorm(self.head_dim, eps=1e-6))
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=config.use_bias, bias=config.use_bias,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.dense", prefix=f"{prefix}.dense",
) )
self.attn = Attention(self.num_heads, self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
self.head_dim, 1.0)
self.scale,
num_kv_heads=self.num_kv_heads, self.rotary_dim = getattr(config, "rotary_dim", self.head_dim)
cache_config=cache_config,
prefix=f"{prefix}.attn")
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.rotary_dim,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
base=config.rope_theta, base=config.rope_theta,
is_neox_style=True, is_neox_style=True,
rope_scaling=config.rope_scaling, rope_scaling=config.rope_scaling,
partial_rotary_factor=self.partial_rotary_factor,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn",
) )
def forward( def forward(
@ -135,6 +153,14 @@ class BailingAttention(nn.Module):
], ],
dim=-1) dim=-1)
if self.use_qk_norm:
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
q = self.query_layernorm(q)
k = self.key_layernorm(k)
q = q.view(-1, self.q_size_per_rank)
k = k.view(-1, self.kv_size_per_rank)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(q, k, v) context_layer = self.attn(q, k, v)
@ -198,24 +224,72 @@ class BailingMoE(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.quant_config = quant_config self.quant_config = quant_config
self.num_shared_experts = config.num_shared_experts self.num_shared_experts = config.num_shared_experts
# Gate always runs at half / full precision for now. self.score_function = getattr(config, "score_function", None)
self.gate = ReplicatedLinear(self.hidden_size, self.n_group = getattr(config, "n_group", None)
self.num_experts, self.topk_group = getattr(config, "topk_group", None)
bias=False, self.use_grouped_topk = (self.n_group is not None
quant_config=None) and self.topk_group is not None)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor",
1.0)
self.experts = FusedMoE(num_experts=self.num_experts, router_dtype = getattr(config, "router_dtype", None)
top_k=self.top_k, if router_dtype is None:
hidden_size=self.hidden_size, self.router_dtype = None
intermediate_size=config.moe_intermediate_size, elif router_dtype == "fp32":
reduce_results=False, self.router_dtype = torch.float32
renormalize=self.norm_expert_prob, else:
quant_config=quant_config, self.router_dtype = torch.bfloat16
prefix=f"{prefix}.experts")
self.gate = nn.Linear(
self.hidden_size,
self.num_experts,
bias=False,
dtype=self.router_dtype,
)
if getattr(config, "moe_router_enable_expert_bias", False):
self.gate.expert_bias = nn.Parameter(
torch.empty((config.num_experts, ), dtype=torch.float32))
else:
self.gate.expert_bias = None
self.correction_bias = (self.gate.expert_bias.data
if self.gate.expert_bias is not None else None)
if self.score_function is not None:
assert (
self.score_function == "softmax"
and self.correction_bias is None
) or (
self.score_function == "sigmoid"
and self.correction_bias is not None
), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" # noqa: E501
else:
# default value for scoring_func
self.score_function = "softmax"
self.experts = FusedMoE(
num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
scoring_func=self.score_function,
e_score_correction_bias=self.gate.expert_bias,
num_expert_group=self.n_group,
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
)
if self.num_shared_experts > 0: if self.num_shared_experts > 0:
intermediate_size = (config.moe_intermediate_size * if hasattr(config, "moe_shared_expert_intermediate_size"):
self.num_shared_experts) intermediate_size = config.moe_shared_expert_intermediate_size
else:
intermediate_size = config.moe_intermediate_size
intermediate_size *= config.num_shared_experts
self.shared_experts = BailingMLP( self.shared_experts = BailingMLP(
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
config=config, config=config,
@ -228,14 +302,18 @@ class BailingMoE(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size) hidden_states = hidden_states.view(-1, hidden_size)
if self.num_shared_experts > 0: if self.shared_experts:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits = self.gate(hidden_states.to(self.router_dtype))
router_logits = router_logits.to(hidden_states.dtype)
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if self.num_shared_experts > 0: final_hidden_states *= self.routed_scaling_factor
if self.shared_experts:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1: if self.tp_size > 1:
@ -254,20 +332,30 @@ class BailingMoeBlock(nn.Module):
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
layer_idx = int(prefix.split('.')[-1])
self.config = config
hidden_size = config.hidden_size hidden_size = config.hidden_size
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
self.attention = BailingAttention(config, self.attention = BailingAttention(config,
cache_config, cache_config,
quant_config, quant_config,
prefix=f"{prefix}.attention") prefix=f"{prefix}.attention")
self.post_attention_layernorm = RMSNorm(hidden_size, self.post_attention_layernorm = RMSNorm(hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.mlp = BailingMoE(intermediate_size,
config, # Choose MLP class based on the number of experts and layer index
quant_config, if layer_idx < config.first_k_dense_replace:
True, mlp_class = BailingMLP
prefix=f"{prefix}.mlp") else:
mlp_class = BailingMoE
self.mlp = mlp_class(intermediate_size,
config,
quant_config,
True,
prefix=f"{prefix}.mlp")
def forward( def forward(
self, self,
@ -310,11 +398,17 @@ class BailingMoeModel(nn.Module):
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.tie_word_embeddings = getattr(config, "tie_word_embeddings",
False)
if get_pp_group().is_first_rank or (config.tie_word_embeddings if get_pp_group().is_first_rank or (self.tie_word_embeddings
and get_pp_group().is_last_rank): and get_pp_group().is_last_rank):
self.word_embeddings = VocabParallelEmbedding( self.word_embeddings = VocabParallelEmbedding(
self.vocab_size, self.embed_dim) self.vocab_size,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.word_embeddings",
)
else: else:
self.word_embeddings = PPMissingLayer() self.word_embeddings = PPMissingLayer()
@ -372,8 +466,11 @@ class BailingMoeModel(nn.Module):
"hidden_states": hidden_states, "hidden_states": hidden_states,
"residual": residual "residual": residual
}) })
else:
hidden_states, _ = self.norm(hidden_states, residual) if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@ -396,7 +493,8 @@ class BailingMoeModel(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping() expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if self.config.norm_head and "lm_head.weight" in name: if (hasattr(self.config, "norm_head") and self.config.norm_head
and "lm_head.weight" in name):
loaded_weight = F.normalize(loaded_weight, loaded_weight = F.normalize(loaded_weight,
dim=0, dim=0,
p=2, p=2,
@ -430,13 +528,17 @@ class BailingMoeModel(nn.Module):
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
if name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(
loaded_weight, param,
name, loaded_weight,
shard_id=shard_id, name,
expert_id=expert_id) shard_id=shard_id,
expert_id=expert_id,
)
break break
else: else:
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
@ -473,19 +575,30 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config.get_text_config()
vllm_config.model_config.hf_config = config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.model = BailingMoeModel(vllm_config=vllm_config, self.model = BailingMoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.tie_word_embeddings = getattr(config, "tie_word_embeddings",
False)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.lm_head = (self.word_embeddings if config.tie_word_embeddings if self.tie_word_embeddings:
else ParallelLMHead(config.vocab_size, self.lm_head = self.model.word_embeddings
config.hidden_size, else:
quant_config=quant_config)) self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
@ -520,10 +633,13 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None),
if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping() return self.model.get_expert_mapping()
class BailingMoeV2ForCausalLM(BailingMoeForCausalLM):
pass

View File

@ -52,6 +52,7 @@ _TEXT_GENERATION_MODELS = {
# baichuan-13b, lower case 'c' in the class name # baichuan-13b, lower case 'c' in the class name
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
"BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
"BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
"BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"),
"BloomForCausalLM": ("bloom", "BloomForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),