mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 02:44:26 +08:00
[Feature] use --eplb_config to set eplb param (#20562)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: rongfu.leng <lenronfu@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
4e51fa8cba
commit
4fbda0b20c
@ -33,7 +33,8 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
|||||||
PrefixCachingHashAlgo)
|
PrefixCachingHashAlgo)
|
||||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||||
CUDAGraphMode, PassConfig)
|
CUDAGraphMode, PassConfig)
|
||||||
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
|
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||||
|
ParallelConfig)
|
||||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||||
from vllm.config.utils import ConfigType, config
|
from vllm.config.utils import ConfigType, config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from dataclasses import field
|
|||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import model_validator
|
from pydantic import TypeAdapter, model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
from torch.distributed import ProcessGroup, ReduceOp
|
from torch.distributed import ProcessGroup, ReduceOp
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
@ -32,6 +32,38 @@ logger = init_logger(__name__)
|
|||||||
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class EPLBConfig:
|
||||||
|
"""Configuration for Expert Parallel Load Balancing (EP)."""
|
||||||
|
|
||||||
|
window_size: int = 1000
|
||||||
|
"""Window size for expert load recording."""
|
||||||
|
step_interval: int = 3000
|
||||||
|
"""
|
||||||
|
Interval for rearranging experts in expert parallelism.
|
||||||
|
|
||||||
|
Note that if this is greater than the EPLB window size, only the metrics
|
||||||
|
of the last `lb_window_size` steps will be used for rearranging experts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_redundant_experts: int = 0
|
||||||
|
"""Number of redundant experts to use for expert parallelism."""
|
||||||
|
|
||||||
|
log_balancedness: bool = False
|
||||||
|
"""
|
||||||
|
Log the balancedness each step of expert parallelism.
|
||||||
|
This is turned off by default since it will cause communication overhead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli(cls, cli_value: str) -> "EPLBConfig":
|
||||||
|
"""Parse the CLI value for the compilation config.
|
||||||
|
-O1, -O2, -O3, etc. is handled in FlexibleArgumentParser.
|
||||||
|
"""
|
||||||
|
return TypeAdapter(EPLBConfig).validate_json(cli_value)
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParallelConfig:
|
class ParallelConfig:
|
||||||
@ -75,22 +107,24 @@ class ParallelConfig:
|
|||||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||||
enable_eplb: bool = False
|
enable_eplb: bool = False
|
||||||
"""Enable expert parallelism load balancing for MoE layers."""
|
"""Enable expert parallelism load balancing for MoE layers."""
|
||||||
num_redundant_experts: int = 0
|
eplb_config: EPLBConfig = field(default_factory=EPLBConfig)
|
||||||
"""Number of redundant experts to use for expert parallelism."""
|
"""Expert parallelism configuration."""
|
||||||
eplb_window_size: int = 1000
|
num_redundant_experts: Optional[int] = None
|
||||||
"""Window size for expert load recording."""
|
"""`num_redundant_experts` is deprecated and has been replaced with
|
||||||
eplb_step_interval: int = 3000
|
`eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
|
||||||
"""
|
Please use `eplb_config.num_redundant_experts` instead."""
|
||||||
Interval for rearranging experts in expert parallelism.
|
eplb_window_size: Optional[int] = None
|
||||||
|
"""`eplb_window_size` is deprecated and has been replaced with
|
||||||
Note that if this is greater than the EPLB window size, only the metrics
|
`eplb_config.window_size`. This will be removed in v0.12.0.
|
||||||
of the last `eplb_window_size` steps will be used for rearranging experts.
|
Please use `eplb_config.window_size` instead."""
|
||||||
"""
|
eplb_step_interval: Optional[int] = None
|
||||||
eplb_log_balancedness: bool = False
|
"""`eplb_step_interval` is deprecated and has been replaced with
|
||||||
"""
|
`eplb_config.step_interval`. This will be removed in v0.12.0.
|
||||||
Log the balancedness each step of expert parallelism.
|
Please use `eplb_config.step_interval` instead."""
|
||||||
This is turned off by default since it will cause communication overhead.
|
eplb_log_balancedness: Optional[bool] = None
|
||||||
"""
|
"""`eplb_log_balancedness` is deprecated and has been replaced with
|
||||||
|
`eplb_config.log_balancedness`. This will be removed in v0.12.0.
|
||||||
|
Please use `eplb_config.log_balancedness` instead."""
|
||||||
|
|
||||||
max_parallel_loading_workers: Optional[int] = None
|
max_parallel_loading_workers: Optional[int] = None
|
||||||
"""Maximum number of parallel loading workers when loading model
|
"""Maximum number of parallel loading workers when loading model
|
||||||
@ -237,6 +271,38 @@ class ParallelConfig:
|
|||||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
# Forward deprecated fields to their new location
|
||||||
|
if self.num_redundant_experts is not None:
|
||||||
|
self.eplb_config.num_redundant_experts = (
|
||||||
|
self.num_redundant_experts)
|
||||||
|
logger.warning_once(
|
||||||
|
"num_redundant_experts is deprecated and has been replaced "
|
||||||
|
"with eplb_config.num_redundant_experts. This will be removed "
|
||||||
|
"in v0.12.0. Changing this field after initialization will "
|
||||||
|
"have no effect.")
|
||||||
|
if self.eplb_window_size is not None:
|
||||||
|
self.eplb_config.window_size = self.eplb_window_size
|
||||||
|
logger.warning_once(
|
||||||
|
"eplb_window_size is deprecated and has been replaced "
|
||||||
|
"with eplb_config.window_size. This will be removed "
|
||||||
|
"in v0.12.0. Changing this field after initialization will "
|
||||||
|
"have no effect.")
|
||||||
|
if self.eplb_step_interval is not None:
|
||||||
|
self.eplb_config.step_interval = self.eplb_step_interval
|
||||||
|
logger.warning_once(
|
||||||
|
"eplb_step_interval is deprecated and has been replaced "
|
||||||
|
"with eplb_config.step_interval. This will be removed "
|
||||||
|
"in v0.12.0. Changing this field after initialization will "
|
||||||
|
"have no effect.")
|
||||||
|
if self.eplb_log_balancedness is not None:
|
||||||
|
self.eplb_config.log_balancedness = self.eplb_log_balancedness
|
||||||
|
logger.warning_once(
|
||||||
|
"eplb_log_balancedness is deprecated and has been replaced "
|
||||||
|
"with eplb_config.log_balancedness. This will be removed "
|
||||||
|
"in v0.12.0. Changing this field after initialization will "
|
||||||
|
"have no effect.")
|
||||||
|
|
||||||
|
# Continue with the rest of the initialization
|
||||||
self.world_size = self.pipeline_parallel_size * \
|
self.world_size = self.pipeline_parallel_size * \
|
||||||
self.tensor_parallel_size
|
self.tensor_parallel_size
|
||||||
|
|
||||||
@ -275,10 +341,10 @@ class ParallelConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expert parallelism load balancing is only supported on "
|
"Expert parallelism load balancing is only supported on "
|
||||||
"CUDA devices now.")
|
"CUDA devices now.")
|
||||||
if self.num_redundant_experts < 0:
|
if self.eplb_config.num_redundant_experts < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"num_redundant_experts must be non-negative, but got "
|
"num_redundant_experts must be non-negative, but got "
|
||||||
f"{self.num_redundant_experts}.")
|
f"{self.eplb_config.num_redundant_experts}.")
|
||||||
if not self.enable_expert_parallel:
|
if not self.enable_expert_parallel:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"enable_expert_parallel must be True to use EPLB.")
|
"enable_expert_parallel must be True to use EPLB.")
|
||||||
@ -289,10 +355,10 @@ class ParallelConfig:
|
|||||||
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
|
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.num_redundant_experts != 0:
|
if self.eplb_config.num_redundant_experts != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"num_redundant_experts should be used with EPLB."
|
"num_redundant_experts should be used with EPLB."
|
||||||
f"{self.num_redundant_experts}.")
|
f"{self.eplb_config.num_redundant_experts}.")
|
||||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||||
# We use multiprocessing by default if world_size fits on the
|
# We use multiprocessing by default if world_size fits on the
|
||||||
# current node and we aren't in a ray placement group.
|
# current node and we aren't in a ray placement group.
|
||||||
|
|||||||
@ -244,7 +244,7 @@ class EplbState:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
expert_load_window_size = parallel_config.eplb_window_size
|
expert_load_window_size = parallel_config.eplb_config.window_size
|
||||||
expert_load_window = torch.zeros(
|
expert_load_window = torch.zeros(
|
||||||
(expert_load_window_size, model.num_moe_layers,
|
(expert_load_window_size, model.num_moe_layers,
|
||||||
model.num_physical_experts),
|
model.num_physical_experts),
|
||||||
@ -253,7 +253,7 @@ class EplbState:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set the initial progress of rearrangement to 3/4
|
# Set the initial progress of rearrangement to 3/4
|
||||||
eplb_step_interval = parallel_config.eplb_step_interval
|
eplb_step_interval = parallel_config.eplb_config.step_interval
|
||||||
expert_rearrangement_step = max(
|
expert_rearrangement_step = max(
|
||||||
0, eplb_step_interval - eplb_step_interval // 4)
|
0, eplb_step_interval - eplb_step_interval // 4)
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ import vllm.envs as envs
|
|||||||
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||||
ConfigFormat, ConfigType, ConvertOption,
|
ConfigFormat, ConfigType, ConvertOption,
|
||||||
DecodingConfig, DetailedTraceModules, Device,
|
DecodingConfig, DetailedTraceModules, Device,
|
||||||
DeviceConfig, DistributedExecutorBackend,
|
DeviceConfig, DistributedExecutorBackend, EPLBConfig,
|
||||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||||
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
|
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
|
||||||
@ -305,11 +305,12 @@ class EngineArgs:
|
|||||||
data_parallel_hybrid_lb: bool = False
|
data_parallel_hybrid_lb: bool = False
|
||||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||||
|
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
|
||||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||||
num_redundant_experts: int = ParallelConfig.num_redundant_experts
|
num_redundant_experts: int = EPLBConfig.num_redundant_experts
|
||||||
eplb_window_size: int = ParallelConfig.eplb_window_size
|
eplb_window_size: int = EPLBConfig.window_size
|
||||||
eplb_step_interval: int = ParallelConfig.eplb_step_interval
|
eplb_step_interval: int = EPLBConfig.step_interval
|
||||||
eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness
|
eplb_log_balancedness: bool = EPLBConfig.log_balancedness
|
||||||
max_parallel_loading_workers: Optional[
|
max_parallel_loading_workers: Optional[
|
||||||
int] = ParallelConfig.max_parallel_loading_workers
|
int] = ParallelConfig.max_parallel_loading_workers
|
||||||
block_size: Optional[BlockSize] = CacheConfig.block_size
|
block_size: Optional[BlockSize] = CacheConfig.block_size
|
||||||
@ -454,6 +455,9 @@ class EngineArgs:
|
|||||||
if isinstance(self.compilation_config, dict):
|
if isinstance(self.compilation_config, dict):
|
||||||
self.compilation_config = CompilationConfig(
|
self.compilation_config = CompilationConfig(
|
||||||
**self.compilation_config)
|
**self.compilation_config)
|
||||||
|
if isinstance(self.eplb_config, dict):
|
||||||
|
self.eplb_config = EPLBConfig.from_cli(json.dumps(
|
||||||
|
self.eplb_config))
|
||||||
# Setup plugins
|
# Setup plugins
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
load_general_plugins()
|
load_general_plugins()
|
||||||
@ -661,14 +665,32 @@ class EngineArgs:
|
|||||||
**parallel_kwargs["enable_expert_parallel"])
|
**parallel_kwargs["enable_expert_parallel"])
|
||||||
parallel_group.add_argument("--enable-eplb",
|
parallel_group.add_argument("--enable-eplb",
|
||||||
**parallel_kwargs["enable_eplb"])
|
**parallel_kwargs["enable_eplb"])
|
||||||
parallel_group.add_argument("--num-redundant-experts",
|
parallel_group.add_argument("--eplb-config",
|
||||||
**parallel_kwargs["num_redundant_experts"])
|
**parallel_kwargs["eplb_config"])
|
||||||
parallel_group.add_argument("--eplb-window-size",
|
parallel_group.add_argument(
|
||||||
**parallel_kwargs["eplb_window_size"])
|
"--num-redundant-experts",
|
||||||
parallel_group.add_argument("--eplb-step-interval",
|
type=int,
|
||||||
**parallel_kwargs["eplb_step_interval"])
|
help=
|
||||||
parallel_group.add_argument("--eplb-log-balancedness",
|
"[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
|
||||||
**parallel_kwargs["eplb_log_balancedness"])
|
deprecated=True)
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--eplb-window-size",
|
||||||
|
type=int,
|
||||||
|
help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
|
||||||
|
deprecated=True)
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--eplb-step-interval",
|
||||||
|
type=int,
|
||||||
|
help=
|
||||||
|
"[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
|
||||||
|
deprecated=True)
|
||||||
|
parallel_group.add_argument(
|
||||||
|
"--eplb-log-balancedness",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
help=
|
||||||
|
"[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
|
||||||
|
deprecated=True)
|
||||||
|
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--max-parallel-loading-workers",
|
"--max-parallel-loading-workers",
|
||||||
**parallel_kwargs["max_parallel_loading_workers"])
|
**parallel_kwargs["max_parallel_loading_workers"])
|
||||||
@ -1244,6 +1266,16 @@ class EngineArgs:
|
|||||||
"Currently, speculative decoding is not supported with "
|
"Currently, speculative decoding is not supported with "
|
||||||
"async scheduling.")
|
"async scheduling.")
|
||||||
|
|
||||||
|
# Forward the deprecated CLI args to the EPLB config.
|
||||||
|
if self.num_redundant_experts is not None:
|
||||||
|
self.eplb_config.num_redundant_experts = self.num_redundant_experts
|
||||||
|
if self.eplb_window_size is not None:
|
||||||
|
self.eplb_config.window_size = self.eplb_window_size
|
||||||
|
if self.eplb_step_interval is not None:
|
||||||
|
self.eplb_config.step_interval = self.eplb_step_interval
|
||||||
|
if self.eplb_log_balancedness is not None:
|
||||||
|
self.eplb_config.log_balancedness = self.eplb_log_balancedness
|
||||||
|
|
||||||
parallel_config = ParallelConfig(
|
parallel_config = ParallelConfig(
|
||||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||||
tensor_parallel_size=self.tensor_parallel_size,
|
tensor_parallel_size=self.tensor_parallel_size,
|
||||||
@ -1257,10 +1289,7 @@ class EngineArgs:
|
|||||||
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
|
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
|
||||||
enable_expert_parallel=self.enable_expert_parallel,
|
enable_expert_parallel=self.enable_expert_parallel,
|
||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
num_redundant_experts=self.num_redundant_experts,
|
eplb_config=self.eplb_config,
|
||||||
eplb_window_size=self.eplb_window_size,
|
|
||||||
eplb_step_interval=self.eplb_step_interval,
|
|
||||||
eplb_log_balancedness=self.eplb_log_balancedness,
|
|
||||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||||
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
||||||
|
|||||||
@ -132,10 +132,10 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
# Load balancing settings.
|
# Load balancing settings.
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
parallel_config = vllm_config.parallel_config
|
eplb_config = vllm_config.parallel_config.eplb_config
|
||||||
self.enable_eplb = enable_eplb
|
self.enable_eplb = enable_eplb
|
||||||
|
|
||||||
self.n_redundant_experts = parallel_config.num_redundant_experts
|
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||||
self.n_logical_experts = self.n_routed_experts
|
self.n_logical_experts = self.n_routed_experts
|
||||||
self.n_physical_experts = (self.n_logical_experts +
|
self.n_physical_experts = (self.n_logical_experts +
|
||||||
self.n_redundant_experts)
|
self.n_redundant_experts)
|
||||||
|
|||||||
@ -131,10 +131,10 @@ class Glm4MoE(nn.Module):
|
|||||||
|
|
||||||
# Load balancing settings.
|
# Load balancing settings.
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
parallel_config = vllm_config.parallel_config
|
eplb_config = vllm_config.parallel_config.eplb_config
|
||||||
self.enable_eplb = enable_eplb
|
self.enable_eplb = enable_eplb
|
||||||
|
|
||||||
self.n_redundant_experts = parallel_config.num_redundant_experts
|
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||||
self.n_logical_experts = self.n_routed_experts
|
self.n_logical_experts = self.n_routed_experts
|
||||||
self.n_physical_experts = (self.n_logical_experts +
|
self.n_physical_experts = (self.n_logical_experts +
|
||||||
self.n_redundant_experts)
|
self.n_redundant_experts)
|
||||||
|
|||||||
@ -121,11 +121,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
|
|
||||||
# Load balancing settings.
|
# Load balancing settings.
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
parallel_config = vllm_config.parallel_config
|
eplb_config = vllm_config.parallel_config.eplb_config
|
||||||
self.enable_eplb = enable_eplb
|
self.enable_eplb = enable_eplb
|
||||||
|
|
||||||
self.n_logical_experts = self.n_routed_experts
|
self.n_logical_experts = self.n_routed_experts
|
||||||
self.n_redundant_experts = parallel_config.num_redundant_experts
|
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||||
self.n_physical_experts = (self.n_logical_experts +
|
self.n_physical_experts = (self.n_logical_experts +
|
||||||
self.n_redundant_experts)
|
self.n_redundant_experts)
|
||||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||||
@ -363,7 +363,8 @@ class Qwen3MoeModel(nn.Module):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
enable_eplb = parallel_config.enable_eplb
|
enable_eplb = parallel_config.enable_eplb
|
||||||
self.num_redundant_experts = parallel_config.num_redundant_experts
|
eplb_config = parallel_config.eplb_config
|
||||||
|
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||||
|
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|||||||
@ -1435,7 +1435,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
model,
|
model,
|
||||||
is_dummy,
|
is_dummy,
|
||||||
is_profile,
|
is_profile,
|
||||||
log_stats=self.parallel_config.eplb_log_balancedness,
|
log_stats=self.parallel_config.eplb_config.log_balancedness,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_dp_padding(self,
|
def get_dp_padding(self,
|
||||||
@ -1977,7 +1977,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
global_expert_load, old_global_expert_indices = (
|
global_expert_load, old_global_expert_indices = (
|
||||||
EplbState.recv_state())
|
EplbState.recv_state())
|
||||||
num_logical_experts = global_expert_load.shape[1]
|
num_logical_experts = global_expert_load.shape[1]
|
||||||
self.parallel_config.num_redundant_experts = (
|
self.parallel_config.eplb_config.num_redundant_experts = (
|
||||||
num_local_physical_experts * new_ep_size - num_logical_experts)
|
num_local_physical_experts * new_ep_size - num_logical_experts)
|
||||||
assert old_global_expert_indices.shape[
|
assert old_global_expert_indices.shape[
|
||||||
1] % num_local_physical_experts == 0
|
1] % num_local_physical_experts == 0
|
||||||
|
|||||||
@ -515,7 +515,7 @@ class Worker(WorkerBase):
|
|||||||
assert self.model_runner.eplb_state is not None
|
assert self.model_runner.eplb_state is not None
|
||||||
new_physical_experts = \
|
new_physical_experts = \
|
||||||
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
|
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
|
||||||
parallel_config.num_redundant_experts = (
|
parallel_config.eplb_config.num_redundant_experts = (
|
||||||
new_physical_experts -
|
new_physical_experts -
|
||||||
self.model_runner.eplb_state.logical_replica_count.shape[1])
|
self.model_runner.eplb_state.logical_replica_count.shape[1])
|
||||||
global_expert_load = None
|
global_expert_load = None
|
||||||
@ -531,7 +531,7 @@ class Worker(WorkerBase):
|
|||||||
assert self.model_runner.eplb_state is not None
|
assert self.model_runner.eplb_state is not None
|
||||||
global_expert_load = self.model_runner.eplb_state.rearrange(
|
global_expert_load = self.model_runner.eplb_state.rearrange(
|
||||||
self.model_runner.model, execute_shuffle=False)
|
self.model_runner.model, execute_shuffle=False)
|
||||||
parallel_config.num_redundant_experts = (
|
parallel_config.eplb_config.num_redundant_experts = (
|
||||||
new_physical_experts - global_expert_load.shape[1])
|
new_physical_experts - global_expert_load.shape[1])
|
||||||
prepare_communication_buffer_for_model(self.model_runner.model)
|
prepare_communication_buffer_for_model(self.model_runner.model)
|
||||||
self.model_runner.model.update_physical_experts_metadata(
|
self.model_runner.model.update_physical_experts_metadata(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user