[EPLB] Add EPLB support for hunyuan_v1 (#23078)

This commit is contained in:
YiwenC 2025-09-17 21:51:35 -07:00 committed by GitHub
parent 3bc18127ff
commit 9d8a2d86d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 123 additions and 16 deletions

View File

@ -1508,8 +1508,8 @@ class FusedMoE(CustomOp):
return [ return [
weight.view(self.local_num_experts, -1) for name, weight in weights weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS if name not in NON_EXPERT_WEIGHTS and weight.shape != torch.Size(
and not name.startswith("_shared_experts.") []) and not name.startswith("_shared_experts.")
] ]
def set_eplb_state( def set_eplb_state(

View File

@ -23,7 +23,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only HunYuan model compatible with HuggingFace weights.""" """Inference-only HunYuan model compatible with HuggingFace weights."""
from collections.abc import Iterable import typing
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union
import regex as re import regex as re
@ -33,8 +34,8 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@ -56,7 +57,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_layers, maybe_prefix) make_layers, maybe_prefix)
@ -355,10 +356,16 @@ class HunYuanSparseMoeBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id: int = -1, layer_id: int = -1,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
@ -379,8 +386,23 @@ class HunYuanSparseMoeBlock(nn.Module):
config.moe_intermediate_size, int) else config.moe_intermediate_size, int) else
config.moe_intermediate_size[layer_id]) config.moe_intermediate_size[layer_id])
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE( self.experts = FusedMoE(
num_experts=config.num_experts, num_experts=self.n_routed_experts,
top_k=top_k, top_k=top_k,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
@ -388,6 +410,8 @@ class HunYuanSparseMoeBlock(nn.Module):
renormalize=top_k > 1, renormalize=top_k > 1,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
) )
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
@ -446,6 +470,7 @@ class HunYuanDecoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
layer_id: int = -1, layer_id: int = -1,
enable_eplb: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert layer_id >= 0 assert layer_id >= 0
@ -509,6 +534,7 @@ class HunYuanDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
) )
else: else:
self.mlp = HunYuanMLP( self.mlp = HunYuanMLP(
@ -562,6 +588,9 @@ class HunYuanModel(nn.Module):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
eplb_config = vllm_config.parallel_config.eplb_config
enable_eplb = vllm_config.parallel_config.enable_eplb
self.num_redundant_experts = eplb_config.num_redundant_experts
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
@ -588,6 +617,7 @@ class HunYuanModel(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
enable_eplb=enable_eplb,
), ),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
@ -674,6 +704,7 @@ class HunYuanModel(nn.Module):
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts, num_experts=self.config.num_experts,
num_redundant_experts=self.num_redundant_experts,
) )
else: else:
return [] return []
@ -803,25 +834,43 @@ class HunYuanModel(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
is_expert_weight = False
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) # this is an expert weight and should not be
# Skip layers on other devices. # attempted to load as other weights later
if is_pp_missing_parameter(name, self): is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue continue
param = params_dict[name] param = params_dict[name_mapped]
weight_loader = param.weight_loader # We should ask the weight loader to return success or not
weight_loader( # here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(
param, param,
loaded_weight, loaded_weight,
name, name_mapped,
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
return_success=True,
) )
break if success:
name = name_mapped
break
else: else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Remapping the name of FP8 kv-scale. # Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict) name = maybe_remap_kv_scale_name(name, params_dict)
if name is None: if name is None:
@ -841,7 +890,7 @@ class HunYuanModel(nn.Module):
return loaded_params return loaded_params
class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP): class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@ -883,6 +932,64 @@ class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsPP):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
# Set MoE hyperparameters
self.expert_weights = []
self.num_expert_groups = 1
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, HunYuanDecoderLayer)
if isinstance(layer.mlp, HunYuanSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No HunYuanMoE layer found in model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
self.expert_weights.append(layer.get_expert_weights())
# Register the expert weights.
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = (num_physical_experts -
self.num_logical_experts)
for layer in self.model.layers:
if isinstance(layer.mlp, HunYuanSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,