mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:05:44 +08:00
[EPLB] Support EPLB for Mixtral Model (#22842)
Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Bowen Wang <abmfy@icloud.com>
This commit is contained in:
parent
dd39baf717
commit
b77bf34e53
@ -23,7 +23,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from collections.abc import Iterable
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from itertools import islice
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -33,8 +34,9 @@ from transformers import MixtralConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -50,8 +52,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -74,10 +76,32 @@ class MixtralMoE(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
dp_size: Optional[int] = None,
|
||||
prefix: str = ""):
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
|
||||
# Expert Parallelism Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.enable_eplb = enable_eplb
|
||||
|
||||
self.n_routed_experts = num_experts
|
||||
self.n_logical_experts = num_experts
|
||||
self.n_redundant_experts = (
|
||||
parallel_config.eplb_config.num_redundant_experts)
|
||||
self.n_physical_experts = (self.n_logical_experts +
|
||||
self.n_redundant_experts)
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
self.physical_expert_start = (self.ep_rank *
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
|
||||
self.gate = ReplicatedLinear(hidden_size,
|
||||
@ -97,7 +121,9 @@ class MixtralMoE(nn.Module):
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
dp_size=dp_size,
|
||||
prefix=f"{prefix}.experts")
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
@ -200,6 +226,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -221,7 +248,8 @@ class MixtralDecoderLayer(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.block_sparse_moe")
|
||||
prefix=f"{prefix}.block_sparse_moe",
|
||||
enable_eplb=enable_eplb)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -262,6 +290,7 @@ class MixtralModel(nn.Module):
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
@ -276,10 +305,18 @@ class MixtralModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
self.num_redundant_experts = (
|
||||
parallel_config.eplb_config.num_redundant_experts)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: MixtralDecoderLayer(
|
||||
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||
config,
|
||||
cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=self.enable_eplb,
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
@ -325,7 +362,8 @@ class MixtralModel(nn.Module):
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
num_experts=self.config.num_local_experts)
|
||||
num_experts=self.config.num_local_experts,
|
||||
num_redundant_experts=self.num_redundant_experts)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
@ -373,26 +411,40 @@ class MixtralModel(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
is_expert_weight = True
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
|
||||
if ((name_mapped.endswith(".bias")
|
||||
or name_mapped.endswith("_bias"))
|
||||
and name_mapped not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
@ -413,7 +465,8 @@ class MixtralModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
MixtureOfExperts):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
packed_modules_mapping = {
|
||||
@ -462,6 +515,67 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
self.expert_weights = []
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
example_moe = None
|
||||
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer, PPMissingLayer):
|
||||
continue
|
||||
assert isinstance(layer, MixtralDecoderLayer)
|
||||
if hasattr(layer, "block_sparse_moe") and isinstance(
|
||||
layer.block_sparse_moe, MixtralMoE):
|
||||
example_moe = layer.block_sparse_moe
|
||||
self.moe_layers.append(layer.block_sparse_moe.experts)
|
||||
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
|
||||
if example_moe is None:
|
||||
raise RuntimeError("No MixtralMoE layer found in model.layers.")
|
||||
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
self.num_expert_groups = 1
|
||||
self.num_shared_experts = 0
|
||||
|
||||
def set_eplb_state(
|
||||
self,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> None:
|
||||
for layer_idx, layer in enumerate(self.moe_layers):
|
||||
# Register the expert weights.
|
||||
self.expert_weights.append(layer.get_expert_weights())
|
||||
layer.set_eplb_state(
|
||||
moe_layer_idx=layer_idx,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = (num_physical_experts -
|
||||
self.num_logical_experts)
|
||||
for layer in self.model.layers:
|
||||
if hasattr(layer, "block_sparse_moe") and isinstance(
|
||||
layer.block_sparse_moe, MixtralMoE):
|
||||
moe = layer.block_sparse_moe
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user