[V1][PP] Support PP for MultiprocExecutor (#14219)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang 2025-05-06 22:58:05 +08:00 committed by GitHub
parent d419aa5dc4
commit a6fed02068
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 98 additions and 28 deletions

View File

@ -100,9 +100,8 @@ class PPTestSettings:
eager_mode=True, eager_mode=True,
chunked_prefill=False), chunked_prefill=False),
], ],
# only ray is supported for V1 distributed_backends=["mp", "mp", "ray", "ray"],
distributed_backends=["mp", "ray", "ray"], vllm_major_versions=["0", "1", "0", "1"],
vllm_major_versions=["0", "0", "1"],
task=task, task=task,
test_options=PPTestOptions(multi_node_only=multi_node_only, test_options=PPTestOptions(multi_node_only=multi_node_only,
load_format=load_format), load_format=load_format),
@ -350,6 +349,11 @@ def _compare_tp(
# Temporary. Currently when zeromq + SPMD is used, it does not properly # Temporary. Currently when zeromq + SPMD is used, it does not properly
# terminate because of a Ray Compiled Graph issue. # terminate because of a Ray Compiled Graph issue.
common_args.append("--disable-frontend-multiprocessing") common_args.append("--disable-frontend-multiprocessing")
elif distributed_backend == "mp":
# Both V0/V1 of multiprocessing executor support PP
pp_env = {
"VLLM_USE_V1": vllm_major_version,
}
else: else:
pp_env = None pp_env = None

View File

@ -1338,11 +1338,10 @@ class EngineArgs:
and _warn_or_fallback("Engine in background thread")): and _warn_or_fallback("Engine in background thread")):
return False return False
# PP is supported on V1 with Ray distributed executor,
# but off for MP distributed executor for now.
if (self.pipeline_parallel_size > 1 if (self.pipeline_parallel_size > 1
and self.distributed_executor_backend != "ray"): and self.distributed_executor_backend not in ["ray", "mp"]):
name = "Pipeline Parallelism without Ray distributed executor" name = "Pipeline Parallelism without Ray distributed executor " \
"or multiprocessing executor"
_raise_or_fallback(feature_name=name, recommend_to_remove=False) _raise_or_fallback(feature_name=name, recommend_to_remove=False)
return False return False

View File

@ -8,7 +8,7 @@ import threading
import time import time
import traceback import traceback
import weakref import weakref
from concurrent.futures import Future from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
@ -53,10 +53,11 @@ class MultiprocExecutor(Executor):
self.world_size = self.parallel_config.world_size self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size tensor_parallel_size = self.parallel_config.tensor_parallel_size
assert self.world_size == tensor_parallel_size, ( pp_parallel_size = self.parallel_config.pipeline_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
f"world_size ({self.world_size}) must be equal to the " f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}). " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"Pipeline parallelism is not yet implemented in v1") f"_parallel_size ({pp_parallel_size}). ")
# Set multiprocessing envs that are common to V0 and V1 # Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config) set_multiprocessing_worker_envs(self.parallel_config)
@ -104,6 +105,17 @@ class MultiprocExecutor(Executor):
self._ensure_worker_termination( self._ensure_worker_termination(
[w.proc for w in unready_workers]) [w.proc for w in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
self.output_rank = self._get_output_rank()
def start_worker_monitor(self): def start_worker_monitor(self):
workers = self.workers workers = self.workers
self_ref = weakref.ref(self) self_ref = weakref.ref(self)
@ -145,7 +157,9 @@ class MultiprocExecutor(Executor):
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc("execute_model", (output, ) = self.collective_rpc("execute_model",
args=(scheduler_output, ), args=(scheduler_output, ),
rank0_reply_only=True, unique_reply_rank=self.output_rank,
non_block=self.max_concurrent_batches
> 1,
timeout=EXECUTE_MODEL_TIMEOUT_S) timeout=EXECUTE_MODEL_TIMEOUT_S)
return output return output
@ -154,7 +168,8 @@ class MultiprocExecutor(Executor):
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: tuple = (), args: tuple = (),
kwargs: Optional[dict] = None, kwargs: Optional[dict] = None,
rank0_reply_only: bool = False) -> list[Any]: non_block: bool = False,
unique_reply_rank: Optional[int] = None) -> list[Any]:
if self.is_failed: if self.is_failed:
raise RuntimeError("Executor failed.") raise RuntimeError("Executor failed.")
@ -171,22 +186,35 @@ class MultiprocExecutor(Executor):
send_method = cloudpickle.dumps( send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL) method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue( self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, rank0_reply_only)) (send_method, args, kwargs, unique_reply_rank))
workers = (self.workers[0], ) if rank0_reply_only else self.workers workers = (self.workers[unique_reply_rank],
responses = [None] * len(workers) ) if unique_reply_rank is not None else self.workers
for w in workers: responses = []
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic()) def get_response(w: WorkerProcHandle,
dequeue_timeout: Optional[float] = None,
cancel_event: Optional[threading.Event] = None):
status, result = w.worker_response_mq.dequeue( status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=self.shutdown_event) timeout=dequeue_timeout, cancel=cancel_event)
if status != WorkerProc.ResponseStatus.SUCCESS: if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError( raise RuntimeError(
f"Worker failed with error '{result}', please check the" f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause") " stack trace above for the root cause")
return result
responses[w.rank] = result for w in workers:
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
if non_block:
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event)
else:
result = get_response(w, dequeue_timeout)
responses.append(result)
return responses return responses
except TimeoutError as e: except TimeoutError as e:
@ -225,6 +253,11 @@ class MultiprocExecutor(Executor):
if not getattr(self, 'shutting_down', False): if not getattr(self, 'shutting_down', False):
self.shutting_down = True self.shutting_down = True
self.shutdown_event.set() self.shutdown_event.set()
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
self.io_thread_pool = None
for w in self.workers: for w in self.workers:
w.worker_response_mq = None w.worker_response_mq = None
self._ensure_worker_termination([w.proc for w in self.workers]) self._ensure_worker_termination([w.proc for w in self.workers])
@ -235,6 +268,22 @@ class MultiprocExecutor(Executor):
self.collective_rpc("check_health", timeout=10) self.collective_rpc("check_health", timeout=10)
return return
@property
def max_concurrent_batches(self) -> int:
return self.parallel_config.pipeline_parallel_size
def _get_output_rank(self) -> int:
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
# (the first TP worker of the last PP stage).
# Example:
# Assuming TP=8, PP=4, then the world_size=32
# 0-7, PP rank 0
# 8-15, PP rank 1
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size
@dataclass @dataclass
class UnreadyWorkerProcHandle: class UnreadyWorkerProcHandle:
@ -280,12 +329,14 @@ class WorkerProc:
all_kwargs: list[dict] = [ all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size) {} for _ in range(vllm_config.parallel_config.world_size)
] ]
is_driver_worker = (
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
all_kwargs[rank] = { all_kwargs[rank] = {
"vllm_config": vllm_config, "vllm_config": vllm_config,
"local_rank": local_rank, "local_rank": local_rank,
"rank": rank, "rank": rank,
"distributed_init_method": distributed_init_method, "distributed_init_method": distributed_init_method,
"is_driver_worker": rank == 0, "is_driver_worker": is_driver_worker,
} }
wrapper.init_worker(all_kwargs) wrapper.init_worker(all_kwargs)
self.worker = wrapper self.worker = wrapper
@ -455,7 +506,7 @@ class WorkerProc:
def worker_busy_loop(self): def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers""" """Main busy loop for Multiprocessing Workers"""
while True: while True:
method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue() method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
try: try:
if isinstance(method, str): if isinstance(method, str):
@ -470,11 +521,11 @@ class WorkerProc:
logger.exception("WorkerProc hit an exception.") logger.exception("WorkerProc hit an exception.")
# exception might not be serializable, so we convert it to # exception might not be serializable, so we convert it to
# string, only for logging purpose. # string, only for logging purpose.
if not rank0_only or self.rank == 0: if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue( self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, str(e))) (WorkerProc.ResponseStatus.FAILURE, str(e)))
continue continue
if not rank0_only or self.rank == 0: if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue( self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output)) (WorkerProc.ResponseStatus.SUCCESS, output))

View File

@ -1016,7 +1016,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]: ) -> Union[ModelRunnerOutput, IntermediateTensors]:
# Update KVConnector with the KVConnector metadata forward(). # Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().bind_connector_metadata( get_kv_transfer_group().bind_connector_metadata(

View File

@ -15,11 +15,12 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
@ -266,7 +267,22 @@ class Worker(WorkerBase):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output) intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None return output if self.is_driver_worker else None
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):