mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:45:52 +08:00
enable multi-node in external launcher mode (#29833)
This commit is contained in:
parent
1109f98288
commit
ad32e3e19c
@ -593,10 +593,14 @@ class ParallelConfig:
|
||||
"max_parallel_loading_workers is currently "
|
||||
"not supported and will be ignored."
|
||||
)
|
||||
if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1:
|
||||
allowed_backends = ("mp", "uni", "external_launcher")
|
||||
if (
|
||||
self.distributed_executor_backend not in allowed_backends
|
||||
and self.nnodes > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"nnodes > 1 can only be set when distributed executor "
|
||||
"backend is mp or uni."
|
||||
"backend is mp, uni or external_launcher."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -1169,17 +1169,13 @@ def init_distributed_environment(
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config()
|
||||
if config is not None and config.parallel_config.nnodes > 1:
|
||||
parallel_config = config.parallel_config
|
||||
ip = parallel_config.master_addr
|
||||
rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
world_size = parallel_config.world_size_across_dp
|
||||
port = parallel_config.master_port
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
elif (
|
||||
if (
|
||||
config is not None
|
||||
and config.parallel_config.data_parallel_size > 1
|
||||
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
||||
and (
|
||||
config.parallel_config.nnodes > 1
|
||||
or config.parallel_config.data_parallel_size > 1
|
||||
)
|
||||
):
|
||||
parallel_config = config.parallel_config
|
||||
# adjust to take into account data parallelism
|
||||
@ -1187,15 +1183,22 @@ def init_distributed_environment(
|
||||
rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
# adjust the world size to take into account data parallelism
|
||||
world_size = parallel_config.world_size_across_dp
|
||||
ip = parallel_config.data_parallel_master_ip
|
||||
port = parallel_config.get_next_dp_init_port()
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
logger.debug(
|
||||
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
|
||||
world_size,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
)
|
||||
|
||||
# Use appropriate IP and port based on configuration
|
||||
if parallel_config.nnodes > 1:
|
||||
ip = parallel_config.master_addr
|
||||
port = parallel_config.master_port
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
else:
|
||||
ip = parallel_config.data_parallel_master_ip
|
||||
port = parallel_config.get_next_dp_init_port()
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
logger.debug(
|
||||
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
|
||||
world_size,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
)
|
||||
if not torch.distributed.is_initialized():
|
||||
logger.info(
|
||||
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user