mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:15:28 +08:00
Support expert parallel load balancing in Transformers backend (#26287)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
19a00eb210
commit
0340f45553
@ -40,7 +40,7 @@ from vllm.config import (
|
|||||||
)
|
)
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.config.utils import getattr_iter
|
from vllm.config.utils import getattr_iter
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tp_group
|
||||||
from vllm.distributed.utils import get_pp_indices
|
from vllm.distributed.utils import get_pp_indices
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -506,9 +506,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config
|
self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config
|
||||||
|
|
||||||
self.pp_group = get_pp_group()
|
self.pp_group = get_pp_group()
|
||||||
self.pp_size = self.pp_group.world_size
|
self.tp_group = get_tp_group()
|
||||||
self.pp_rank = self.pp_group.rank_in_group
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
|
||||||
# Weights to skip in `self.load_weights`
|
# Weights to skip in `self.load_weights`
|
||||||
self.skip_prefixes: list[str] = []
|
self.skip_prefixes: list[str] = []
|
||||||
@ -576,7 +574,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
"""
|
"""
|
||||||
Apply the model's pipeline parallelization plan.
|
Apply the model's pipeline parallelization plan.
|
||||||
"""
|
"""
|
||||||
if self.pp_size <= 1:
|
if self.pp_group.world_size <= 1:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.model.supports_pp_plan:
|
if not self.model.supports_pp_plan:
|
||||||
@ -613,7 +611,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
# Module list
|
# Module list
|
||||||
start_layer, end_layer = get_pp_indices(
|
start_layer, end_layer = get_pp_indices(
|
||||||
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size
|
self.text_config.num_hidden_layers,
|
||||||
|
self.pp_group.rank_in_group,
|
||||||
|
self.pp_group.world_size,
|
||||||
)
|
)
|
||||||
layers_name = pp_plan[module_list_idx]
|
layers_name = pp_plan[module_list_idx]
|
||||||
layers = getattr(self.model, layers_name)
|
layers = getattr(self.model, layers_name)
|
||||||
@ -638,7 +638,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
"""
|
"""
|
||||||
tp_plan = self.model.tp_plan
|
tp_plan = self.model.tp_plan
|
||||||
|
|
||||||
if not tp_plan and self.tp_size > 1:
|
if not tp_plan and self.tp_group.world_size > 1:
|
||||||
tip = get_feature_request_tip(
|
tip = get_feature_request_tip(
|
||||||
self.model_config.model, self.model_config.trust_remote_code
|
self.model_config.model, self.model_config.trust_remote_code
|
||||||
)
|
)
|
||||||
@ -687,7 +687,9 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
head_size = self.model_config.get_head_size()
|
head_size = self.model_config.get_head_size()
|
||||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
start, end = get_pp_indices(
|
start, end = get_pp_indices(
|
||||||
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size
|
self.text_config.num_hidden_layers,
|
||||||
|
self.pp_group.rank_in_group,
|
||||||
|
self.pp_group.world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
attention_instances = {}
|
attention_instances = {}
|
||||||
@ -749,7 +751,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
if not get_pp_group().is_first_rank:
|
if not self.pp_group.is_first_rank:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||||
@ -773,7 +775,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0][0, ...] # we remove batch dimension for now
|
)[0][0, ...] # we remove batch dimension for now
|
||||||
|
|
||||||
if not get_pp_group().is_last_rank:
|
if not self.pp_group.is_last_rank:
|
||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@ -811,7 +813,7 @@ class TransformersForCausalLM(TransformersBase):
|
|||||||
if self.text_config.tie_word_embeddings:
|
if self.text_config.tie_word_embeddings:
|
||||||
self.skip_prefixes.append("lm_head.")
|
self.skip_prefixes.append("lm_head.")
|
||||||
|
|
||||||
if get_pp_group().is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
self.unpadded_vocab_size = self.text_config.vocab_size
|
self.unpadded_vocab_size = self.text_config.vocab_size
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
self.text_config.vocab_size,
|
self.text_config.vocab_size,
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
from .interfaces import MixtureOfExperts
|
||||||
from .transformers import (
|
from .transformers import (
|
||||||
TransformersBase,
|
TransformersBase,
|
||||||
TransformersForCausalLM,
|
TransformersForCausalLM,
|
||||||
@ -116,17 +117,41 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformersMoEBase(TransformersBase):
|
class TransformersMoEBase(TransformersBase, MixtureOfExperts):
|
||||||
def __init__(self, *, vllm_config, prefix=""):
|
def __init__(self, *, vllm_config, prefix=""):
|
||||||
self.check_version("4.57.0.dev0", "MoE models support")
|
self.check_version("4.57.0.dev0", "MoE models support")
|
||||||
|
self.ep_group = get_ep_group()
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
if self.parallel_config.enable_eplb:
|
def set_eplb_state(
|
||||||
raise NotImplementedError(
|
self,
|
||||||
"Transformers backend does not support expert parallel load "
|
expert_load_view: torch.Tensor,
|
||||||
"balancing yet."
|
logical_to_physical_map: torch.Tensor,
|
||||||
|
logical_replica_count: torch.Tensor,
|
||||||
|
):
|
||||||
|
for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers):
|
||||||
|
mlp_layer.experts.set_eplb_state(
|
||||||
|
moe_layer_idx=moe_layer_idx,
|
||||||
|
expert_load_view=expert_load_view,
|
||||||
|
logical_to_physical_map=logical_to_physical_map,
|
||||||
|
logical_replica_count=logical_replica_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_physical_experts_metadata(
|
||||||
|
self,
|
||||||
|
num_physical_experts: int,
|
||||||
|
num_local_physical_experts: int,
|
||||||
|
):
|
||||||
|
assert self.num_local_physical_experts == num_local_physical_experts
|
||||||
|
self.num_physical_experts = num_physical_experts
|
||||||
|
self.num_local_physical_experts = num_local_physical_experts
|
||||||
|
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
|
||||||
|
for mlp in self.mlp_layers:
|
||||||
|
mlp.n_local_physical_experts = num_local_physical_experts
|
||||||
|
mlp.n_physical_experts = num_physical_experts
|
||||||
|
mlp.n_redundant_experts = self.num_redundant_experts
|
||||||
|
mlp.experts.update_expert_map()
|
||||||
|
|
||||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||||
"""
|
"""
|
||||||
Params for weights, fp8 weight scales, fp8 activation scales
|
Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
@ -138,6 +163,8 @@ class TransformersMoEBase(TransformersBase):
|
|||||||
("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style
|
("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style
|
||||||
("linear", "linear_1", "linear_v"), # Grok1 style
|
("linear", "linear_1", "linear_v"), # Grok1 style
|
||||||
]
|
]
|
||||||
|
num_experts = self.model_config.get_num_experts()
|
||||||
|
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
|
||||||
expert_mapping = []
|
expert_mapping = []
|
||||||
for gate_proj, down_proj, up_proj in ckpt_names:
|
for gate_proj, down_proj, up_proj in ckpt_names:
|
||||||
expert_mapping.extend(
|
expert_mapping.extend(
|
||||||
@ -145,8 +172,8 @@ class TransformersMoEBase(TransformersBase):
|
|||||||
ckpt_gate_proj_name=gate_proj,
|
ckpt_gate_proj_name=gate_proj,
|
||||||
ckpt_down_proj_name=down_proj,
|
ckpt_down_proj_name=down_proj,
|
||||||
ckpt_up_proj_name=up_proj,
|
ckpt_up_proj_name=up_proj,
|
||||||
num_experts=self.model_config.get_num_experts(),
|
num_experts=num_experts,
|
||||||
num_redundant_experts=0, # TODO: enable EPLB
|
num_redundant_experts=num_redundant_experts,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return expert_mapping
|
return expert_mapping
|
||||||
@ -167,12 +194,15 @@ class TransformersMoEBase(TransformersBase):
|
|||||||
|
|
||||||
# If there are shared experts, the results are
|
# If there are shared experts, the results are
|
||||||
# reduced after mlp.forward() not inside FusedMoE
|
# reduced after mlp.forward() not inside FusedMoE
|
||||||
num_experts_shared = getattr_iter(
|
num_shared_experts = getattr_iter(
|
||||||
text_config,
|
text_config,
|
||||||
["num_experts_shared", "n_shared_experts", "moe_num_shared_experts"],
|
[
|
||||||
|
"n_shared_experts", # DeepSeek, Docs, GLM
|
||||||
|
"moe_num_shared_experts", # Aria, Ernie
|
||||||
|
],
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
reduce_results = num_experts_shared == 0
|
reduce_results = num_shared_experts == 0
|
||||||
|
|
||||||
def add_all_reduce(mlp: nn.Module):
|
def add_all_reduce(mlp: nn.Module):
|
||||||
"""Adds an all-reduce to the output of `mlp.forward()`."""
|
"""Adds an all-reduce to the output of `mlp.forward()`."""
|
||||||
@ -207,13 +237,23 @@ class TransformersMoEBase(TransformersBase):
|
|||||||
# Expert mapping for `AutoWeightsLoader`
|
# Expert mapping for `AutoWeightsLoader`
|
||||||
expert_mapping = self.get_expert_mapping()
|
expert_mapping = self.get_expert_mapping()
|
||||||
|
|
||||||
# Configs
|
|
||||||
parallel_config = self.parallel_config
|
|
||||||
eplb_config = parallel_config.eplb_config
|
|
||||||
|
|
||||||
# Expert parallel load balancing kwargs
|
# Expert parallel load balancing kwargs
|
||||||
enable_eplb = parallel_config.enable_eplb
|
enable_eplb = self.parallel_config.enable_eplb
|
||||||
num_redundant_experts = eplb_config.num_redundant_experts
|
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
|
||||||
|
|
||||||
|
# MixtureOfExperts mixin settings
|
||||||
|
ep_size = self.ep_group.world_size
|
||||||
|
|
||||||
|
self.mlp_layers = [] # Used for MixtureOfExperts methods
|
||||||
|
self.expert_weights = []
|
||||||
|
self.num_moe_layers = 0
|
||||||
|
self.num_expert_groups = 1 if num_expert_group is None else num_expert_group
|
||||||
|
self.num_logical_experts = num_experts
|
||||||
|
self.num_physical_experts = num_experts + num_redundant_experts
|
||||||
|
self.num_local_physical_experts = self.num_physical_experts // ep_size
|
||||||
|
self.num_routed_experts = num_experts
|
||||||
|
self.num_shared_experts = num_shared_experts
|
||||||
|
self.num_redundant_experts = num_redundant_experts
|
||||||
|
|
||||||
# Recursively fuse MoE layers
|
# Recursively fuse MoE layers
|
||||||
def _recursive_replace(module: nn.Module, prefix: str):
|
def _recursive_replace(module: nn.Module, prefix: str):
|
||||||
@ -235,6 +275,9 @@ class TransformersMoEBase(TransformersBase):
|
|||||||
for mlp_param_name, _ in mlp.named_parameters():
|
for mlp_param_name, _ in mlp.named_parameters():
|
||||||
if "shared_expert" in mlp_param_name:
|
if "shared_expert" in mlp_param_name:
|
||||||
reduce_results = False
|
reduce_results = False
|
||||||
|
# If the config does not specify num_shared_experts, but
|
||||||
|
# the model has shared experts, we assume there is one.
|
||||||
|
self.num_shared_experts = 1
|
||||||
break
|
break
|
||||||
# Replace experts module with FusedMoE
|
# Replace experts module with FusedMoE
|
||||||
fused_experts = TransformersFusedMoE(
|
fused_experts = TransformersFusedMoE(
|
||||||
@ -258,6 +301,10 @@ class TransformersMoEBase(TransformersBase):
|
|||||||
)
|
)
|
||||||
mlp.experts = fused_experts
|
mlp.experts = fused_experts
|
||||||
log_replacement(qual_name, experts, fused_experts)
|
log_replacement(qual_name, experts, fused_experts)
|
||||||
|
# Update MixtureOfExperts mixin state
|
||||||
|
self.mlp_layers.append(mlp)
|
||||||
|
self.expert_weights.append(fused_experts.get_expert_weights())
|
||||||
|
self.num_moe_layers += 1
|
||||||
# If results are not all-reduced in FusedMoE, ensure they
|
# If results are not all-reduced in FusedMoE, ensure they
|
||||||
# are all-reduced at the end of mlp.forward() if tensor
|
# are all-reduced at the end of mlp.forward() if tensor
|
||||||
# parallel or expert parallel is enabled
|
# parallel or expert parallel is enabled
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user