[Model][Bugfix]fix ernie45 load failed due to ernie45 eplb code (#26684)

Signed-off-by: wangyafeng <wangyafeng@baidu.com>
This commit is contained in:
CSWYF3634076 2025-10-14 14:55:23 +08:00 committed by GitHub
parent 481545b397
commit 01ad27faff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,7 +23,8 @@
# limitations under the License.
"""Inference-only ErineMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable
import typing
from collections.abc import Callable, Iterable
from itertools import islice
from typing import Any
@ -139,10 +140,10 @@ class Ernie4_5_MoeMoE(nn.Module):
# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
@ -426,8 +427,10 @@ class Ernie4_5_MoeModel(nn.Module):
self.vocab_size = config.vocab_size
self.config = config
parallel_config = vllm_config.parallel_config
eplb_config = parallel_config.eplb_config
enable_eplb = parallel_config.enable_eplb
self.num_redundant_experts = parallel_config.num_redundant_experts
self.num_redundant_experts = eplb_config.num_redundant_experts
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
@ -570,20 +573,27 @@ class Ernie4_5_MoeModel(nn.Module):
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
name_mapped.endswith(".bias") or name_mapped.endswith("_bias")
) and name_mapped not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
loaded_weight,
name,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
break
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight