diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 52fcbbfc58be..b02030b6d627 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -23,7 +23,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from itertools import islice from typing import Optional, Union @@ -33,8 +34,9 @@ from transformers import MixtralConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -50,8 +52,8 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -74,10 +76,32 @@ class MixtralMoE(nn.Module): quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, dp_size: Optional[int] = None, - prefix: str = ""): + prefix: str = "", + enable_eplb: bool = False): super().__init__() self.hidden_size = hidden_size + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + + # Expert Parallelism Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_routed_experts = num_experts + self.n_logical_experts = num_experts + self.n_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts) + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear(hidden_size, @@ -97,7 +121,9 @@ class MixtralMoE(nn.Module): quant_config=quant_config, tp_size=tp_size, dp_size=dp_size, - prefix=f"{prefix}.experts") + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -200,6 +226,7 @@ class MixtralDecoderLayer(nn.Module): cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -221,7 +248,8 @@ class MixtralDecoderLayer(nn.Module): hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") + prefix=f"{prefix}.block_sparse_moe", + enable_eplb=enable_eplb) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -262,6 +290,7 @@ class MixtralModel(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config self.config = config self.quant_config = quant_config @@ -276,10 +305,18 @@ class MixtralModel(nn.Module): org_num_embeddings=config.vocab_size, ) + self.enable_eplb = parallel_config.enable_eplb + self.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts) + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix + config, + cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=self.enable_eplb, ), prefix=f"{prefix}.layers") @@ -325,7 +362,8 @@ class MixtralModel(nn.Module): ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) + num_experts=self.config.num_local_experts, + num_redundant_experts=self.num_redundant_experts) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -373,26 +411,40 @@ class MixtralModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: continue - name = name.replace(weight_name, param_name) + + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + if is_pp_missing_parameter(name_mapped, self): continue - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + + if ((name_mapped.endswith(".bias") + or name_mapped.endswith("_bias")) + and name_mapped not in params_dict): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break + + param = params_dict[name_mapped] + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break else: + if is_expert_weight: + continue # Skip loading extra bias for GPTQ models. if ((name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict): @@ -413,7 +465,8 @@ class MixtralModel(nn.Module): return loaded_params -class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, + MixtureOfExperts): fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -462,6 +515,67 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + self.moe_layers: list[FusedMoE] = [] + example_moe = None + + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + assert isinstance(layer, MixtralDecoderLayer) + if hasattr(layer, "block_sparse_moe") and isinstance( + layer.block_sparse_moe, MixtralMoE): + example_moe = layer.block_sparse_moe + self.moe_layers.append(layer.block_sparse_moe.experts) + + self.num_moe_layers = len(self.moe_layers) + + if example_moe is None: + raise RuntimeError("No MixtralMoE layer found in model.layers.") + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + self.num_expert_groups = 1 + self.num_shared_experts = 0 + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if hasattr(layer, "block_sparse_moe") and isinstance( + layer.block_sparse_moe, MixtralMoE): + moe = layer.block_sparse_moe + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids)