mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 03:55:41 +08:00
[DP] support torchrun external launcher with Data Parallelism (#24899)
Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
239ef0c1ac
commit
922979bfcc
@ -165,10 +165,18 @@ steps:
|
|||||||
- tests/v1/test_hybrid_lb_dp.py
|
- tests/v1/test_hybrid_lb_dp.py
|
||||||
- tests/v1/engine/test_engine_core_client.py
|
- tests/v1/engine/test_engine_core_client.py
|
||||||
commands:
|
commands:
|
||||||
# test with tp=2 and external_dp=2
|
# test with torchrun tp=2 and external_dp=2
|
||||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||||
# test with tp=2 and pp=2
|
# test with torchrun tp=2 and pp=2
|
||||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||||
|
# test with torchrun tp=4 and dp=1
|
||||||
|
- TP_SIZE=4 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||||
|
# test with torchrun tp=2, pp=2 and dp=1
|
||||||
|
- PP_SIZE=2 TP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||||
|
# test with torchrun tp=1 and dp=4 with ep
|
||||||
|
- DP_SIZE=4 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||||
|
# test with torchrun tp=2 and dp=2 with ep
|
||||||
|
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||||
# test with internal dp
|
# test with internal dp
|
||||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||||
|
|||||||
81
examples/offline_inference/torchrun_dp_example.py
Normal file
81
examples/offline_inference/torchrun_dp_example.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
experimental support for data-parallel inference with torchrun
|
||||||
|
Note the data load balancing and distribution is done out of the vllm engine,
|
||||||
|
no internal lb supported in external_launcher mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# Create prompts, the same across all ranks
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
] * 50
|
||||||
|
|
||||||
|
# Create sampling parameters, the same across all ranks
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
# Use `distributed_executor_backend="external_launcher"` so that
|
||||||
|
# this llm engine/instance only creates one worker.
|
||||||
|
# it is important to set an explicit seed to make sure that
|
||||||
|
# all ranks have the same random seed, so that sampling can be
|
||||||
|
# deterministic across ranks.
|
||||||
|
llm = LLM(
|
||||||
|
model="microsoft/Phi-mini-MoE-instruct",
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
data_parallel_size=2,
|
||||||
|
pipeline_parallel_size=1,
|
||||||
|
enable_expert_parallel=False,
|
||||||
|
distributed_executor_backend="external_launcher",
|
||||||
|
max_model_len=4096,
|
||||||
|
gpu_memory_utilization=0.6,
|
||||||
|
seed=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
|
||||||
|
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
# all ranks will have the same outputs
|
||||||
|
print("-" * 50)
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
|
||||||
|
print("-" * 50)
|
||||||
|
"""
|
||||||
|
Further tips:
|
||||||
|
|
||||||
|
1. to communicate control messages across all ranks, use the cpu group,
|
||||||
|
a PyTorch ProcessGroup with GLOO backend.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from vllm.distributed.parallel_state import get_world_group
|
||||||
|
cpu_group = get_world_group().cpu_group
|
||||||
|
torch_rank = dist.get_rank(group=cpu_group)
|
||||||
|
if torch_rank == 0:
|
||||||
|
# do something for rank 0, e.g. saving the results to disk.
|
||||||
|
```
|
||||||
|
|
||||||
|
2. to communicate data across all ranks, use the model's device group,
|
||||||
|
a PyTorch ProcessGroup with NCCL backend.
|
||||||
|
```python
|
||||||
|
from vllm.distributed.parallel_state import get_world_group
|
||||||
|
device_group = get_world_group().device_group
|
||||||
|
```
|
||||||
|
|
||||||
|
3. to access the model directly in every rank, use the following code:
|
||||||
|
```python
|
||||||
|
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||||
|
```
|
||||||
|
"""
|
||||||
81
tests/distributed/test_torchrun_example_moe.py
Normal file
81
tests/distributed/test_torchrun_example_moe.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# unit test for `examples/offline_inference/torchrun_example.py`
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.distributed.parallel_state import get_tp_group, get_world_group
|
||||||
|
|
||||||
|
dist.init_process_group(backend="gloo")
|
||||||
|
|
||||||
|
# Create prompts
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
] * 10
|
||||||
|
dp_size = int(os.getenv("DP_SIZE", "1"))
|
||||||
|
dp_rank = int(os.getenv("DP_RANK", "0"))
|
||||||
|
|
||||||
|
if dp_size > 1:
|
||||||
|
# distribute the prompts across the data parallel ranks
|
||||||
|
prompts = [
|
||||||
|
prompt for idx, prompt in enumerate(prompts)
|
||||||
|
if idx % dp_size == dp_rank
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
|
||||||
|
# to test if all ranks agree on the same kv cache configuration.
|
||||||
|
llm = LLM(model="microsoft/Phi-mini-MoE-instruct",
|
||||||
|
tensor_parallel_size=int(os.getenv("TP_SIZE", "1")),
|
||||||
|
pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")),
|
||||||
|
enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1,
|
||||||
|
distributed_executor_backend="external_launcher",
|
||||||
|
gpu_memory_utilization=random.uniform(0.7, 0.9),
|
||||||
|
swap_space=random.randint(1, 4),
|
||||||
|
seed=0)
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
group = get_world_group() if dp_size == 1 else get_tp_group()
|
||||||
|
cpu_group = group.cpu_group
|
||||||
|
group_rank = dist.get_rank(group=cpu_group)
|
||||||
|
|
||||||
|
|
||||||
|
def test_consistent_across_ranks(obj):
|
||||||
|
if group_rank == 0:
|
||||||
|
dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group)
|
||||||
|
else:
|
||||||
|
container = [None]
|
||||||
|
dist.broadcast_object_list(container,
|
||||||
|
src=group.ranks[0],
|
||||||
|
group=cpu_group)
|
||||||
|
assert container[0] == obj
|
||||||
|
|
||||||
|
|
||||||
|
test_consistent_across_ranks(
|
||||||
|
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
|
||||||
|
test_consistent_across_ranks(
|
||||||
|
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
|
||||||
|
|
||||||
|
# make sure we can access the model parameters from the calling process
|
||||||
|
# of the `LLM` instance.
|
||||||
|
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
|
||||||
|
model.parameters())
|
||||||
|
test_consistent_across_ranks(len(params))
|
||||||
|
|
||||||
|
# all ranks should have the same outputs
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
test_consistent_across_ranks(prompt)
|
||||||
|
test_consistent_across_ranks(generated_text)
|
||||||
|
print(f"Rank {group_rank}, Prompt: {prompt!r}, "
|
||||||
|
f"Generated text: {generated_text!r}")
|
||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import os
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||||
|
|
||||||
@ -351,6 +352,10 @@ class ParallelConfig:
|
|||||||
self.world_size = self.pipeline_parallel_size * \
|
self.world_size = self.pipeline_parallel_size * \
|
||||||
self.tensor_parallel_size
|
self.tensor_parallel_size
|
||||||
|
|
||||||
|
if self.distributed_executor_backend == "external_launcher":
|
||||||
|
logger.info("Using external launcher for distributed inference.")
|
||||||
|
self.world_size *= self.data_parallel_size
|
||||||
|
|
||||||
if self.data_parallel_size_local > self.data_parallel_size:
|
if self.data_parallel_size_local > self.data_parallel_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"data_parallel_size_local ({self.data_parallel_size_local}) "
|
f"data_parallel_size_local ({self.data_parallel_size_local}) "
|
||||||
@ -358,6 +363,13 @@ class ParallelConfig:
|
|||||||
|
|
||||||
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
|
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
|
||||||
# Data parallel was specified in the engine args.
|
# Data parallel was specified in the engine args.
|
||||||
|
if self.distributed_executor_backend == "external_launcher":
|
||||||
|
# For external launcher,
|
||||||
|
# we need to set the data parallel rank automatically
|
||||||
|
self.data_parallel_rank = int(os.environ["RANK"]) \
|
||||||
|
// (self.world_size // self.data_parallel_size)
|
||||||
|
logger.info("Set data_parallel_rank to %d automatically.",
|
||||||
|
self.data_parallel_rank)
|
||||||
if not self._data_parallel_master_port_list:
|
if not self._data_parallel_master_port_list:
|
||||||
self._data_parallel_master_port_list = get_open_ports_list(5)
|
self._data_parallel_master_port_list = get_open_ports_list(5)
|
||||||
self.data_parallel_master_port = \
|
self.data_parallel_master_port = \
|
||||||
@ -380,7 +392,6 @@ class ParallelConfig:
|
|||||||
"be set when data_parallel_size > 1")
|
"be set when data_parallel_size > 1")
|
||||||
|
|
||||||
if self.distributed_executor_backend == "external_launcher":
|
if self.distributed_executor_backend == "external_launcher":
|
||||||
import os
|
|
||||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||||
logger.info("Disabling V1 multiprocessing for external launcher.")
|
logger.info("Disabling V1 multiprocessing for external launcher.")
|
||||||
|
|
||||||
|
|||||||
@ -1032,7 +1032,9 @@ def init_distributed_environment(world_size: int = -1,
|
|||||||
distributed_init_method, backend)
|
distributed_init_method, backend)
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
config = get_current_vllm_config()
|
config = get_current_vllm_config()
|
||||||
if config is not None and config.parallel_config.data_parallel_size > 1:
|
if config is not None and config.parallel_config.data_parallel_size > 1 \
|
||||||
|
and config.parallel_config.distributed_executor_backend \
|
||||||
|
!= "external_launcher":
|
||||||
parallel_config = config.parallel_config
|
parallel_config = config.parallel_config
|
||||||
# adjust to take into account data parallelism
|
# adjust to take into account data parallelism
|
||||||
# offset the rank by the data parallel rank
|
# offset the rank by the data parallel rank
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from typing_extensions import TypeVar
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import ParallelConfig, VllmConfig
|
from vllm.config import ParallelConfig, VllmConfig
|
||||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||||
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -77,10 +78,15 @@ class LLMEngine:
|
|||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
self.stat_logger = PrometheusStatLogger(vllm_config)
|
self.stat_logger = PrometheusStatLogger(vllm_config)
|
||||||
|
|
||||||
|
executor_backend = (
|
||||||
|
self.vllm_config.parallel_config.distributed_executor_backend)
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
self.external_launcher_dp = (parallel_config.data_parallel_size > 1 and
|
||||||
|
executor_backend == "external_launcher")
|
||||||
# important: init dp group before init the engine_core
|
# important: init dp group before init the engine_core
|
||||||
# In the decoupled engine case this is handled in EngineCoreProc.
|
# In the decoupled engine case this is handled in EngineCoreProc.
|
||||||
parallel_config = vllm_config.parallel_config
|
if not multiprocess_mode and parallel_config.data_parallel_size > 1 \
|
||||||
if not multiprocess_mode and parallel_config.data_parallel_size > 1:
|
and not self.external_launcher_dp:
|
||||||
self.dp_group = parallel_config.stateless_init_dp_group()
|
self.dp_group = parallel_config.stateless_init_dp_group()
|
||||||
else:
|
else:
|
||||||
self.dp_group = None
|
self.dp_group = None
|
||||||
@ -120,6 +126,11 @@ class LLMEngine:
|
|||||||
# for v0 compatibility
|
# for v0 compatibility
|
||||||
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
||||||
|
|
||||||
|
if self.external_launcher_dp:
|
||||||
|
# If we use DP in external launcher mode, we reuse the
|
||||||
|
# existing DP group used for data communication.
|
||||||
|
self.dp_group = get_dp_group().cpu_group
|
||||||
|
|
||||||
# Don't keep the dummy data in memory
|
# Don't keep the dummy data in memory
|
||||||
self.reset_mm_cache()
|
self.reset_mm_cache()
|
||||||
|
|
||||||
@ -331,5 +342,6 @@ class LLMEngine:
|
|||||||
return self.collective_rpc("apply_model", args=(func, ))
|
return self.collective_rpc("apply_model", args=(func, ))
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if dp_group := getattr(self, "dp_group", None):
|
if dp_group := getattr(self, "dp_group",
|
||||||
|
None) and not self.external_launcher_dp:
|
||||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user