[EPLB] Add EPLB support for hunyuan_v1 (#23078)

This commit is contained in:
YiwenC 2025-09-17 21:51:35 -07:00 committed by GitHub
parent 3bc18127ff
commit 9d8a2d86d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 123 additions and 16 deletions

View File

@ -1508,8 +1508,8 @@ class FusedMoE(CustomOp):
return [
weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS
and not name.startswith("_shared_experts.")
if name not in NON_EXPERT_WEIGHTS and weight.shape != torch.Size(
[]) and not name.startswith("_shared_experts.")
]
def set_eplb_state(

View File

@ -23,7 +23,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only HunYuan model compatible with HuggingFace weights."""
from collections.abc import Iterable
import typing
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union
import regex as re
@ -33,8 +34,8 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group,
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
@ -56,7 +57,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_layers, maybe_prefix)
@ -355,10 +356,16 @@ class HunYuanSparseMoeBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
layer_id: int = -1,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
@ -379,8 +386,23 @@ class HunYuanSparseMoeBlock(nn.Module):
config.moe_intermediate_size, int) else
config.moe_intermediate_size[layer_id])
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(
num_experts=config.num_experts,
num_experts=self.n_routed_experts,
top_k=top_k,
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
@ -388,6 +410,8 @@ class HunYuanSparseMoeBlock(nn.Module):
renormalize=top_k > 1,
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
)
self.gate = ReplicatedLinear(config.hidden_size,
@ -446,6 +470,7 @@ class HunYuanDecoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
layer_id: int = -1,
enable_eplb: bool = False,
) -> None:
super().__init__()
assert layer_id >= 0
@ -509,6 +534,7 @@ class HunYuanDecoderLayer(nn.Module):
quant_config=quant_config,
layer_id=layer_id,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = HunYuanMLP(
@ -562,6 +588,9 @@ class HunYuanModel(nn.Module):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
eplb_config = vllm_config.parallel_config.eplb_config
enable_eplb = vllm_config.parallel_config.enable_eplb
self.num_redundant_experts = eplb_config.num_redundant_experts
self.config = config
self.quant_config = quant_config
@ -588,6 +617,7 @@ class HunYuanModel(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
enable_eplb=enable_eplb,
),
prefix=f"{prefix}.layers",
)
@ -674,6 +704,7 @@ class HunYuanModel(nn.Module):
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
num_redundant_experts=self.num_redundant_experts,
)
else:
return []
@ -803,25 +834,43 @@ class HunYuanModel(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
# this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(
param,
loaded_weight,
name,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
break
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
@ -841,7 +890,7 @@ class HunYuanModel(nn.Module):
return loaded_params
class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP):
class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -883,6 +932,64 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP):
else:
self.lm_head = PPMissingLayer()
# Set MoE hyperparameters
self.expert_weights = []
self.num_expert_groups = 1
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, HunYuanDecoderLayer)
if isinstance(layer.mlp, HunYuanSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No HunYuanMoE layer found in model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
self.expert_weights.append(layer.get_expert_weights())
# Register the expert weights.
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = (num_physical_experts -
self.num_logical_experts)
for layer in self.model.layers:
if isinstance(layer.mlp, HunYuanSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def forward(
self,
input_ids: torch.Tensor,