[EPLB] Support EPLB for Mixtral Model (#22842)

Signed-off-by: rouchenzi <ruochenwen@gmail.com>
Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com>
Co-authored-by: Bowen Wang <abmfy@icloud.com>
This commit is contained in:
rouchenzi 2025-09-17 00:27:34 -07:00 committed by GitHub
parent dd39baf717
commit b77bf34e53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,7 +23,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from collections.abc import Iterable import typing
from collections.abc import Callable, Iterable
from itertools import islice from itertools import islice
from typing import Optional, Union from typing import Optional, Union
@ -33,8 +34,9 @@ from transformers import MixtralConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -50,8 +52,8 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
@ -74,10 +76,32 @@ class MixtralMoE(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
dp_size: Optional[int] = None, dp_size: Optional[int] = None,
prefix: str = ""): prefix: str = "",
enable_eplb: bool = False):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
# Expert Parallelism Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.enable_eplb = enable_eplb
self.n_routed_experts = num_experts
self.n_logical_experts = num_experts
self.n_redundant_experts = (
parallel_config.eplb_config.num_redundant_experts)
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
# Gate always runs at half / full precision for now. # Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size, self.gate = ReplicatedLinear(hidden_size,
@ -97,7 +121,9 @@ class MixtralMoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
dp_size=dp_size, dp_size=dp_size,
prefix=f"{prefix}.experts") prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
@ -200,6 +226,7 @@ class MixtralDecoderLayer(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -221,7 +248,8 @@ class MixtralDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe") prefix=f"{prefix}.block_sparse_moe",
enable_eplb=enable_eplb)
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
@ -262,6 +290,7 @@ class MixtralModel(nn.Module):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
parallel_config = vllm_config.parallel_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
@ -276,10 +305,18 @@ class MixtralModel(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.enable_eplb = parallel_config.enable_eplb
self.num_redundant_experts = (
parallel_config.eplb_config.num_redundant_experts)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer( lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix config,
cache_config,
quant_config=quant_config,
prefix=prefix,
enable_eplb=self.enable_eplb,
), ),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
@ -325,7 +362,8 @@ class MixtralModel(nn.Module):
ckpt_gate_proj_name="w1", ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2", ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3", ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts) num_experts=self.config.num_local_experts,
num_redundant_experts=self.num_redundant_experts)
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
@ -373,26 +411,40 @@ class MixtralModel(nn.Module):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
is_expert_weight = False
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name)
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name_mapped, self):
continue continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict): if ((name_mapped.endswith(".bias")
or name_mapped.endswith("_bias"))
and name_mapped not in params_dict):
continue continue
param = params_dict[name]
weight_loader = param.weight_loader param = params_dict[name_mapped]
weight_loader(param, weight_loader = typing.cast(Callable[..., bool],
loaded_weight, param.weight_loader)
name, success = weight_loader(param,
shard_id=shard_id, loaded_weight,
expert_id=expert_id) name_mapped,
break shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else: else:
if is_expert_weight:
continue
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias")) if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict): and name not in params_dict):
@ -413,7 +465,8 @@ class MixtralModel(nn.Module):
return loaded_params return loaded_params
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
MixtureOfExperts):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
packed_modules_mapping = { packed_modules_mapping = {
@ -462,6 +515,67 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self.expert_weights = []
self.moe_layers: list[FusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, MixtralDecoderLayer)
if hasattr(layer, "block_sparse_moe") and isinstance(
layer.block_sparse_moe, MixtralMoE):
example_moe = layer.block_sparse_moe
self.moe_layers.append(layer.block_sparse_moe.experts)
self.num_moe_layers = len(self.moe_layers)
if example_moe is None:
raise RuntimeError("No MixtralMoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_redundant_experts = example_moe.n_redundant_experts
self.num_expert_groups = 1
self.num_shared_experts = 0
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = (num_physical_experts -
self.num_logical_experts)
for layer in self.model.layers:
if hasattr(layer, "block_sparse_moe") and isinstance(
layer.block_sparse_moe, MixtralMoE):
moe = layer.block_sparse_moe
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)