[Model]: support Ling2.0 (#24627)

Signed-off-by: vito.yy <vito.yy@antgroup.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
ant-yy 2025-09-15 20:09:30 +08:00 committed by GitHub
parent bf214ca226
commit 72c99f2a75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 169 additions and 49 deletions

View File

@ -328,6 +328,7 @@ th {
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `BailingMoeV2ForCausalLM` | Ling | `inclusionAI/Ling-mini-2.0`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ |
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |

View File

@ -180,6 +180,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5",
trust_remote_code=True),
"BailingMoeV2ForCausalLM": _HfExamplesInfo("inclusionAI/Ling-mini-2.0",
trust_remote_code=True),
"BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1",
min_transformers_version="4.55.3",
extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501

View File

@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
@ -68,6 +67,7 @@ class BailingAttention(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
):
super().__init__()
@ -84,10 +84,11 @@ class BailingAttention(nn.Module):
self.head_dim = config.head_dim or (self.hidden_size //
self.total_num_heads)
self.q_size_per_rank = self.head_dim * self.num_heads
self.num_kv_heads = self.total_kv_heads // tp_size
self.kv_size_per_rank = self.num_kv_heads * self.head_dim
self.scale = self.head_dim**-0.5
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.use_rmsnorm = getattr(config, "use_rmsnorm", False)
self.query_key_value = QKVParallelLinear(
self.hidden_size,
@ -99,28 +100,45 @@ class BailingAttention(nn.Module):
prefix=f"{prefix}.query_key_value",
)
if self.use_qk_norm:
self.query_layernorm = (RMSNorm(
self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm
else nn.LayerNorm(self.head_dim, eps=1e-6))
self.key_layernorm = (RMSNorm(
self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm
else nn.LayerNorm(self.head_dim, eps=1e-6))
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=config.use_bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.dense",
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn")
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
1.0)
self.rotary_dim = getattr(config, "rotary_dim", self.head_dim)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
rotary_dim=self.rotary_dim,
max_position=config.max_position_embeddings,
base=config.rope_theta,
is_neox_style=True,
rope_scaling=config.rope_scaling,
partial_rotary_factor=self.partial_rotary_factor,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn",
)
def forward(
@ -135,6 +153,14 @@ class BailingAttention(nn.Module):
],
dim=-1)
if self.use_qk_norm:
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
q = self.query_layernorm(q)
k = self.key_layernorm(k)
q = q.view(-1, self.q_size_per_rank)
k = k.view(-1, self.kv_size_per_rank)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(q, k, v)
@ -198,24 +224,72 @@ class BailingMoE(nn.Module):
self.hidden_size = config.hidden_size
self.quant_config = quant_config
self.num_shared_experts = config.num_shared_experts
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size,
self.num_experts,
bias=False,
quant_config=None)
self.score_function = getattr(config, "score_function", None)
self.n_group = getattr(config, "n_group", None)
self.topk_group = getattr(config, "topk_group", None)
self.use_grouped_topk = (self.n_group is not None
and self.topk_group is not None)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor",
1.0)
self.experts = FusedMoE(num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts")
router_dtype = getattr(config, "router_dtype", None)
if router_dtype is None:
self.router_dtype = None
elif router_dtype == "fp32":
self.router_dtype = torch.float32
else:
self.router_dtype = torch.bfloat16
self.gate = nn.Linear(
self.hidden_size,
self.num_experts,
bias=False,
dtype=self.router_dtype,
)
if getattr(config, "moe_router_enable_expert_bias", False):
self.gate.expert_bias = nn.Parameter(
torch.empty((config.num_experts, ), dtype=torch.float32))
else:
self.gate.expert_bias = None
self.correction_bias = (self.gate.expert_bias.data
if self.gate.expert_bias is not None else None)
if self.score_function is not None:
assert (
self.score_function == "softmax"
and self.correction_bias is None
) or (
self.score_function == "sigmoid"
and self.correction_bias is not None
), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" # noqa: E501
else:
# default value for scoring_func
self.score_function = "softmax"
self.experts = FusedMoE(
num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=self.norm_expert_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
scoring_func=self.score_function,
e_score_correction_bias=self.gate.expert_bias,
num_expert_group=self.n_group,
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
)
if self.num_shared_experts > 0:
intermediate_size = (config.moe_intermediate_size *
self.num_shared_experts)
if hasattr(config, "moe_shared_expert_intermediate_size"):
intermediate_size = config.moe_shared_expert_intermediate_size
else:
intermediate_size = config.moe_intermediate_size
intermediate_size *= config.num_shared_experts
self.shared_experts = BailingMLP(
intermediate_size=intermediate_size,
config=config,
@ -228,14 +302,18 @@ class BailingMoE(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
if self.num_shared_experts > 0:
if self.shared_experts:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = self.gate(hidden_states.to(self.router_dtype))
router_logits = router_logits.to(hidden_states.dtype)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if self.num_shared_experts > 0:
final_hidden_states *= self.routed_scaling_factor
if self.shared_experts:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
@ -254,20 +332,30 @@ class BailingMoeBlock(nn.Module):
prefix: str = "",
):
super().__init__()
layer_idx = int(prefix.split('.')[-1])
self.config = config
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
self.attention = BailingAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attention")
self.post_attention_layernorm = RMSNorm(hidden_size,
eps=config.rms_norm_eps)
self.mlp = BailingMoE(intermediate_size,
config,
quant_config,
True,
prefix=f"{prefix}.mlp")
# Choose MLP class based on the number of experts and layer index
if layer_idx < config.first_k_dense_replace:
mlp_class = BailingMLP
else:
mlp_class = BailingMoE
self.mlp = mlp_class(intermediate_size,
config,
quant_config,
True,
prefix=f"{prefix}.mlp")
def forward(
self,
@ -310,11 +398,17 @@ class BailingMoeModel(nn.Module):
self.config = config
self.vocab_size = config.vocab_size
self.embed_dim = config.hidden_size
self.tie_word_embeddings = getattr(config, "tie_word_embeddings",
False)
if get_pp_group().is_first_rank or (config.tie_word_embeddings
if get_pp_group().is_first_rank or (self.tie_word_embeddings
and get_pp_group().is_last_rank):
self.word_embeddings = VocabParallelEmbedding(
self.vocab_size, self.embed_dim)
self.vocab_size,
self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.word_embeddings",
)
else:
self.word_embeddings = PPMissingLayer()
@ -372,8 +466,11 @@ class BailingMoeModel(nn.Module):
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
else:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@ -396,7 +493,8 @@ class BailingMoeModel(nn.Module):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if self.config.norm_head and "lm_head.weight" in name:
if (hasattr(self.config, "norm_head") and self.config.norm_head
and "lm_head.weight" in name):
loaded_weight = F.normalize(loaded_weight,
dim=0,
p=2,
@ -430,13 +528,17 @@ class BailingMoeModel(nn.Module):
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
if name.endswith(".bias") and name not in params_dict:
@ -473,19 +575,30 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_config.get_text_config()
vllm_config.model_config.hf_config = config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.max_position_embeddings = config.max_position_embeddings
self.model = BailingMoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.tie_word_embeddings = getattr(config, "tie_word_embeddings",
False)
if get_pp_group().is_last_rank:
self.lm_head = (self.word_embeddings if config.tie_word_embeddings
else ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config))
if self.tie_word_embeddings:
self.lm_head = self.model.word_embeddings
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)
self.logits_processor = LogitsProcessor(config.vocab_size)
else:
self.lm_head = PPMissingLayer()
@ -520,10 +633,13 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None),
)
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
class BailingMoeV2ForCausalLM(BailingMoeForCausalLM):
pass

View File

@ -52,6 +52,7 @@ _TEXT_GENERATION_MODELS = {
# baichuan-13b, lower case 'c' in the class name
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
"BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
"BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),