[Feature] Use pydantic validation in parallel.py config (#26417)

Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Simon Danielsson 2025-10-09 14:41:31 +02:00 committed by GitHub
parent d1ddf340c8
commit 92be3f3517
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,11 +3,10 @@
import hashlib
import os
from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch
from pydantic import model_validator
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self
@ -32,6 +31,7 @@ logger = init_logger(__name__)
ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"]
@config
@ -49,7 +49,7 @@ class EPLBConfig:
of the last `lb_window_size` steps will be used for rearranging experts.
"""
num_redundant_experts: int = 0
num_redundant_experts: int = Field(default=0, ge=0)
"""Number of redundant experts to use for expert parallelism."""
log_balancedness: bool = False
@ -84,7 +84,7 @@ class ParallelConfig:
"""Port for data parallel messaging."""
data_parallel_master_port: int = 29500
"""Port of the data parallel master."""
data_parallel_backend: str = "mp"
data_parallel_backend: DataParallelBackend = "mp"
"""Backend to use for data parallel, either "mp" or "ray"."""
data_parallel_external_lb: bool = False
"""Whether to use "external" DP LB mode. Applies only to online serving
@ -102,7 +102,7 @@ class ParallelConfig:
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
"""Enable expert parallelism load balancing for MoE layers."""
eplb_config: EPLBConfig = field(default_factory=EPLBConfig)
eplb_config: EPLBConfig = Field(default_factory=EPLBConfig)
"""Expert parallelism configuration."""
expert_placement_strategy: ExpertPlacementStrategy = "linear"
"""The expert placement strategy for MoE layers:\n
@ -188,13 +188,13 @@ class ParallelConfig:
new attributes and methods to the worker class for use in collective_rpc
calls."""
world_size: int = field(init=False)
world_size: int = Field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
rank: int = 0
"""Global rank in distributed setup."""
_data_parallel_master_port_list: list[int] = field(default_factory=list)
_data_parallel_master_port_list: list[int] = Field(default_factory=list)
"""List of open port auto-queried for data parallel messaging.
Set to be private as it's not intended to be configured by users.
"""
@ -204,7 +204,7 @@ class ParallelConfig:
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""
_api_process_count: int = 1
_api_process_count: int = Field(default=1, gt=0)
"""
The number of API processes initialized.
@ -213,7 +213,7 @@ class ParallelConfig:
should only be set by API server scale-out.
"""
_api_process_rank: int = 0
_api_process_rank: int = Field(default=0, ge=-1)
"""
The rank of this API process, or `-1` for engine core processes
under API server scale-out.
@ -223,6 +223,51 @@ class ParallelConfig:
should only be set by API server scale-out.
"""
@model_validator(mode="after")
def _validate_parallel_config(self) -> Self:
if self._api_process_rank >= self._api_process_count:
raise ValueError(
"Invalid value of `_api_process_rank`. "
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
f"but found: {self._api_process_rank}"
)
if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
f"must be <= data_parallel_size ({self.data_parallel_size})"
)
if self.data_parallel_size <= 1 and self.data_parallel_external_lb:
raise ValueError(
"data_parallel_external_lb can only be set when data_parallel_size > 1"
)
if self.enable_eplb:
if not current_platform.is_cuda():
raise ValueError(
"Expert parallelism load balancing is only supported on "
"CUDA devices now."
)
if not self.enable_expert_parallel:
raise ValueError("enable_expert_parallel must be True to use EPLB.")
if self.tensor_parallel_size * self.data_parallel_size <= 1:
raise ValueError(
"EPLB requires tensor_parallel_size or data_parallel_size "
f"to be greater than 1, but got "
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
)
else:
if self.eplb_config.num_redundant_experts != 0:
raise ValueError(
"num_redundant_experts is set to "
f"{self.eplb_config.num_redundant_experts} but EPLB is not "
"enabled. Either enable EPLB or unset "
"num_redundant_experts."
)
return self
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
@ -396,12 +441,6 @@ class ParallelConfig:
logger.info("Using external launcher for distributed inference.")
self.world_size *= self.data_parallel_size
if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
f"must be <= data_parallel_size ({self.data_parallel_size})"
)
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
if self.distributed_executor_backend == "external_launcher":
@ -431,43 +470,10 @@ class ParallelConfig:
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
if self.data_parallel_external_lb:
raise ValueError(
"data_parallel_external_lb can only "
"be set when data_parallel_size > 1"
)
if self.distributed_executor_backend == "external_launcher":
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
if self.enable_eplb:
if not current_platform.is_cuda():
raise ValueError(
"Expert parallelism load balancing is only supported on "
"CUDA devices now."
)
if self.eplb_config.num_redundant_experts < 0:
raise ValueError(
"num_redundant_experts must be non-negative, but got "
f"{self.eplb_config.num_redundant_experts}."
)
if not self.enable_expert_parallel:
raise ValueError("enable_expert_parallel must be True to use EPLB.")
if self.tensor_parallel_size * self.data_parallel_size <= 1:
raise ValueError(
"EPLB requires tensor_parallel_size or data_parallel_size "
f"to be greater than 1, but got "
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
)
else:
if self.eplb_config.num_redundant_experts != 0:
raise ValueError(
"num_redundant_experts is set to "
f"{self.eplb_config.num_redundant_experts} but EPLB is not "
"enabled. Either enable EPLB or unset "
"num_redundant_experts."
)
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
@ -514,13 +520,6 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"
if not -1 <= self._api_process_rank < self._api_process_count:
raise ValueError(
"Invalid value of `_api_process_rank`. "
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
f"but found: {self._api_process_rank}"
)
@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (