mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 03:55:01 +08:00
[V1][PP] Support PP for MultiprocExecutor (#14219)
Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: jiang.li <jiang1.li@intel.com>
This commit is contained in:
parent
d419aa5dc4
commit
a6fed02068
@ -100,9 +100,8 @@ class PPTestSettings:
|
|||||||
eager_mode=True,
|
eager_mode=True,
|
||||||
chunked_prefill=False),
|
chunked_prefill=False),
|
||||||
],
|
],
|
||||||
# only ray is supported for V1
|
distributed_backends=["mp", "mp", "ray", "ray"],
|
||||||
distributed_backends=["mp", "ray", "ray"],
|
vllm_major_versions=["0", "1", "0", "1"],
|
||||||
vllm_major_versions=["0", "0", "1"],
|
|
||||||
task=task,
|
task=task,
|
||||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||||
load_format=load_format),
|
load_format=load_format),
|
||||||
@ -350,6 +349,11 @@ def _compare_tp(
|
|||||||
# Temporary. Currently when zeromq + SPMD is used, it does not properly
|
# Temporary. Currently when zeromq + SPMD is used, it does not properly
|
||||||
# terminate because of a Ray Compiled Graph issue.
|
# terminate because of a Ray Compiled Graph issue.
|
||||||
common_args.append("--disable-frontend-multiprocessing")
|
common_args.append("--disable-frontend-multiprocessing")
|
||||||
|
elif distributed_backend == "mp":
|
||||||
|
# Both V0/V1 of multiprocessing executor support PP
|
||||||
|
pp_env = {
|
||||||
|
"VLLM_USE_V1": vllm_major_version,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
pp_env = None
|
pp_env = None
|
||||||
|
|
||||||
|
|||||||
@ -1338,11 +1338,10 @@ class EngineArgs:
|
|||||||
and _warn_or_fallback("Engine in background thread")):
|
and _warn_or_fallback("Engine in background thread")):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# PP is supported on V1 with Ray distributed executor,
|
|
||||||
# but off for MP distributed executor for now.
|
|
||||||
if (self.pipeline_parallel_size > 1
|
if (self.pipeline_parallel_size > 1
|
||||||
and self.distributed_executor_backend != "ray"):
|
and self.distributed_executor_backend not in ["ray", "mp"]):
|
||||||
name = "Pipeline Parallelism without Ray distributed executor"
|
name = "Pipeline Parallelism without Ray distributed executor " \
|
||||||
|
"or multiprocessing executor"
|
||||||
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
|
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import weakref
|
import weakref
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -53,10 +53,11 @@ class MultiprocExecutor(Executor):
|
|||||||
|
|
||||||
self.world_size = self.parallel_config.world_size
|
self.world_size = self.parallel_config.world_size
|
||||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||||
assert self.world_size == tensor_parallel_size, (
|
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||||
|
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||||
f"world_size ({self.world_size}) must be equal to the "
|
f"world_size ({self.world_size}) must be equal to the "
|
||||||
f"tensor_parallel_size ({tensor_parallel_size}). "
|
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
||||||
f"Pipeline parallelism is not yet implemented in v1")
|
f"_parallel_size ({pp_parallel_size}). ")
|
||||||
|
|
||||||
# Set multiprocessing envs that are common to V0 and V1
|
# Set multiprocessing envs that are common to V0 and V1
|
||||||
set_multiprocessing_worker_envs(self.parallel_config)
|
set_multiprocessing_worker_envs(self.parallel_config)
|
||||||
@ -104,6 +105,17 @@ class MultiprocExecutor(Executor):
|
|||||||
self._ensure_worker_termination(
|
self._ensure_worker_termination(
|
||||||
[w.proc for w in unready_workers])
|
[w.proc for w in unready_workers])
|
||||||
|
|
||||||
|
# For pipeline parallel, we use a thread pool for asynchronous
|
||||||
|
# execute_model.
|
||||||
|
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||||
|
if self.max_concurrent_batches > 1:
|
||||||
|
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||||
|
# from the response queue
|
||||||
|
self.io_thread_pool = ThreadPoolExecutor(
|
||||||
|
max_workers=1, thread_name_prefix="mp_exec_io")
|
||||||
|
|
||||||
|
self.output_rank = self._get_output_rank()
|
||||||
|
|
||||||
def start_worker_monitor(self):
|
def start_worker_monitor(self):
|
||||||
workers = self.workers
|
workers = self.workers
|
||||||
self_ref = weakref.ref(self)
|
self_ref = weakref.ref(self)
|
||||||
@ -145,7 +157,9 @@ class MultiprocExecutor(Executor):
|
|||||||
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||||
(output, ) = self.collective_rpc("execute_model",
|
(output, ) = self.collective_rpc("execute_model",
|
||||||
args=(scheduler_output, ),
|
args=(scheduler_output, ),
|
||||||
rank0_reply_only=True,
|
unique_reply_rank=self.output_rank,
|
||||||
|
non_block=self.max_concurrent_batches
|
||||||
|
> 1,
|
||||||
timeout=EXECUTE_MODEL_TIMEOUT_S)
|
timeout=EXECUTE_MODEL_TIMEOUT_S)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -154,7 +168,8 @@ class MultiprocExecutor(Executor):
|
|||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: tuple = (),
|
args: tuple = (),
|
||||||
kwargs: Optional[dict] = None,
|
kwargs: Optional[dict] = None,
|
||||||
rank0_reply_only: bool = False) -> list[Any]:
|
non_block: bool = False,
|
||||||
|
unique_reply_rank: Optional[int] = None) -> list[Any]:
|
||||||
if self.is_failed:
|
if self.is_failed:
|
||||||
raise RuntimeError("Executor failed.")
|
raise RuntimeError("Executor failed.")
|
||||||
|
|
||||||
@ -171,22 +186,35 @@ class MultiprocExecutor(Executor):
|
|||||||
send_method = cloudpickle.dumps(
|
send_method = cloudpickle.dumps(
|
||||||
method, protocol=pickle.HIGHEST_PROTOCOL)
|
method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
self.rpc_broadcast_mq.enqueue(
|
self.rpc_broadcast_mq.enqueue(
|
||||||
(send_method, args, kwargs, rank0_reply_only))
|
(send_method, args, kwargs, unique_reply_rank))
|
||||||
|
|
||||||
workers = (self.workers[0], ) if rank0_reply_only else self.workers
|
workers = (self.workers[unique_reply_rank],
|
||||||
responses = [None] * len(workers)
|
) if unique_reply_rank is not None else self.workers
|
||||||
for w in workers:
|
responses = []
|
||||||
dequeue_timeout = None if deadline is None else (
|
|
||||||
deadline - time.monotonic())
|
def get_response(w: WorkerProcHandle,
|
||||||
|
dequeue_timeout: Optional[float] = None,
|
||||||
|
cancel_event: Optional[threading.Event] = None):
|
||||||
status, result = w.worker_response_mq.dequeue(
|
status, result = w.worker_response_mq.dequeue(
|
||||||
timeout=dequeue_timeout, cancel=self.shutdown_event)
|
timeout=dequeue_timeout, cancel=cancel_event)
|
||||||
|
|
||||||
if status != WorkerProc.ResponseStatus.SUCCESS:
|
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Worker failed with error '{result}', please check the"
|
f"Worker failed with error '{result}', please check the"
|
||||||
" stack trace above for the root cause")
|
" stack trace above for the root cause")
|
||||||
|
return result
|
||||||
|
|
||||||
responses[w.rank] = result
|
for w in workers:
|
||||||
|
dequeue_timeout = None if deadline is None else (
|
||||||
|
deadline - time.monotonic())
|
||||||
|
|
||||||
|
if non_block:
|
||||||
|
result = self.io_thread_pool.submit( # type: ignore
|
||||||
|
get_response, w, dequeue_timeout, self.shutdown_event)
|
||||||
|
else:
|
||||||
|
result = get_response(w, dequeue_timeout)
|
||||||
|
|
||||||
|
responses.append(result)
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
except TimeoutError as e:
|
except TimeoutError as e:
|
||||||
@ -225,6 +253,11 @@ class MultiprocExecutor(Executor):
|
|||||||
if not getattr(self, 'shutting_down', False):
|
if not getattr(self, 'shutting_down', False):
|
||||||
self.shutting_down = True
|
self.shutting_down = True
|
||||||
self.shutdown_event.set()
|
self.shutdown_event.set()
|
||||||
|
|
||||||
|
if self.io_thread_pool is not None:
|
||||||
|
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
|
||||||
|
self.io_thread_pool = None
|
||||||
|
|
||||||
for w in self.workers:
|
for w in self.workers:
|
||||||
w.worker_response_mq = None
|
w.worker_response_mq = None
|
||||||
self._ensure_worker_termination([w.proc for w in self.workers])
|
self._ensure_worker_termination([w.proc for w in self.workers])
|
||||||
@ -235,6 +268,22 @@ class MultiprocExecutor(Executor):
|
|||||||
self.collective_rpc("check_health", timeout=10)
|
self.collective_rpc("check_health", timeout=10)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_concurrent_batches(self) -> int:
|
||||||
|
return self.parallel_config.pipeline_parallel_size
|
||||||
|
|
||||||
|
def _get_output_rank(self) -> int:
|
||||||
|
# Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1
|
||||||
|
# (the first TP worker of the last PP stage).
|
||||||
|
# Example:
|
||||||
|
# Assuming TP=8, PP=4, then the world_size=32
|
||||||
|
# 0-7, PP rank 0
|
||||||
|
# 8-15, PP rank 1
|
||||||
|
# 16-23, PP rank 2
|
||||||
|
# 24-31, PP rank 3
|
||||||
|
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
|
||||||
|
return self.world_size - self.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnreadyWorkerProcHandle:
|
class UnreadyWorkerProcHandle:
|
||||||
@ -280,12 +329,14 @@ class WorkerProc:
|
|||||||
all_kwargs: list[dict] = [
|
all_kwargs: list[dict] = [
|
||||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||||
]
|
]
|
||||||
|
is_driver_worker = (
|
||||||
|
rank % vllm_config.parallel_config.tensor_parallel_size == 0)
|
||||||
all_kwargs[rank] = {
|
all_kwargs[rank] = {
|
||||||
"vllm_config": vllm_config,
|
"vllm_config": vllm_config,
|
||||||
"local_rank": local_rank,
|
"local_rank": local_rank,
|
||||||
"rank": rank,
|
"rank": rank,
|
||||||
"distributed_init_method": distributed_init_method,
|
"distributed_init_method": distributed_init_method,
|
||||||
"is_driver_worker": rank == 0,
|
"is_driver_worker": is_driver_worker,
|
||||||
}
|
}
|
||||||
wrapper.init_worker(all_kwargs)
|
wrapper.init_worker(all_kwargs)
|
||||||
self.worker = wrapper
|
self.worker = wrapper
|
||||||
@ -455,7 +506,7 @@ class WorkerProc:
|
|||||||
def worker_busy_loop(self):
|
def worker_busy_loop(self):
|
||||||
"""Main busy loop for Multiprocessing Workers"""
|
"""Main busy loop for Multiprocessing Workers"""
|
||||||
while True:
|
while True:
|
||||||
method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue()
|
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(method, str):
|
if isinstance(method, str):
|
||||||
@ -470,11 +521,11 @@ class WorkerProc:
|
|||||||
logger.exception("WorkerProc hit an exception.")
|
logger.exception("WorkerProc hit an exception.")
|
||||||
# exception might not be serializable, so we convert it to
|
# exception might not be serializable, so we convert it to
|
||||||
# string, only for logging purpose.
|
# string, only for logging purpose.
|
||||||
if not rank0_only or self.rank == 0:
|
if output_rank is None or self.rank == output_rank:
|
||||||
self.worker_response_mq.enqueue(
|
self.worker_response_mq.enqueue(
|
||||||
(WorkerProc.ResponseStatus.FAILURE, str(e)))
|
(WorkerProc.ResponseStatus.FAILURE, str(e)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not rank0_only or self.rank == 0:
|
if output_rank is None or self.rank == output_rank:
|
||||||
self.worker_response_mq.enqueue(
|
self.worker_response_mq.enqueue(
|
||||||
(WorkerProc.ResponseStatus.SUCCESS, output))
|
(WorkerProc.ResponseStatus.SUCCESS, output))
|
||||||
|
|||||||
@ -1016,7 +1016,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||||
# Update KVConnector with the KVConnector metadata forward().
|
# Update KVConnector with the KVConnector metadata forward().
|
||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
get_kv_transfer_group().bind_connector_metadata(
|
get_kv_transfer_group().bind_connector_metadata(
|
||||||
|
|||||||
@ -15,11 +15,12 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import GiB_bytes
|
from vllm.utils import GiB_bytes
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
@ -266,7 +267,22 @@ class Worker(WorkerBase):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> Optional[ModelRunnerOutput]:
|
) -> Optional[ModelRunnerOutput]:
|
||||||
output = self.model_runner.execute_model(scheduler_output)
|
intermediate_tensors = None
|
||||||
|
if not get_pp_group().is_first_rank:
|
||||||
|
intermediate_tensors = IntermediateTensors(
|
||||||
|
get_pp_group().recv_tensor_dict(
|
||||||
|
all_gather_group=get_tp_group()))
|
||||||
|
|
||||||
|
output = self.model_runner.execute_model(scheduler_output,
|
||||||
|
intermediate_tensors)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
assert isinstance(output, IntermediateTensors)
|
||||||
|
get_pp_group().send_tensor_dict(output.tensors,
|
||||||
|
all_gather_group=get_tp_group())
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert isinstance(output, ModelRunnerOutput)
|
||||||
return output if self.is_driver_worker else None
|
return output if self.is_driver_worker else None
|
||||||
|
|
||||||
def profile(self, is_start: bool = True):
|
def profile(self, is_start: bool = True):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user