Support expert parallel load balancing in Transformers backend (#26287)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-10-06 12:20:16 +01:00 committed by GitHub
parent 19a00eb210
commit 0340f45553
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 76 additions and 27 deletions

View File

@ -40,7 +40,7 @@ from vllm.config import (
) )
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -506,9 +506,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size self.tp_group = get_tp_group()
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()
# Weights to skip in `self.load_weights` # Weights to skip in `self.load_weights`
self.skip_prefixes: list[str] = [] self.skip_prefixes: list[str] = []
@ -576,7 +574,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
""" """
Apply the model's pipeline parallelization plan. Apply the model's pipeline parallelization plan.
""" """
if self.pp_size <= 1: if self.pp_group.world_size <= 1:
return return
if not self.model.supports_pp_plan: if not self.model.supports_pp_plan:
@ -613,7 +611,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Module list # Module list
start_layer, end_layer = get_pp_indices( start_layer, end_layer = get_pp_indices(
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size self.text_config.num_hidden_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
) )
layers_name = pp_plan[module_list_idx] layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name) layers = getattr(self.model, layers_name)
@ -638,7 +638,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
""" """
tp_plan = self.model.tp_plan tp_plan = self.model.tp_plan
if not tp_plan and self.tp_size > 1: if not tp_plan and self.tp_group.world_size > 1:
tip = get_feature_request_tip( tip = get_feature_request_tip(
self.model_config.model, self.model_config.trust_remote_code self.model_config.model, self.model_config.trust_remote_code
) )
@ -687,7 +687,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
head_size = self.model_config.get_head_size() head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices( start, end = get_pp_indices(
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size self.text_config.num_hidden_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
) )
attention_instances = {} attention_instances = {}
@ -749,7 +751,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if not get_pp_group().is_first_rank: if not self.pp_group.is_first_rank:
assert intermediate_tensors is not None assert intermediate_tensors is not None
input_ids = None input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"] inputs_embeds = intermediate_tensors["hidden_states"]
@ -773,7 +775,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
return_dict=False, return_dict=False,
)[0][0, ...] # we remove batch dimension for now )[0][0, ...] # we remove batch dimension for now
if not get_pp_group().is_last_rank: if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states return hidden_states
@ -811,7 +813,7 @@ class TransformersForCausalLM(TransformersBase):
if self.text_config.tie_word_embeddings: if self.text_config.tie_word_embeddings:
self.skip_prefixes.append("lm_head.") self.skip_prefixes.append("lm_head.")
if get_pp_group().is_last_rank: if self.pp_group.is_last_rank:
self.unpadded_vocab_size = self.text_config.vocab_size self.unpadded_vocab_size = self.text_config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.text_config.vocab_size, self.text_config.vocab_size,

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .interfaces import MixtureOfExperts
from .transformers import ( from .transformers import (
TransformersBase, TransformersBase,
TransformersForCausalLM, TransformersForCausalLM,
@ -116,17 +117,41 @@ direct_register_custom_op(
) )
class TransformersMoEBase(TransformersBase): class TransformersMoEBase(TransformersBase, MixtureOfExperts):
def __init__(self, *, vllm_config, prefix=""): def __init__(self, *, vllm_config, prefix=""):
self.check_version("4.57.0.dev0", "MoE models support") self.check_version("4.57.0.dev0", "MoE models support")
self.ep_group = get_ep_group()
super().__init__(vllm_config=vllm_config, prefix=prefix) super().__init__(vllm_config=vllm_config, prefix=prefix)
if self.parallel_config.enable_eplb: def set_eplb_state(
raise NotImplementedError( self,
"Transformers backend does not support expert parallel load " expert_load_view: torch.Tensor,
"balancing yet." logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
):
for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers):
mlp_layer.experts.set_eplb_state(
moe_layer_idx=moe_layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
) )
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
):
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for mlp in self.mlp_layers:
mlp.n_local_physical_experts = num_local_physical_experts
mlp.n_physical_experts = num_physical_experts
mlp.n_redundant_experts = self.num_redundant_experts
mlp.experts.update_expert_map()
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
""" """
Params for weights, fp8 weight scales, fp8 activation scales Params for weights, fp8 weight scales, fp8 activation scales
@ -138,6 +163,8 @@ class TransformersMoEBase(TransformersBase):
("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style ("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style
("linear", "linear_1", "linear_v"), # Grok1 style ("linear", "linear_1", "linear_v"), # Grok1 style
] ]
num_experts = self.model_config.get_num_experts()
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
expert_mapping = [] expert_mapping = []
for gate_proj, down_proj, up_proj in ckpt_names: for gate_proj, down_proj, up_proj in ckpt_names:
expert_mapping.extend( expert_mapping.extend(
@ -145,8 +172,8 @@ class TransformersMoEBase(TransformersBase):
ckpt_gate_proj_name=gate_proj, ckpt_gate_proj_name=gate_proj,
ckpt_down_proj_name=down_proj, ckpt_down_proj_name=down_proj,
ckpt_up_proj_name=up_proj, ckpt_up_proj_name=up_proj,
num_experts=self.model_config.get_num_experts(), num_experts=num_experts,
num_redundant_experts=0, # TODO: enable EPLB num_redundant_experts=num_redundant_experts,
) )
) )
return expert_mapping return expert_mapping
@ -167,12 +194,15 @@ class TransformersMoEBase(TransformersBase):
# If there are shared experts, the results are # If there are shared experts, the results are
# reduced after mlp.forward() not inside FusedMoE # reduced after mlp.forward() not inside FusedMoE
num_experts_shared = getattr_iter( num_shared_experts = getattr_iter(
text_config, text_config,
["num_experts_shared", "n_shared_experts", "moe_num_shared_experts"], [
"n_shared_experts", # DeepSeek, Docs, GLM
"moe_num_shared_experts", # Aria, Ernie
],
0, 0,
) )
reduce_results = num_experts_shared == 0 reduce_results = num_shared_experts == 0
def add_all_reduce(mlp: nn.Module): def add_all_reduce(mlp: nn.Module):
"""Adds an all-reduce to the output of `mlp.forward()`.""" """Adds an all-reduce to the output of `mlp.forward()`."""
@ -207,13 +237,23 @@ class TransformersMoEBase(TransformersBase):
# Expert mapping for `AutoWeightsLoader` # Expert mapping for `AutoWeightsLoader`
expert_mapping = self.get_expert_mapping() expert_mapping = self.get_expert_mapping()
# Configs
parallel_config = self.parallel_config
eplb_config = parallel_config.eplb_config
# Expert parallel load balancing kwargs # Expert parallel load balancing kwargs
enable_eplb = parallel_config.enable_eplb enable_eplb = self.parallel_config.enable_eplb
num_redundant_experts = eplb_config.num_redundant_experts num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
# MixtureOfExperts mixin settings
ep_size = self.ep_group.world_size
self.mlp_layers = [] # Used for MixtureOfExperts methods
self.expert_weights = []
self.num_moe_layers = 0
self.num_expert_groups = 1 if num_expert_group is None else num_expert_group
self.num_logical_experts = num_experts
self.num_physical_experts = num_experts + num_redundant_experts
self.num_local_physical_experts = self.num_physical_experts // ep_size
self.num_routed_experts = num_experts
self.num_shared_experts = num_shared_experts
self.num_redundant_experts = num_redundant_experts
# Recursively fuse MoE layers # Recursively fuse MoE layers
def _recursive_replace(module: nn.Module, prefix: str): def _recursive_replace(module: nn.Module, prefix: str):
@ -235,6 +275,9 @@ class TransformersMoEBase(TransformersBase):
for mlp_param_name, _ in mlp.named_parameters(): for mlp_param_name, _ in mlp.named_parameters():
if "shared_expert" in mlp_param_name: if "shared_expert" in mlp_param_name:
reduce_results = False reduce_results = False
# If the config does not specify num_shared_experts, but
# the model has shared experts, we assume there is one.
self.num_shared_experts = 1
break break
# Replace experts module with FusedMoE # Replace experts module with FusedMoE
fused_experts = TransformersFusedMoE( fused_experts = TransformersFusedMoE(
@ -258,6 +301,10 @@ class TransformersMoEBase(TransformersBase):
) )
mlp.experts = fused_experts mlp.experts = fused_experts
log_replacement(qual_name, experts, fused_experts) log_replacement(qual_name, experts, fused_experts)
# Update MixtureOfExperts mixin state
self.mlp_layers.append(mlp)
self.expert_weights.append(fused_experts.get_expert_weights())
self.num_moe_layers += 1
# If results are not all-reduced in FusedMoE, ensure they # If results are not all-reduced in FusedMoE, ensure they
# are all-reduced at the end of mlp.forward() if tensor # are all-reduced at the end of mlp.forward() if tensor
# parallel or expert parallel is enabled # parallel or expert parallel is enabled