[Core] Multiprocessing Pipeline Parallel support (#6130)

Co-authored-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Nick Hill 2024-07-18 19:15:52 -07:00 committed by GitHub
parent c5df56f88b
commit b5672a112c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 152 additions and 99 deletions

View File

@ -54,7 +54,7 @@ steps:
- label: Core Test
mirror_hardwares: [amd]
fast_check: true
commands:
commands:
- pytest -v -s core
- pytest -v -s distributed/test_parallel_state.py
@ -73,7 +73,7 @@ steps:
commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
- pytest -v -s distributed/test_pipeline_parallel.py
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
@ -123,7 +123,7 @@ steps:
- label: Engine Test
mirror_hardwares: [amd]
commands:
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenization

View File

@ -1,28 +1,42 @@
import os
import pytest
from ..utils import compare_two_settings
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
@pytest.mark.parametrize(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND",
[
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
pp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"float16",
"--pipeline-parallel-size",
str(PP_SIZE),
"--tensor-parallel-size",
str(TP_SIZE),
"--distributed-executor-backend",
"ray",
DIST_BACKEND,
]
# compare without pipeline parallelism

View File

@ -712,10 +712,6 @@ class ParallelConfig:
self.rank = 0
def _verify_args(self) -> None:
if (self.pipeline_parallel_size > 1
and self.distributed_executor_backend == "mp"):
raise NotImplementedError("Pipeline parallelism is not supported "
"yet with multiprocessing.")
if self.distributed_executor_backend not in ("ray", "mp", None):
raise ValueError(
"Unrecognized distributed executor backend. Supported values "

View File

@ -1,4 +1,3 @@
import asyncio
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
@ -132,26 +131,6 @@ class ExecutorBase(ABC):
class ExecutorAsyncBase(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
) -> None:
self.pp_locks: Optional[List[asyncio.Lock]] = None
super().__init__(model_config, cache_config, parallel_config,
scheduler_config, device_config, load_config,
lora_config, multimodal_config, speculative_config,
prompt_adapter_config)
@abstractmethod
async def execute_model_async(
self,

View File

@ -12,6 +12,15 @@ from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
def create_worker(worker_module_name, worker_class_name, **kwargs):
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**kwargs)
return wrapper.worker
class GPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
@ -51,25 +60,30 @@ class GPUExecutor(ExecutorBase):
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
def _get_create_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict:
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)
if self.speculative_config is None:
worker_kwargs.update(worker_module_name="vllm.worker.worker",
worker_class_name="Worker")
else:
worker_kwargs.update(
worker_module_name="vllm.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker")
return worker_kwargs
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
else:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
return create_worker(**self._get_create_worker_kwargs(
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method))
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the

View File

@ -7,12 +7,13 @@ from typing import Any, List, Optional
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.gpu_executor import create_worker
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (cuda_device_count_stateless,
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
error_on_invalid_device_count_status,
get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async,
@ -26,7 +27,8 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None:
# Create the parallel GPU workers.
world_size = self.parallel_config.tensor_parallel_size
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if "CUDA_VISIBLE_DEVICES" not in os.environ:
@ -49,8 +51,15 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if world_size > 1:
maybe_set_triton_cache_manager()
assert world_size <= cuda_device_count_stateless(), (
"please set tensor_parallel_size to less than max local gpu count")
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (
f"please set tensor_parallel_size ({tensor_parallel_size}) "
f"to less than max local gpu count ({cuda_device_count})")
assert world_size <= cuda_device_count, (
f"please ensure that world_size ({world_size}) "
f"is less than than max local gpu count ({cuda_device_count})")
error_on_invalid_device_count_status()
@ -60,21 +69,35 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())
self.workers: List[ProcessWorkerWrapper] = []
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[ProcessWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[ProcessWorkerWrapper] = []
if world_size == 1:
self.workers = []
self.worker_monitor = None
else:
result_handler = ResultHandler()
self.workers = [
ProcessWorkerWrapper(
for rank in range(1, world_size):
worker = ProcessWorkerWrapper(
result_handler,
partial(
self._create_worker,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)) for rank in range(1, world_size)
]
create_worker,
**self._get_create_worker_kwargs(
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
)))
self.workers.append(worker)
if rank % tensor_parallel_size == 0:
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(worker)
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
@ -136,16 +159,19 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
# Start the workers first.
if async_run_tensor_parallel_workers_only:
# Run only non-driver workers and just return futures.
return [
worker.execute_method(method, *args, **kwargs)
for worker in self.non_driver_workers
]
# Start all remote workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
for worker in self.workers
]
if async_run_tensor_parallel_workers_only:
# Just return futures
return worker_outputs
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
@ -172,16 +198,45 @@ class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_exec_model = make_async(self.driver_worker.execute_model)
self.pp_locks: Optional[List[asyncio.Lock]] = None
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return await self.driver_exec_model(execute_model_req)
if not self.tp_driver_workers:
return await self.driver_exec_model(execute_model_req)
if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# We create the locks here to avoid creating them in the constructor
# which uses a different asyncio loop.
self.pp_locks = [
asyncio.Lock()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
tasks = [
asyncio.create_task(
_run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
execute_model_req))
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(
asyncio.create_task(
_run_task_with_lock(driver_worker.execute_method_async,
self.pp_locks[pp_rank],
"execute_model", execute_model_req)))
results = await asyncio.gather(*tasks)
# Only the last PP stage has the final results.
return results[-1]
async def _start_worker_execution_loop(self):
coros = [
worker.execute_method_async("start_worker_execution_loop")
for worker in self.workers
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)

View File

@ -10,7 +10,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (error_on_invalid_device_count_status,
from vllm.utils import (_run_task_with_lock,
error_on_invalid_device_count_status,
get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
@ -240,27 +241,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
tp_driver_worker_ranks = []
non_driver_worker_ranks = []
for idx, rank in enumerate(worker_ranks[1:]):
# Enforce rank order for correct rank to return final output.
for rank, worker in sorted(zip(worker_ranks[1:], self.workers)):
# We need to skip the driver worker, which we
# do by skipping worker_ranks[0] which is always 0.
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(self.workers[idx])
tp_driver_worker_ranks.append(rank)
self.tp_driver_workers.append(worker)
else:
self.non_driver_workers.append(self.workers[idx])
non_driver_worker_ranks.append(rank)
# Enforce rank order for correct rank to return final output.
self.tp_driver_workers = [
worker for _, worker in sorted(
zip(tp_driver_worker_ranks, self.tp_driver_workers))
]
self.non_driver_workers = [
worker for _, worker in sorted(
zip(non_driver_worker_ranks, self.non_driver_workers))
]
self.non_driver_workers.append(worker)
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
@ -413,6 +401,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pp_locks: Optional[List[asyncio.Lock]] = None
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(
@ -437,6 +426,9 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
) -> List[SamplerOutput]:
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
if not self.tp_driver_workers:
return await self.driver_exec_method("execute_model",
execute_model_req)
if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
@ -447,15 +439,11 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
for _ in range(self.parallel_config.pipeline_parallel_size)
]
async def _run_task_with_lock(task, lock, *args, **kwargs):
async with lock:
return await task(*args, **kwargs)
tasks = []
tasks.append(
tasks = [
asyncio.create_task(
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
"execute_model", execute_model_req)))
"execute_model", execute_model_req))
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1):
tasks.append(

View File

@ -939,3 +939,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
processed_args.append(arg)
return super().parse_args(processed_args, namespace)
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
**kwargs):
"""Utility function to run async task in a lock"""
async with lock:
return await task(*args, **kwargs)

View File

@ -274,11 +274,11 @@ class LocalOrDistributedWorkerBase(WorkerBase):
num_steps)
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
get_pp_group().send_tensor_dict(output.tensors)
return [None]
# Worker only supports single-step execution. Wrap the output in a
# list to conform to interface.
# output is List[SamplerOutput]
return output
def _execute_model_spmd(