[Feature] use --eplb_config to set eplb param (#20562)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: rongfu.leng <lenronfu@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
rongfu.leng 2025-08-21 05:07:28 +08:00 committed by GitHub
parent 4e51fa8cba
commit 4fbda0b20c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 149 additions and 52 deletions

View File

@ -33,7 +33,8 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo)
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
CUDAGraphMode, PassConfig)
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
ParallelConfig)
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger

View File

@ -6,7 +6,7 @@ from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch
from pydantic import model_validator
from pydantic import TypeAdapter, model_validator
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self
@ -32,6 +32,38 @@ logger = init_logger(__name__)
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
@config
@dataclass
class EPLBConfig:
"""Configuration for Expert Parallel Load Balancing (EP)."""
window_size: int = 1000
"""Window size for expert load recording."""
step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `lb_window_size` steps will be used for rearranging experts.
"""
num_redundant_experts: int = 0
"""Number of redundant experts to use for expert parallelism."""
log_balancedness: bool = False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
@classmethod
def from_cli(cls, cli_value: str) -> "EPLBConfig":
"""Parse the CLI value for the compilation config.
-O1, -O2, -O3, etc. is handled in FlexibleArgumentParser.
"""
return TypeAdapter(EPLBConfig).validate_json(cli_value)
@config
@dataclass
class ParallelConfig:
@ -75,22 +107,24 @@ class ParallelConfig:
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
"""Enable expert parallelism load balancing for MoE layers."""
num_redundant_experts: int = 0
"""Number of redundant experts to use for expert parallelism."""
eplb_window_size: int = 1000
"""Window size for expert load recording."""
eplb_step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `eplb_window_size` steps will be used for rearranging experts.
"""
eplb_log_balancedness: bool = False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
eplb_config: EPLBConfig = field(default_factory=EPLBConfig)
"""Expert parallelism configuration."""
num_redundant_experts: Optional[int] = None
"""`num_redundant_experts` is deprecated and has been replaced with
`eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
Please use `eplb_config.num_redundant_experts` instead."""
eplb_window_size: Optional[int] = None
"""`eplb_window_size` is deprecated and has been replaced with
`eplb_config.window_size`. This will be removed in v0.12.0.
Please use `eplb_config.window_size` instead."""
eplb_step_interval: Optional[int] = None
"""`eplb_step_interval` is deprecated and has been replaced with
`eplb_config.step_interval`. This will be removed in v0.12.0.
Please use `eplb_config.step_interval` instead."""
eplb_log_balancedness: Optional[bool] = None
"""`eplb_log_balancedness` is deprecated and has been replaced with
`eplb_config.log_balancedness`. This will be removed in v0.12.0.
Please use `eplb_config.log_balancedness` instead."""
max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallel loading workers when loading model
@ -237,6 +271,38 @@ class ParallelConfig:
return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None:
# Forward deprecated fields to their new location
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = (
self.num_redundant_experts)
logger.warning_once(
"num_redundant_experts is deprecated and has been replaced "
"with eplb_config.num_redundant_experts. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_window_size is not None:
self.eplb_config.window_size = self.eplb_window_size
logger.warning_once(
"eplb_window_size is deprecated and has been replaced "
"with eplb_config.window_size. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_step_interval is not None:
self.eplb_config.step_interval = self.eplb_step_interval
logger.warning_once(
"eplb_step_interval is deprecated and has been replaced "
"with eplb_config.step_interval. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_log_balancedness is not None:
self.eplb_config.log_balancedness = self.eplb_log_balancedness
logger.warning_once(
"eplb_log_balancedness is deprecated and has been replaced "
"with eplb_config.log_balancedness. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
# Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size
@ -275,10 +341,10 @@ class ParallelConfig:
raise ValueError(
"Expert parallelism load balancing is only supported on "
"CUDA devices now.")
if self.num_redundant_experts < 0:
if self.eplb_config.num_redundant_experts < 0:
raise ValueError(
"num_redundant_experts must be non-negative, but got "
f"{self.num_redundant_experts}.")
f"{self.eplb_config.num_redundant_experts}.")
if not self.enable_expert_parallel:
raise ValueError(
"enable_expert_parallel must be True to use EPLB.")
@ -289,10 +355,10 @@ class ParallelConfig:
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
)
else:
if self.num_redundant_experts != 0:
if self.eplb_config.num_redundant_experts != 0:
raise ValueError(
"num_redundant_experts should be used with EPLB."
f"{self.num_redundant_experts}.")
f"{self.eplb_config.num_redundant_experts}.")
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.

View File

@ -244,7 +244,7 @@ class EplbState:
dtype=torch.int32,
device=device,
)
expert_load_window_size = parallel_config.eplb_window_size
expert_load_window_size = parallel_config.eplb_config.window_size
expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers,
model.num_physical_experts),
@ -253,7 +253,7 @@ class EplbState:
)
# Set the initial progress of rearrangement to 3/4
eplb_step_interval = parallel_config.eplb_step_interval
eplb_step_interval = parallel_config.eplb_config.step_interval
expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4)

View File

@ -25,7 +25,7 @@ import vllm.envs as envs
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, ConvertOption,
DecodingConfig, DetailedTraceModules, Device,
DeviceConfig, DistributedExecutorBackend,
DeviceConfig, DistributedExecutorBackend, EPLBConfig,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
@ -305,11 +305,12 @@ class EngineArgs:
data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb
num_redundant_experts: int = ParallelConfig.num_redundant_experts
eplb_window_size: int = ParallelConfig.eplb_window_size
eplb_step_interval: int = ParallelConfig.eplb_step_interval
eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness
num_redundant_experts: int = EPLBConfig.num_redundant_experts
eplb_window_size: int = EPLBConfig.window_size
eplb_step_interval: int = EPLBConfig.step_interval
eplb_log_balancedness: bool = EPLBConfig.log_balancedness
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[BlockSize] = CacheConfig.block_size
@ -454,6 +455,9 @@ class EngineArgs:
if isinstance(self.compilation_config, dict):
self.compilation_config = CompilationConfig(
**self.compilation_config)
if isinstance(self.eplb_config, dict):
self.eplb_config = EPLBConfig.from_cli(json.dumps(
self.eplb_config))
# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
@ -661,14 +665,32 @@ class EngineArgs:
**parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument("--enable-eplb",
**parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--num-redundant-experts",
**parallel_kwargs["num_redundant_experts"])
parallel_group.add_argument("--eplb-window-size",
**parallel_kwargs["eplb_window_size"])
parallel_group.add_argument("--eplb-step-interval",
**parallel_kwargs["eplb_step_interval"])
parallel_group.add_argument("--eplb-log-balancedness",
**parallel_kwargs["eplb_log_balancedness"])
parallel_group.add_argument("--eplb-config",
**parallel_kwargs["eplb_config"])
parallel_group.add_argument(
"--num-redundant-experts",
type=int,
help=
"[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--eplb-window-size",
type=int,
help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--eplb-step-interval",
type=int,
help=
"[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--eplb-log-balancedness",
action=argparse.BooleanOptionalAction,
help=
"[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--max-parallel-loading-workers",
**parallel_kwargs["max_parallel_loading_workers"])
@ -1244,6 +1266,16 @@ class EngineArgs:
"Currently, speculative decoding is not supported with "
"async scheduling.")
# Forward the deprecated CLI args to the EPLB config.
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = self.num_redundant_experts
if self.eplb_window_size is not None:
self.eplb_config.window_size = self.eplb_window_size
if self.eplb_step_interval is not None:
self.eplb_config.step_interval = self.eplb_step_interval
if self.eplb_log_balancedness is not None:
self.eplb_config.log_balancedness = self.eplb_log_balancedness
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
@ -1257,10 +1289,7 @@ class EngineArgs:
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.num_redundant_experts,
eplb_window_size=self.eplb_window_size,
eplb_step_interval=self.eplb_step_interval,
eplb_log_balancedness=self.eplb_log_balancedness,
eplb_config=self.eplb_config,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
ray_workers_use_nsight=self.ray_workers_use_nsight,

View File

@ -132,10 +132,10 @@ class DeepseekV2MoE(nn.Module):
# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)

View File

@ -131,10 +131,10 @@ class Glm4MoE(nn.Module):
# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)

View File

@ -121,11 +121,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
@ -363,7 +363,8 @@ class Qwen3MoeModel(nn.Module):
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
enable_eplb = parallel_config.enable_eplb
self.num_redundant_experts = parallel_config.num_redundant_experts
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

View File

@ -1435,7 +1435,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model,
is_dummy,
is_profile,
log_stats=self.parallel_config.eplb_log_balancedness,
log_stats=self.parallel_config.eplb_config.log_balancedness,
)
def get_dp_padding(self,
@ -1977,7 +1977,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
global_expert_load, old_global_expert_indices = (
EplbState.recv_state())
num_logical_experts = global_expert_load.shape[1]
self.parallel_config.num_redundant_experts = (
self.parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts)
assert old_global_expert_indices.shape[
1] % num_local_physical_experts == 0

View File

@ -515,7 +515,7 @@ class Worker(WorkerBase):
assert self.model_runner.eplb_state is not None
new_physical_experts = \
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
parallel_config.num_redundant_experts = (
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts -
self.model_runner.eplb_state.logical_replica_count.shape[1])
global_expert_load = None
@ -531,7 +531,7 @@ class Worker(WorkerBase):
assert self.model_runner.eplb_state is not None
global_expert_load = self.model_runner.eplb_state.rearrange(
self.model_runner.model, execute_shuffle=False)
parallel_config.num_redundant_experts = (
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_load.shape[1])
prepare_communication_buffer_for_model(self.model_runner.model)
self.model_runner.model.update_physical_experts_metadata(