[BugFix] Support EP/DP + EPLB with MTP (#25311)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Ilya Markov 2025-11-05 16:22:17 +01:00 committed by GitHub
parent 5d16d0fa62
commit e50c454672
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 957 additions and 529 deletions

View File

@ -232,8 +232,8 @@ steps:
commands: commands:
- pytest -v -s distributed/test_eplb_algo.py - pytest -v -s distributed/test_eplb_algo.py
- label: EPLB Execution Test # 5min - label: EPLB Execution Test # 10min
timeout_in_minutes: 15 timeout_in_minutes: 20
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 4 num_gpus: 4
source_file_dependencies: source_file_dependencies:
@ -241,6 +241,7 @@ steps:
- tests/distributed/test_eplb_execute.py - tests/distributed/test_eplb_execute.py
commands: commands:
- pytest -v -s distributed/test_eplb_execute.py - pytest -v -s distributed/test_eplb_execute.py
- pytest -v -s distributed/test_eplb_spec_decode.py
- label: Metrics, Tracing Test # 12min - label: Metrics, Tracing Test # 12min
timeout_in_minutes: 20 timeout_in_minutes: 20

View File

@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import lm_eval
import pytest
from tests.utils import large_gpu_mark
def get_model_args(
model_name: str,
spec_model_name: str,
spec_method: str,
tp_size: int,
model_max_len: int,
) -> dict:
speculative_config = {
"method": spec_method,
"model": spec_model_name,
"num_speculative_tokens": 1,
"max_model_len": model_max_len,
}
model_args = {
"pretrained": model_name,
"dtype": "auto",
"add_bos_token": True,
"tensor_parallel_size": tp_size,
"gpu_memory_utilization": 0.7,
"speculative_config": speculative_config,
"enable_expert_parallel": True,
"num_redundant_experts": tp_size,
"eplb_window_size": 128,
"eplb_step_interval": 1024,
"eplb_log_balancedness": False,
"enable_eplb": True,
"max_model_len": model_max_len,
}
return model_args
@pytest.mark.parametrize(
"model_setup",
[
pytest.param(
("mtp", "Qwen/Qwen3-Next-80B-A3B-Instruct", None, 4, 0.86),
marks=large_gpu_mark(min_gb=80),
),
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
0.92,
),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues"),
),
],
ids=["qwen3_next_mtp", "llama4_eagle"],
)
def test_eplb_spec_decode(
monkeypatch: pytest.MonkeyPatch,
model_setup: tuple[str, str, str, int, float],
):
"""
Test the correctness of EPLB speculative decoding with GSM8K dataset.
Applicable to MoE models with mtp or eagle spec decode.
"""
method, model_name, spec_model_name, tp_size, expected_gsm8k_value = model_setup
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
model_args = get_model_args(
model_name=model_name,
spec_model_name=spec_model_name,
spec_method=method,
tp_size=tp_size,
model_max_len=4096,
)
results = lm_eval.simple_evaluate(
model="vllm",
model_args=model_args,
tasks=TASK,
batch_size=64,
num_fewshot=8,
)
measured_value = results["results"][TASK][FILTER]
assert (
measured_value - RTOL < expected_gsm8k_value
and measured_value + RTOL > expected_gsm8k_value
), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}"

View File

@ -33,7 +33,7 @@ from dataclasses import dataclass
import torch import torch
from torch.distributed import ProcessGroup, all_reduce from torch.distributed import ProcessGroup, all_reduce
from vllm.config import ParallelConfig from vllm.config import ModelConfig, ParallelConfig
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_ep_group, get_ep_group,
get_node_count, get_node_count,
@ -50,7 +50,7 @@ logger = init_logger(__name__)
@dataclass @dataclass
class EplbState: class EplbModelState:
"""EPLB metrics.""" """EPLB metrics."""
physical_to_logical_map: torch.Tensor physical_to_logical_map: torch.Tensor
@ -130,34 +130,46 @@ class EplbState:
See: See:
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856 https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
""" """
expert_load_window_step: int = 0 model_name: str
""" model: MixtureOfExperts
Current step in the sliding window.
Different from `expert_rearrangement_step`, each EP rank may have its own
`expert_load_window_step`. class EplbState:
""" """
expert_load_window_size: int = 0 EplbState of each expert parallel model. Key is the model config hash.
"""
Size of the expert load sliding window.
This is a constant and is taken from the config.
""" """
expert_rearrangement_step: int = 0 def __init__(self, parallel_config: ParallelConfig, device: torch.device):
""" self.parallel_config = parallel_config
Steps after last rearrangement. self.device = device
Will trigger a rearrangement if it exceeds the threshold. self.model_states: dict[str, EplbModelState] = {}
"""
Current step in the sliding window.
NOTE: Keep in mind that all EP ranks need to have the same Different from `expert_rearrangement_step`,
`expert_rearrangement_step` value to ensure synchronization. each EP rank may have its own `expert_load_window_step`.
Otherwise, the rearrangement will hang at collective """
communication calls. self.expert_load_window_step: int = 0
""" """
expert_rearrangement_step_interval: int = 0 Size of the expert load sliding window.
""" This is a constant and is taken from the config.
Interval for expert rearrangement steps. """
This is a constant and is taken from the config. self.expert_load_window_size: int = 0
""" """
Steps after last rearrangement.
Will trigger a rearrangement if it exceeds the threshold.
NOTE: Keep in mind that all EP ranks need to have the same
`expert_rearrangement_step` value to ensure synchronization.
Otherwise, the rearrangement will hang at collective
communication calls.
"""
self.expert_rearrangement_step: int = 0
"""
Interval for expert rearrangement steps.
This is a constant and is taken from the config.
"""
self.expert_rearrangement_step_interval: int = 0
@staticmethod @staticmethod
def build_initial_global_physical_to_logical_map( def build_initial_global_physical_to_logical_map(
@ -179,26 +191,63 @@ class EplbState:
] ]
return global_physical_to_logical_map return global_physical_to_logical_map
@classmethod def validate_ep_configuration(self, new_model: MixtureOfExperts):
def build( """
cls, Validate that the expert parallel configuration of
the new model is the same as the existing models.
"""
if len(self.model_states) > 0:
model = next(iter(self.model_states.values())).model
if (
model.num_routed_experts != new_model.num_routed_experts
or model.num_redundant_experts != new_model.num_redundant_experts
or model.num_physical_experts != new_model.num_physical_experts
or model.num_logical_experts != new_model.num_logical_experts
or model.num_expert_groups != new_model.num_expert_groups
):
raise RuntimeError(
"Model: {} "
"with config {} "
"{} {} {} {} "
"mismatch with new model {} "
"with config {} "
"{} {} {} {}".format(
type(model),
model.num_routed_experts,
model.num_redundant_experts,
model.num_physical_experts,
model.num_logical_experts,
model.num_expert_groups,
type(new_model),
new_model.num_routed_experts,
new_model.num_redundant_experts,
new_model.num_physical_experts,
new_model.num_logical_experts,
new_model.num_expert_groups,
)
)
def add_model(
self,
model: MixtureOfExperts, model: MixtureOfExperts,
device: torch.device, model_config: ModelConfig,
parallel_config: ParallelConfig,
global_expert_load: torch.Tensor | None = None, global_expert_load: torch.Tensor | None = None,
old_global_expert_indices: torch.Tensor | None = None, old_global_expert_indices: torch.Tensor | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
) -> "EplbState": ):
""" """
Build the initial EPLB state. Build the initial EPLB state.
""" """
physical_to_logical_map_list = cls.build_initial_global_physical_to_logical_map( self.validate_ep_configuration(model)
model.num_routed_experts, physical_to_logical_map_list = (
model.num_redundant_experts, EplbState.build_initial_global_physical_to_logical_map(
model.num_routed_experts,
model.num_redundant_experts,
)
) )
physical_to_logical_map = torch.tensor( physical_to_logical_map = torch.tensor(
physical_to_logical_map_list, physical_to_logical_map_list,
device=device, device=self.device,
) )
# Assuming 8 GPUs per node, this supports up to # Assuming 8 GPUs per node, this supports up to
# (1023 + 1) / 8 = 128 nodes for now. # (1023 + 1) / 8 = 128 nodes for now.
@ -212,11 +261,11 @@ class EplbState:
logical_to_physical_map = torch.full( logical_to_physical_map = torch.full(
(model.num_logical_experts, max_slots_per_logical_expert), (model.num_logical_experts, max_slots_per_logical_expert),
-1, -1,
device=device, device=self.device,
) )
logical_replica_count = torch.zeros( logical_replica_count = torch.zeros(
(model.num_logical_experts,), (model.num_logical_experts,),
device=device, device=self.device,
dtype=torch.long, dtype=torch.long,
) )
@ -255,18 +304,25 @@ class EplbState:
expert_load_pass = torch.zeros( expert_load_pass = torch.zeros(
(model.num_moe_layers, model.num_physical_experts), (model.num_moe_layers, model.num_physical_experts),
dtype=torch.int32, dtype=torch.int32,
device=device, device=self.device,
) )
expert_load_window_size = parallel_config.eplb_config.window_size self.expert_load_window_size = self.parallel_config.eplb_config.window_size
expert_load_window = torch.zeros( expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers, model.num_physical_experts), (
self.expert_load_window_size,
model.num_moe_layers,
model.num_physical_experts,
),
dtype=torch.int32, dtype=torch.int32,
device=device, device=self.device,
) )
# Set the initial progress of rearrangement to 3/4 # Set the initial progress of rearrangement to 3/4
eplb_step_interval = parallel_config.eplb_config.step_interval eplb_step_interval = self.parallel_config.eplb_config.step_interval
expert_rearrangement_step = max(0, eplb_step_interval - eplb_step_interval // 4) self.expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4
)
self.expert_rearrangement_step_interval = eplb_step_interval
if global_expert_load is not None: if global_expert_load is not None:
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
@ -309,7 +365,7 @@ class EplbState:
(0, logical_to_physical_map.shape[-1] - max_physical_slots), (0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1, value=-1,
) )
physical_to_logical_map = new_physical_to_logical_map.to(device) physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map) logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count) logical_replica_count.copy_(new_logical_replica_count)
@ -327,22 +383,20 @@ class EplbState:
False, False,
rank_mapping, rank_mapping,
) )
expert_rearrangement_step = 0 self.expert_rearrangement_step = 0
return cls( self.model_states[model_config.compute_hash()] = EplbModelState(
physical_to_logical_map, physical_to_logical_map,
logical_to_physical_map, logical_to_physical_map,
logical_replica_count, logical_replica_count,
expert_load_pass, expert_load_pass,
expert_load_window, expert_load_window,
expert_load_window_size=expert_load_window_size, model_config.model,
expert_rearrangement_step=expert_rearrangement_step, model,
expert_rearrangement_step_interval=eplb_step_interval,
) )
def step( def step(
self, self,
model: MixtureOfExperts,
is_dummy: bool = False, is_dummy: bool = False,
is_profile: bool = False, is_profile: bool = False,
log_stats: bool = False, log_stats: bool = False,
@ -351,7 +405,6 @@ class EplbState:
Step the EPLB state. Step the EPLB state.
Args: Args:
model (MixtureOfExperts): The MoE model.
is_dummy (bool): If `True`, this is a dummy step and the load is_dummy (bool): If `True`, this is a dummy step and the load
metrics recorded in this forward pass will not count. metrics recorded in this forward pass will not count.
Defaults to `False`. Defaults to `False`.
@ -369,60 +422,66 @@ class EplbState:
""" """
if is_profile: if is_profile:
self.rearrange(model, is_profile=True) self.rearrange(is_profile=True)
return return
if is_dummy: if is_dummy:
# Do not record load metrics for dummy steps # Do not record load metrics for dummy steps
self.expert_load_pass.zero_() for eplb_model_state in self.model_states.values():
eplb_model_state.expert_load_pass.zero_()
if log_stats: if log_stats:
# total_expert_load_pass: (num_moe_layers, num_physical_experts) # Sync the expert load pass for each model (main and drafter).
total_expert_load_pass = self.expert_load_pass.clone() # expert_load_pass: (num_moe_layers, num_physical_experts)
expert_load_pass_list = self._sync_load_pass()
# Collect load metrics from all ranks
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
all_reduce(total_expert_load_pass, group=ep_group) for expert_load_pass, eplb_model_state in zip(
expert_load_pass_list, self.model_states.values()
# num_tokens_per_rank: (num_moe_layers, num_ranks) ):
num_tokens_per_rank = ( # num_tokens_per_rank: (num_moe_layers, num_ranks)
total_expert_load_pass.reshape( num_tokens_per_rank = (
total_expert_load_pass.shape[0], ep_group.size(), -1 expert_load_pass.reshape(
expert_load_pass.shape[0], ep_group.size(), -1
)
.sum(dim=-1)
.float()
) )
.sum(dim=-1)
.float()
)
# Compute balancedness ratio: # Compute balancedness ratio:
# for each layer: # for each layer:
# (mean load across ranks) / (max load across ranks) # (mean load across ranks) / (max load across ranks)
avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0) max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0)
# Just to make type checker happy # Just to make type checker happy
tokens_tensors: list[float] = torch.stack( tokens_tensors: list[float] = torch.stack(
[avg_tokens_tensor, max_tokens_tensor] [avg_tokens_tensor, max_tokens_tensor]
).tolist() ).tolist()
avg_tokens, max_tokens = tokens_tensors avg_tokens, max_tokens = tokens_tensors
balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0
if ep_group.rank() == 0: if ep_group.rank() == 0:
logger.info( logger.info(
"EPLB step: avg_tokens=%.2f, max_tokens=%d, balancedness=%.4f", "EPLB step: %d for model %s: avg_tokens=%.2f, "
avg_tokens, "max_tokens=%d, balancedness=%.4f",
max_tokens, self.expert_rearrangement_step,
balancedness, eplb_model_state.model_name,
) avg_tokens,
max_tokens,
balancedness,
)
# Update the expert load sliding window # Update the expert load sliding window
if not is_dummy: if not is_dummy:
self.expert_load_window[self.expert_load_window_step] = ( for eplb_model_state in self.model_states.values():
self.expert_load_pass.clone() eplb_model_state.expert_load_window[self.expert_load_window_step] = (
) eplb_model_state.expert_load_pass.clone()
)
eplb_model_state.expert_load_pass.zero_()
self.expert_load_window_step += 1 self.expert_load_window_step += 1
if self.expert_load_window_step >= self.expert_load_window_size: if self.expert_load_window_step >= self.expert_load_window_size:
self.expert_load_window_step = 0 self.expert_load_window_step = 0
self.expert_load_pass.zero_()
# Step the expert rearrangement step # Step the expert rearrangement step
# Note that even if this is a dummy step, we still increment the # Note that even if this is a dummy step, we still increment the
@ -431,18 +490,30 @@ class EplbState:
self.expert_rearrangement_step += 1 self.expert_rearrangement_step += 1
if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval: if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval:
self.expert_rearrangement_step = 0 self.expert_rearrangement_step = 0
self.rearrange(model) self.rearrange()
def rearrange( def rearrange(
self, self,
model: MixtureOfExperts,
is_profile: bool = False, is_profile: bool = False,
execute_shuffle: bool = True, execute_shuffle: bool = True,
global_expert_load: torch.Tensor | None = None, global_expert_loads: list[torch.Tensor] | None = None,
rank_mapping: dict[int, int] | None = None, rank_mapping: dict[int, int] | None = None,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
""" """
Rearrange the experts according to the current load. Rearrange the experts according to the current load.
Args:
is_profile (bool): If `True`, perform a dummy rearrangement.
This is used in `profile_run` to reserve enough memory,
no memory movement will be performed. Default is False.
execute_shuffle (bool): If `True`, execute the shuffle
in elastic expert parallel (EEP). Default is True.
global_expert_loads (list[torch.Tensor] | None): The global expert
loads when scaling is done in EEP.
List of expert loads for the main and drafter
(when spec decode is used) models.
rank_mapping (dict[int, int] | None): The rank mapping
when scaling is done in EEP.
""" """
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
@ -455,53 +526,71 @@ class EplbState:
time_start = time.perf_counter() time_start = time.perf_counter()
logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") logger.info("Rearranging experts %s...", "(profile)" if is_profile else "")
if global_expert_load is None: if global_expert_loads is None:
# Map the physical expert load to global logical experts # Map the physical expert load to global logical experts
logical_expert_load_window = torch.zeros( global_expert_load_windows = []
self.expert_load_window_size,
model.num_moe_layers,
model.num_logical_experts,
dtype=self.expert_load_window.dtype,
device=self.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=self.physical_to_logical_map.unsqueeze(0)
.expand_as(self.expert_load_window)
.long(),
src=self.expert_load_window,
)
if not execute_shuffle: if not execute_shuffle:
metadata = torch.tensor( num_models = torch.tensor(
[ [len(self.model_states)], dtype=torch.int32, device="cpu"
model.num_moe_layers,
model.num_logical_experts,
self.physical_to_logical_map.shape[1],
],
dtype=torch.int32,
device="cpu",
) )
torch.distributed.broadcast( torch.distributed.broadcast(
metadata, group=get_ep_group().cpu_group, group_src=0 num_models, group=get_ep_group().cpu_group, group_src=0
) )
# Perform all-reduce to get the expert load across all ranks for eplb_model_state in self.model_states.values():
global_expert_load_window = logical_expert_load_window.sum(dim=0) logical_expert_load_window = torch.zeros(
all_reduce(global_expert_load_window, group=ep_group) self.expert_load_window_size,
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
dtype=eplb_model_state.expert_load_window.dtype,
device=eplb_model_state.expert_load_window.device,
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
.expand_as(eplb_model_state.expert_load_window)
.long(),
src=eplb_model_state.expert_load_window,
)
if not execute_shuffle:
metadata = torch.tensor(
[
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
eplb_model_state.physical_to_logical_map.shape[1],
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(
metadata, group=get_ep_group().cpu_group, group_src=0
)
global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(
global_expert_load_windows
)
if not execute_shuffle: if not execute_shuffle:
# (num_moe_layers, old_num_physical_experts) for eplb_model_state, global_expert_load_window in zip(
old_global_expert_indices = self.physical_to_logical_map self.model_states.values(), global_expert_load_windows
torch.distributed.broadcast( ):
old_global_expert_indices, group=ep_group, group_src=0 # (num_moe_layers, old_num_physical_experts)
) old_global_expert_indices = eplb_model_state.physical_to_logical_map
return global_expert_load_window torch.distributed.broadcast(
old_global_expert_indices, group=ep_group, group_src=0
)
if not execute_shuffle:
return global_expert_load_windows
else: else:
assert execute_shuffle assert execute_shuffle
global_expert_load_window = global_expert_load global_expert_load_windows = global_expert_loads
# TODO(bowen): Treat differently for prefill and decode nodes # TODO(bowen): Treat differently for prefill and decode nodes
eplb_model_state = next(iter(self.model_states.values()))
model = eplb_model_state.model
num_replicas = model.num_physical_experts num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups num_groups = model.num_expert_groups
if rank_mapping is not None and len(rank_mapping) == ep_group.size(): if rank_mapping is not None and len(rank_mapping) == ep_group.size():
@ -526,48 +615,64 @@ class EplbState:
f"{num_gpus=}, {num_nodes=}" f"{num_gpus=}, {num_nodes=}"
) )
# Get new expert mappings for eplb_model_state, global_expert_load_window in zip(
( self.model_states.values(), global_expert_load_windows
new_physical_to_logical_map, ):
new_logical_to_physical_map, # Get new expert mappings for the model
new_logical_replica_count, (
) = rebalance_experts( new_physical_to_logical_map,
global_expert_load_window,
num_replicas,
num_groups,
num_nodes,
num_gpus,
)
# Update expert weights
rearrange_expert_weights_inplace(
self.physical_to_logical_map,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
is_profile,
rank_mapping,
)
if not is_profile:
if (
self.physical_to_logical_map.shape[1]
!= new_physical_to_logical_map.shape[1]
):
self.physical_to_logical_map = new_physical_to_logical_map.to(
self.physical_to_logical_map.device
)
else:
self.physical_to_logical_map.copy_(new_physical_to_logical_map)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map, new_logical_to_physical_map,
(0, self.logical_to_physical_map.shape[-1] - max_physical_slots), new_logical_replica_count,
value=-1, ) = rebalance_experts(
global_expert_load_window,
num_replicas,
num_groups,
num_nodes,
num_gpus,
) )
self.logical_to_physical_map.copy_(new_logical_to_physical_map)
self.logical_replica_count.copy_(new_logical_replica_count) # Update expert weights
rearrange_expert_weights_inplace(
eplb_model_state.physical_to_logical_map,
new_physical_to_logical_map,
eplb_model_state.model.expert_weights,
ep_group,
is_profile,
rank_mapping,
)
if not is_profile:
if (
eplb_model_state.physical_to_logical_map.shape[1]
!= new_physical_to_logical_map.shape[1]
):
eplb_model_state.physical_to_logical_map = (
new_physical_to_logical_map.to(
eplb_model_state.physical_to_logical_map.device
)
)
else:
eplb_model_state.physical_to_logical_map.copy_(
new_physical_to_logical_map
)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert (
max_physical_slots
<= eplb_model_state.logical_to_physical_map.shape[-1]
)
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(
0,
eplb_model_state.logical_to_physical_map.shape[-1]
- max_physical_slots,
),
value=-1,
)
eplb_model_state.logical_to_physical_map.copy_(
new_logical_to_physical_map
)
eplb_model_state.logical_replica_count.copy_(new_logical_replica_count)
if is_main_rank: if is_main_rank:
assert time_start is not None assert time_start is not None
@ -581,32 +686,118 @@ class EplbState:
return None return None
@staticmethod @staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]: def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
""" """
Receive the expert load and old placement from the master rank. Receive the expert load and old placement from the master rank.
""" """
ep_group = get_ep_group() ep_group = get_ep_group()
metadata = torch.empty(3, dtype=torch.int32, device="cpu") num_models = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0) torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = ( num_models = num_models.item()
metadata.tolist() global_expert_loads = []
) old_global_expert_indices_per_model = []
global_expert_load = torch.zeros( for _ in range(num_models):
(num_moe_layers, num_logical_experts), metadata = torch.empty(3, dtype=torch.int32, device="cpu")
dtype=torch.int64, torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
device=ep_group.device, num_moe_layers, num_logical_experts, num_old_physical_experts = (
) metadata.tolist()
all_reduce(global_expert_load, group=ep_group.device_group) )
old_global_expert_indices = torch.empty( global_expert_load = torch.zeros(
(num_moe_layers, num_old_physical_experts), (num_moe_layers, num_logical_experts),
dtype=torch.int64, dtype=torch.int64,
device=ep_group.device, device=ep_group.device,
) )
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(
old_global_expert_indices,
group=ep_group.device_group,
group_src=0,
)
global_expert_loads.append(global_expert_load)
old_global_expert_indices_per_model.append(old_global_expert_indices)
return global_expert_loads, old_global_expert_indices_per_model
@classmethod
def get_eep_state(
cls, parallel_config: ParallelConfig
) -> tuple[
list[torch.Tensor] | None,
list[torch.Tensor] | None,
dict[int, int] | None,
]:
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast( torch.distributed.broadcast(
old_global_expert_indices, group=ep_group.device_group, group_src=0 num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_loads, old_global_expert_indices_per_model = (
EplbState.recv_state()
) )
return global_expert_load, old_global_expert_indices # EP configuration for all models has to be the same so as eplb config
num_logical_experts = global_expert_loads[0].shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts
)
assert (
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
== 0
)
old_ep_size = (
old_global_expert_indices_per_model[0].shape[1]
// num_local_physical_experts
)
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
return (
global_expert_loads,
old_global_expert_indices_per_model,
rank_mapping,
)
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
"""
All-reduce a list of tensors.
"""
if len(tensor_list) == 1:
all_reduce(tensor_list[0], group=get_ep_group().device_group)
return tensor_list
assert all(t.dim() == 2 for t in tensor_list), "All tensors must be 2D."
assert all(t.shape[1] == tensor_list[0].shape[1] for t in tensor_list), (
"All tensors must have the same shape[1]."
)
# Concatenate, all_reduce, then unpack to original shapes.
# We assume all tensors are 2D and shape[1] (num_physical_experts)
# is the same across all models.
shapes = [t.shape for t in tensor_list]
concat_tensor = torch.cat(tensor_list, dim=0)
ep_group = get_ep_group().device_group
all_reduce(concat_tensor, group=ep_group)
all_reduce_list = []
offset = 0
for shape in shapes:
all_reduce_list.append(concat_tensor[offset : offset + shape[0], :])
offset += shape[0]
return all_reduce_list
def _sync_load_pass(self) -> list[torch.Tensor]:
"""
Sync the expert load pass across all ranks for log stats.
Doesn't update the expert load pass in eplb_model_state.
"""
load_pass_list = []
for eplb_model_state in self.model_states.values():
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
return self._allreduce_list(load_pass_list)
def _node_count_with_rank_mapping( def _node_count_with_rank_mapping(

View File

@ -226,7 +226,7 @@ class ToolParserManager:
if isinstance(name, str): if isinstance(name, str):
names = [name] names = [name]
elif is_list_of(name, str): elif name is not None and is_list_of(name, str):
names = name names = name
else: else:
names = [class_name] names = [class_name]

View File

@ -24,9 +24,12 @@ from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV3ForCausalLM, DeepseekV3ForCausalLM,
) )
from vllm.utils import init_logger
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
logger = init_logger(__name__)
@support_torch_compile @support_torch_compile
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
@ -215,6 +218,10 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM):
self.config.vocab_size, scale=logit_scale self.config.vocab_size, scale=logit_scale
) )
# Set MoE hyperparameters
self.num_moe_layers = self.config.num_hidden_layers
self.set_moe_parameters()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)

View File

@ -8,6 +8,7 @@ from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -25,11 +26,15 @@ from vllm.sequence import IntermediateTensors
from .deepseek_v2 import ( from .deepseek_v2 import (
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV2MixtureOfExperts,
DeepseekV2MoE,
get_spec_layer_idx_from_weight_name, get_spec_layer_idx_from_weight_name,
) )
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
logger = init_logger(__name__)
class SharedHead(nn.Module): class SharedHead(nn.Module):
def __init__( def __init__(
@ -119,6 +124,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
self.mtp_start_layer_idx = config.num_hidden_layers self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights # to map the exact layer index from weights
self.layers = torch.nn.ModuleDict( self.layers = torch.nn.ModuleDict(
{ {
str(idx): DeepSeekMultiTokenPredictorLayer( str(idx): DeepSeekMultiTokenPredictorLayer(
@ -172,13 +178,33 @@ class DeepSeekMultiTokenPredictor(nn.Module):
@support_torch_compile @support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP): class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.model = DeepSeekMultiTokenPredictor( self.model = DeepSeekMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
# Set MoE hyperparameters
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = []
self.num_moe_layers = self.config.num_nextn_predict_layers
self.num_expert_groups = self.config.n_group
self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None
for layer in self.model.layers.values():
assert isinstance(layer, DeepSeekMultiTokenPredictorLayer)
layer = layer.mtp_block
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)

View File

@ -166,7 +166,7 @@ class DeepseekV2MoE(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts self.n_shared_experts: int = config.n_shared_experts
@ -1122,7 +1122,6 @@ class DeepseekV2Model(nn.Module):
) )
else: else:
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer( lambda prefix: DeepseekV2DecoderLayer(
@ -1172,7 +1171,50 @@ class DeepseekV2Model(nn.Module):
return hidden_states return hidden_states
class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA): class DeepseekV2MixtureOfExperts(MixtureOfExperts):
moe_mlp_layers: list[DeepseekV2MoE]
"""
List of MoE MLP layers in the model.
"""
def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None):
if example_moe is None:
self.num_moe_layers = 0
self.num_expert_groups = 0
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_shared_experts = 0
self.num_redundant_experts = 0
logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.")
else:
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for moe in self.moe_mlp_layers:
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
class DeepseekV2ForCausalLM(
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
):
packed_modules_mapping = { packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
} }
@ -1213,13 +1255,19 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
# Set MoE hyperparameters
self.num_moe_layers = (
self.config.num_hidden_layers - self.config.first_k_dense_replace
)
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = [] self.expert_weights = []
# Set MoE hyperparameters self.num_expert_groups = self.config.n_group
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
@ -1229,50 +1277,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoR
if isinstance(layer.mlp, DeepseekV2MoE): if isinstance(layer.mlp, DeepseekV2MoE):
# Pick last one layer since the first ones may be dense layers. # Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts) self.moe_layers.append(layer.mlp.experts)
if example_moe is None: self.extract_moe_parameters(example_moe)
raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.mlp, DeepseekV2MoE):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)

View File

@ -133,7 +133,7 @@ class Ernie4_5_MoeMoE(nn.Module):
self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None) self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None)
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.moe_num_experts self.n_routed_experts: int = config.moe_num_experts
self.n_shared_experts: int = self.moe_num_shared_experts self.n_shared_experts: int = self.moe_num_shared_experts
@ -709,22 +709,6 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
self.num_shared_experts = example_moe.n_shared_experts self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,

View File

@ -62,7 +62,7 @@ from vllm.model_executor.model_loader.weight_utils import (
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
@ -127,7 +127,7 @@ class Glm4MoE(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts self.n_shared_experts: int = config.n_shared_experts
@ -616,7 +616,35 @@ class Glm4MoeModel(nn.Module):
return loaded_params return loaded_params
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): class Glm4MixtureOfExperts(MixtureOfExperts):
def extract_moe_parameters(self, example_moe: Glm4MoE | None) -> None:
if example_moe is None:
raise RuntimeError("No Glm4MoE layer found in model.layers.")
else:
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for moe in self.moe_mlp_layers:
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExperts):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@ -659,7 +687,9 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = config.n_group self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
self.moe_mlp_layers: list[Glm4MoE] = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
@ -669,33 +699,10 @@ class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
if isinstance(layer.mlp, Glm4MoE): if isinstance(layer.mlp, Glm4MoE):
# Pick last one layer since the first ones may be dense layers. # Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts) self.moe_layers.append(layer.mlp.experts)
if example_moe is None: self.extract_moe_parameters(example_moe)
raise RuntimeError("No Glm4MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)

View File

@ -29,7 +29,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -41,7 +41,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name from .glm4_moe import (
Glm4MixtureOfExperts,
Glm4MoE,
Glm4MoeDecoderLayer,
get_spec_layer_idx_from_weight_name,
)
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
@ -73,6 +78,7 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module):
prefix: str, prefix: str,
cache_config: CacheConfig | None = None, cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -81,11 +87,13 @@ class Glm4MoeMultiTokenPredictorLayer(nn.Module):
self.shared_head = SharedHead( self.shared_head = SharedHead(
config=config, prefix=prefix, quant_config=quant_config config=config, prefix=prefix, quant_config=quant_config
) )
self.enable_eplb = parallel_config.enable_eplb
self.mtp_block = Glm4MoeDecoderLayer( self.mtp_block = Glm4MoeDecoderLayer(
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
enable_eplb=self.enable_eplb,
) )
def forward( def forward(
@ -127,6 +135,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
f"{prefix}.layers.{idx}", f"{prefix}.layers.{idx}",
cache_config=vllm_config.cache_config, cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config, quant_config=vllm_config.quant_config,
parallel_config=vllm_config.parallel_config,
) )
for idx in range( for idx in range(
self.mtp_start_layer_idx, self.mtp_start_layer_idx,
@ -175,7 +184,7 @@ class Glm4MoeMultiTokenPredictor(nn.Module):
return logits return logits
class Glm4MoeMTP(nn.Module, SupportsPP): class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
@ -183,6 +192,25 @@ class Glm4MoeMTP(nn.Module, SupportsPP):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
self.expert_weights = []
# Set MoE hyperparameters
self.num_moe_layers = self.config.num_nextn_predict_layers
self.num_expert_groups = self.config.n_group
self.moe_layers: list[FusedMoE] = []
self.moe_mlp_layers: list[Glm4MoE] = []
example_moe = None
for layer in self.model.layers.values():
assert isinstance(layer, Glm4MoeMultiTokenPredictorLayer)
layer = layer.mtp_block
assert isinstance(layer, Glm4MoeDecoderLayer)
if isinstance(layer.mlp, Glm4MoE):
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)

View File

@ -374,7 +374,7 @@ class HunYuanSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
@ -1007,7 +1007,7 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.expert_weights = []
self.num_expert_groups = 1 self.num_expert_groups = 1
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
example_layer = None example_layer = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
@ -1028,22 +1028,6 @@ class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts):
self.num_routed_experts = example_layer.n_routed_experts self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
self.expert_weights.append(layer.get_expert_weights())
# Register the expert weights.
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,

View File

@ -14,6 +14,7 @@ from typing import (
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from torch import Tensor from torch import Tensor
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.models.whisper.tokenization_whisper import LANGUAGES from transformers.models.whisper.tokenization_whisper import LANGUAGES
@ -641,6 +642,9 @@ class MixtureOfExperts(Protocol):
num_redundant_experts: int num_redundant_experts: int
"""Number of redundant experts in this model.""" """Number of redundant experts in this model."""
moe_layers: Iterable[nn.Module]
"""List of MoE layers in this model."""
def set_eplb_state( def set_eplb_state(
self, self,
expert_load_view: Tensor, expert_load_view: Tensor,
@ -663,7 +667,15 @@ class MixtureOfExperts(Protocol):
logical_to_physical_map: Mapping from logical to physical experts. logical_to_physical_map: Mapping from logical to physical experts.
logical_replica_count: Count of replicas for each logical expert. logical_replica_count: Count of replicas for each logical expert.
""" """
... for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,

View File

@ -105,7 +105,7 @@ class Lfm2MoeSparseMoeBlock(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
@ -707,7 +707,7 @@ class Lfm2MoeForCausalLM(
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.expert_weights = []
self.moe_layers: list[FusedMoE] = [] self.moe_layers = []
example_layer = None example_layer = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
@ -737,22 +737,6 @@ class Lfm2MoeForCausalLM(
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,

View File

@ -30,9 +30,11 @@ from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
get_ep_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
@ -46,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from vllm.model_executor.models.interfaces import MixtureOfExperts
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
@ -56,6 +59,8 @@ from .utils import (
is_pp_missing_parameter, is_pp_missing_parameter,
) )
logger = init_logger(__name__)
class Llama4MoE(nn.Module): class Llama4MoE(nn.Module):
@staticmethod @staticmethod
@ -80,6 +85,9 @@ class Llama4MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.ep_group = get_ep_group().device_group
self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size()
intermediate_size_moe = config.intermediate_size intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear( self.router = ReplicatedLinear(
@ -101,6 +109,20 @@ class Llama4MoE(nn.Module):
disable_tp=self.is_sequence_parallel, disable_tp=self.is_sequence_parallel,
) )
# Load balancing settings.
eplb_config = parallel_config.eplb_config if parallel_config else None
self.enable_eplb = parallel_config.enable_eplb if parallel_config else False
self.n_redundant_experts = (
eplb_config.num_redundant_experts if eplb_config else 0
)
self.n_routed_experts: int = config.num_local_experts
self.n_logical_experts = self.n_routed_experts
self.n_shared_experts: int = 1
self.n_local_experts: int = config.num_local_experts
self.n_physical_experts = self.n_local_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
shared_experts=self.shared_expert, shared_experts=self.shared_expert,
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
@ -114,6 +136,8 @@ class Llama4MoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel, is_sequence_parallel=self.is_sequence_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
@ -378,6 +402,9 @@ class Llama4Model(LlamaModel):
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer, layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
): ):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts self.num_experts = vllm_config.model_config.hf_config.num_local_experts
self.n_redundant_experts = (
vllm_config.parallel_config.eplb_config.num_redundant_experts
)
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def load_moe_expert_weights( def load_moe_expert_weights(
@ -499,7 +526,6 @@ class Llama4Model(LlamaModel):
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
) )
loaded_params.add(full_param_name) loaded_params.add(full_param_name)
expert_param_loaded = True expert_param_loaded = True
@ -526,6 +552,7 @@ class Llama4Model(LlamaModel):
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.num_experts, num_experts=self.num_experts,
num_redundant_experts=self.n_redundant_experts,
) )
# Expert parameter mapping for the case where the expert weights are # Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor. # fused into a single weight tensor.
@ -683,7 +710,7 @@ class Llama4Model(LlamaModel):
return loaded_params return loaded_params
class Llama4ForCausalLM(LlamaForCausalLM): class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
@ -702,6 +729,57 @@ class Llama4ForCausalLM(LlamaForCausalLM):
super().__init__( super().__init__(
vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer
) )
# Set MoE hyperparameters
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = []
self.moe_layers = []
example_moe = None
for layer in self.model.layers:
assert isinstance(layer, Llama4DecoderLayer)
if isinstance(layer.feed_forward, Llama4MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.feed_forward
self.moe_layers.append(layer.feed_forward.experts)
if example_moe is None:
self.num_moe_layers = 0
self.num_expert_groups = 0
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_shared_experts = 0
self.num_redundant_experts = 0
logger.warning("No Llama4MoE layer found in model.layers.")
else:
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.feed_forward, Llama4MoE):
moe = layer.feed_forward
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def _init_model( def _init_model(
self, self,

View File

@ -189,6 +189,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
self.config.vocab_size, scale=logit_scale self.config.vocab_size, scale=logit_scale
) )
# Set MoE hyperparameters
self.set_moe_parameters()
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model

View File

@ -578,6 +578,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
parallel_config = vllm_config.parallel_config
self.prefix = prefix self.prefix = prefix
self.vllm_config = vllm_config self.vllm_config = vllm_config
@ -613,6 +614,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
if parallel_config.enable_eplb and getattr(config, "num_experts", 0) > 0:
raise NotImplementedError("EPLB is not supported for MiniCPM yet.")
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) return MiniCPMModel(vllm_config=vllm_config, prefix=prefix)

View File

@ -98,7 +98,7 @@ class MixtralMoE(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
# Expert Parallelism Load balancing settings. # Expert Parallelism Load balancing settings.
@ -546,7 +546,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
) )
self.expert_weights = [] self.expert_weights = []
self.moe_layers: list[FusedMoE] = [] self.moe_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
@ -572,22 +572,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
self.num_expert_groups = 1 self.num_expert_groups = 1
self.num_shared_experts = 0 self.num_shared_experts = 0
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,

View File

@ -65,6 +65,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
MixtureOfExperts,
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsEagle3, SupportsEagle3,
SupportsMultiModal, SupportsMultiModal,
@ -723,7 +724,7 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
dummy_inputs=Mllama4DummyInputsBuilder, dummy_inputs=Mllama4DummyInputsBuilder,
) )
class Llama4ForConditionalGeneration( class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3
): ):
merge_by_field_config = True merge_by_field_config = True
@ -776,6 +777,17 @@ class Llama4ForConditionalGeneration(
self.language_model.make_empty_intermediate_tensors self.language_model.make_empty_intermediate_tensors
) )
# Set MoE hyperparameters
self.num_expert_groups = 1
self.num_logical_experts = self.language_model.num_logical_experts
self.num_physical_experts = self.language_model.num_physical_experts
self.num_local_physical_experts = self.language_model.num_local_physical_experts
self.num_routed_experts = self.language_model.num_routed_experts
self.num_shared_experts = self.language_model.num_shared_experts
self.num_redundant_experts = self.language_model.num_redundant_experts
self.moe_layers = self.language_model.moe_layers
self.num_moe_layers = len(self.moe_layers)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3.""" """Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM) # Delegate to underlying language model (Llama4ForCausalLM)
@ -792,6 +804,24 @@ class Llama4ForConditionalGeneration(
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers") assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
return self.language_model.get_eagle3_aux_hidden_state_layers() return self.language_model.get_eagle3_aux_hidden_state_layers()
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
):
self.language_model.set_eplb_state(
expert_load_view, logical_to_physical_map, logical_replica_count
)
self.expert_weights = self.language_model.expert_weights
def update_physical_experts_metadata(
self, num_physical_experts: int, num_local_physical_experts: int
):
self.language_model.update_physical_experts_metadata(
num_physical_experts, num_local_physical_experts
)
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
) -> Llama4ImagePatchInputs | None: ) -> Llama4ImagePatchInputs | None:

View File

@ -807,7 +807,7 @@ class NemotronHForCausalLM(
self.expert_weights = [] self.expert_weights = []
self.num_expert_groups = config.n_group self.num_expert_groups = config.n_group
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, NemotronHMoEDecoderLayer): if isinstance(layer, NemotronHMoEDecoderLayer):
@ -824,22 +824,6 @@ class NemotronHForCausalLM(
self.num_shared_experts = example_moe.n_shared_experts self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,

View File

@ -1009,7 +1009,7 @@ class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts):
self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
self.num_expert_groups = 1 self.num_expert_groups = 1
self.moe_layers: list[SharedFusedMoE] = [] self.moe_layers = []
example_moe = None example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
@ -1031,22 +1031,6 @@ class OpenPanguMoEModel(OpenPanguModelBase, MixtureOfExperts):
self.n_shared_experts = example_moe.n_shared_experts self.n_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts self.num_redundant_experts = example_moe.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,

View File

@ -132,7 +132,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
@ -665,7 +665,7 @@ class Qwen3MoeForCausalLM(
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.expert_weights = []
self.moe_layers: list[FusedMoE] = [] self.moe_layers = []
example_layer = None example_layer = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer): if isinstance(layer, PPMissingLayer):
@ -688,22 +688,6 @@ class Qwen3MoeForCausalLM(
self.num_routed_experts = example_layer.n_routed_experts self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata( def update_physical_experts_metadata(
self, self,
num_physical_experts: int, num_physical_experts: int,

View File

@ -107,7 +107,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank() self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
@ -1095,8 +1095,57 @@ class Qwen3NextModel(nn.Module):
return loaded_params return loaded_params
class QwenNextMixtureOfExperts(MixtureOfExperts):
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def set_moe_parameters(self):
self.expert_weights = []
self.moe_layers = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
layer.mlp, Qwen3NextSparseMoeBlock
):
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_moe is None:
raise RuntimeError("No Qwen3Next layer found in the model.layers.")
# Set MoE hyperparameters
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_redundant_experts = example_moe.n_redundant_experts
class Qwen3NextForCausalLM( class Qwen3NextForCausalLM(
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, MixtureOfExperts, IsHybrid nn.Module,
HasInnerState,
SupportsLoRA,
SupportsPP,
QwenNextMixtureOfExperts,
IsHybrid,
): ):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
@ -1147,63 +1196,7 @@ class Qwen3NextForCausalLM(
) )
# Set MoE hyperparameters # Set MoE hyperparameters
self.expert_weights = [] self.set_moe_parameters()
self.moe_layers: list[SharedFusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3NextDecoderLayer)
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No Qwen3Next layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
# Register the expert weights.
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)

View File

@ -23,6 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_next import ( from vllm.model_executor.models.qwen3_next import (
Qwen3NextDecoderLayer, Qwen3NextDecoderLayer,
Qwen3NextRMSNorm, Qwen3NextRMSNorm,
QwenNextMixtureOfExperts,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.transformers_utils.configs import Qwen3NextConfig
@ -226,7 +227,7 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
@support_torch_compile @support_torch_compile
class Qwen3NextMTP(nn.Module, SupportsPP): class Qwen3NextMTP(nn.Module, SupportsPP, QwenNextMixtureOfExperts):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@ -265,6 +266,7 @@ class Qwen3NextMTP(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors self.model.make_empty_intermediate_tensors
) )
self.set_moe_parameters()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)

View File

@ -125,7 +125,7 @@ class MoEMixin(MixtureOfExperts):
logical_to_physical_map: torch.Tensor, logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor, logical_replica_count: torch.Tensor,
): ):
for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers): for moe_layer_idx, mlp_layer in enumerate(self.mlp_moe_layers):
mlp_layer.experts.set_eplb_state( mlp_layer.experts.set_eplb_state(
moe_layer_idx=moe_layer_idx, moe_layer_idx=moe_layer_idx,
expert_load_view=expert_load_view, expert_load_view=expert_load_view,
@ -142,7 +142,7 @@ class MoEMixin(MixtureOfExperts):
self.num_physical_experts = num_physical_experts self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for mlp in self.mlp_layers: for mlp in self.mlp_moe_layers:
mlp.n_local_physical_experts = num_local_physical_experts mlp.n_local_physical_experts = num_local_physical_experts
mlp.n_physical_experts = num_physical_experts mlp.n_physical_experts = num_physical_experts
mlp.n_redundant_experts = self.num_redundant_experts mlp.n_redundant_experts = self.num_redundant_experts
@ -240,7 +240,8 @@ class MoEMixin(MixtureOfExperts):
# MixtureOfExperts mixin settings # MixtureOfExperts mixin settings
ep_size = get_ep_group().world_size ep_size = get_ep_group().world_size
self.mlp_layers = [] # Used for MixtureOfExperts methods self.mlp_moe_layers = [] # Used for MixtureOfExperts methods
self.moe_layers = []
self.expert_weights = [] self.expert_weights = []
self.num_moe_layers = 0 self.num_moe_layers = 0
self.num_expert_groups = 1 if num_expert_group is None else num_expert_group self.num_expert_groups = 1 if num_expert_group is None else num_expert_group
@ -298,7 +299,8 @@ class MoEMixin(MixtureOfExperts):
mlp.experts = fused_experts mlp.experts = fused_experts
log_replacement(qual_name, experts, fused_experts) log_replacement(qual_name, experts, fused_experts)
# Update MixtureOfExperts mixin state # Update MixtureOfExperts mixin state
self.mlp_layers.append(mlp) self.mlp_moe_layers.append(mlp)
self.moe_layers.append(fused_experts)
self.expert_weights.append(fused_experts.get_expert_weights()) self.expert_weights.append(fused_experts.get_expert_weights())
self.num_moe_layers += 1 self.num_moe_layers += 1
# If results are not all-reduced in FusedMoE, ensure they # If results are not all-reduced in FusedMoE, ensure they

View File

@ -8,6 +8,7 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
# Initialize logger # Initialize logger
@ -56,6 +57,10 @@ class MedusaProposer:
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
model_config=self.vllm_config.speculative_config.draft_model_config, model_config=self.vllm_config.speculative_config.draft_model_config,
) )
assert not (
is_mixture_of_experts(self.model)
and self.vllm_config.parallel_config.enable_eplb
), "EPLB for Medusa is not supported"
@torch.inference_mode() @torch.inference_mode()
def dummy_run(self, num_tokens: int) -> None: def dummy_run(self, num_tokens: int) -> None:

View File

@ -2046,7 +2046,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model = self.get_model() model = self.get_model()
assert is_mixture_of_experts(model) assert is_mixture_of_experts(model)
self.eplb_state.step( self.eplb_state.step(
model,
is_dummy, is_dummy,
is_profile, is_profile,
log_stats=self.parallel_config.eplb_config.log_balancedness, log_stats=self.parallel_config.eplb_config.log_balancedness,
@ -2803,7 +2802,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else: else:
indices = [] indices = []
offset = 0 offset = 0
assert spec_decode_metadata is not None assert spec_decode_metadata is not None, (
"No spec decode metadata for medusa"
)
for num_draft, tokens in zip( for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens, sampled_token_ids spec_decode_metadata.num_draft_tokens, sampled_token_ids
): ):
@ -2934,32 +2935,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model_config.model, self.model_config.model,
scope="global", scope="global",
) )
if eep_scale_up: global_expert_loads, old_global_expert_indices_per_model, rank_mapping = (
from vllm.distributed.parallel_state import get_ep_group EplbState.get_eep_state(self.parallel_config)
if eep_scale_up
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") else (None, None, None)
torch.distributed.broadcast( )
num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0
)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_load, old_global_expert_indices = EplbState.recv_state()
num_logical_experts = global_expert_load.shape[1]
self.parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts
)
assert old_global_expert_indices.shape[1] % num_local_physical_experts == 0
old_ep_size = (
old_global_expert_indices.shape[1] // num_local_physical_experts
)
rank_mapping = {
old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)
}
else:
global_expert_load = None
old_global_expert_indices = None
rank_mapping = None
if self.parallel_config.enable_eplb:
self.eplb_state = EplbState(self.parallel_config, self.device)
eplb_models = 0
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
time_before_load = time.perf_counter() time_before_load = time.perf_counter()
model_loader = get_model_loader(self.load_config) model_loader = get_model_loader(self.load_config)
@ -2971,8 +2955,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.model, self.vllm_config, self.device self.model, self.vllm_config, self.device
) )
if hasattr(self, "drafter"): if hasattr(self, "drafter"):
logger.info("Loading drafter model...") logger.info_once("Loading drafter model...")
self.drafter.load_model(self.model) self.drafter.load_model(self.model)
if (
hasattr(self.drafter, "model")
and is_mixture_of_experts(self.drafter.model)
and self.parallel_config.enable_eplb
):
logger.info_once(
"EPLB is enabled for drafter model %s.",
self.vllm_config.speculative_config.draft_model_config.model,
)
global_expert_load = (
global_expert_loads[eplb_models]
if global_expert_loads
else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
if self.eplb_state is None:
self.eplb_state = EplbState(self.parallel_config, self.device)
self.eplb_state.add_model(
self.drafter.model,
self.vllm_config.speculative_config.draft_model_config,
global_expert_load,
old_global_expert_indices,
rank_mapping,
)
eplb_models += 1
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
if not supports_eagle3(self.get_model()): if not supports_eagle3(self.get_model()):
raise RuntimeError( raise RuntimeError(
@ -3001,18 +3016,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scope="local", scope="local",
) )
prepare_communication_buffer_for_model(self.model) prepare_communication_buffer_for_model(self.model)
self.is_multimodal_pruning_enabled = ( self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model()) supports_multimodal_pruning(self.get_model())
and self.model_config.multimodal_config.is_multimodal_pruning_enabled() and self.model_config.multimodal_config.is_multimodal_pruning_enabled()
) )
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
logger.info("EPLB is enabled for model %s.", self.model_config.model) logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
self.eplb_state = EplbState.build( global_expert_load = (
global_expert_loads[eplb_models] if global_expert_loads else None
)
old_global_expert_indices = (
old_global_expert_indices_per_model[eplb_models]
if old_global_expert_indices_per_model
else None
)
assert self.eplb_state is not None
self.eplb_state.add_model(
self.model, self.model,
self.device, self.model_config,
self.parallel_config,
global_expert_load, global_expert_load,
old_global_expert_indices, old_global_expert_indices,
rank_mapping, rank_mapping,

View File

@ -32,6 +32,7 @@ from vllm.distributed.parallel_state import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -613,7 +614,6 @@ class Worker(WorkerBase):
} }
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange( self.model_runner.eplb_state.rearrange(
self.model_runner.model,
execute_shuffle=True, execute_shuffle=True,
global_expert_load=None, global_expert_load=None,
rank_mapping=rank_mapping, rank_mapping=rank_mapping,
@ -626,7 +626,7 @@ class Worker(WorkerBase):
self, self,
old_ep_size: int, old_ep_size: int,
new_ep_size: int, new_ep_size: int,
global_expert_load: torch.Tensor | None, global_expert_loads: list[torch.Tensor] | None,
) -> None: ) -> None:
from vllm.distributed.parallel_state import get_ep_group from vllm.distributed.parallel_state import get_ep_group
@ -635,9 +635,8 @@ class Worker(WorkerBase):
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange( self.model_runner.eplb_state.rearrange(
self.model_runner.model,
execute_shuffle=True, execute_shuffle=True,
global_expert_load=global_expert_load, global_expert_loads=global_expert_loads,
rank_mapping=rank_mapping, rank_mapping=rank_mapping,
) )
if get_ep_group().rank == 0: if get_ep_group().rank == 0:
@ -684,31 +683,56 @@ class Worker(WorkerBase):
get_ep_group, get_ep_group,
prepare_communication_buffer_for_model, prepare_communication_buffer_for_model,
) )
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEParallelConfig,
)
parallel_config = self.vllm_config.parallel_config parallel_config = self.vllm_config.parallel_config
moe_modules = [
module def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
for module in self.model_runner.model.modules() return [
if ( module
module.__class__.__name__ == "FusedMoE" for module in model.modules()
or module.__class__.__name__ == "SharedFusedMoE" if (
) module.__class__.__name__ == "FusedMoE"
] or module.__class__.__name__ == "SharedFusedMoE"
num_local_experts = moe_modules[0].moe_config.num_local_experts )
assert all( ]
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
), "All MoE modules must have the same number of experts" assert all(
for module in moe_modules: module.moe_config.num_local_experts == num_local_experts
module.moe_config.num_experts = num_local_experts * new_ep_size for module in moe_modules
module.global_num_experts = module.moe_config.num_experts ), "All MoE modules must have the same number of experts"
module.moe_parallel_config = FusedMoEParallelConfig.make( for module in moe_modules:
tp_size_=get_tp_group().world_size, module.moe_config.num_experts = num_local_experts * new_ep_size
dp_size_=get_dp_group().world_size, module.global_num_experts = module.moe_config.num_experts
vllm_parallel_config=parallel_config, module.moe_parallel_config = FusedMoEParallelConfig.make(
) tp_size_=get_tp_group().world_size,
module.moe_config.moe_parallel_config = module.moe_parallel_config dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
return moe_modules
model_moe_modules = get_moe_modules(self.model_runner.model)
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
update_moe_modules(model_moe_modules, num_local_experts)
drafter_model = None
if hasattr(self.model_runner, "drafter") and hasattr(
self.model_runner.drafter, "model"
):
drafter_model = self.model_runner.drafter.model
if drafter_model is not None and is_mixture_of_experts(drafter_model):
drafter_moe_modules = get_moe_modules(drafter_model)
# Check if drafter and model have matching configs
assert (
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
), "Drafter and model configs should be the same"
update_moe_modules(drafter_moe_modules, num_local_experts)
if new_ep_size < old_ep_size: if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
@ -719,7 +743,7 @@ class Worker(WorkerBase):
new_physical_experts new_physical_experts
- self.model_runner.eplb_state.logical_replica_count.shape[1] - self.model_runner.eplb_state.logical_replica_count.shape[1]
) )
global_expert_load = None global_expert_loads = None
else: else:
num_local_physical_experts = torch.tensor( num_local_physical_experts = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu" [num_local_experts], dtype=torch.int32, device="cpu"
@ -730,18 +754,20 @@ class Worker(WorkerBase):
num_local_physical_experts = num_local_physical_experts.item() num_local_physical_experts = num_local_physical_experts.item()
new_physical_experts = num_local_physical_experts * new_ep_size new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None assert self.model_runner.eplb_state is not None
global_expert_load = self.model_runner.eplb_state.rearrange( global_expert_loads = self.model_runner.eplb_state.rearrange(
self.model_runner.model, execute_shuffle=False execute_shuffle=False
) )
parallel_config.eplb_config.num_redundant_experts = ( parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_load.shape[1] new_physical_experts - global_expert_loads[0].shape[1]
) )
prepare_communication_buffer_for_model(self.model_runner.model) prepare_communication_buffer_for_model(self.model_runner.model)
if drafter_model is not None:
prepare_communication_buffer_for_model(drafter_model)
self.model_runner.model.update_physical_experts_metadata( self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts, num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts, num_local_physical_experts=num_local_physical_experts,
) )
return global_expert_load return global_expert_loads
def reinitialize_distributed( def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest self, reconfig_request: ReconfigureDistributedRequest
@ -782,11 +808,11 @@ class Worker(WorkerBase):
self.local_rank, self.local_rank,
) )
global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size) global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size: if new_ep_size > old_ep_size:
assert global_expert_load is not None assert global_expert_loads is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load) self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
def save_sharded_state( def save_sharded_state(
self, self,