mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 11:58:43 +08:00
Introduce RayPPCommunicator for ray-based PP (#21660)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
parent
9659bc7f27
commit
d331759488
257
vllm/distributed/device_communicators/ray_communicator.py
Normal file
257
vllm/distributed/device_communicators/ray_communicator.py
Normal file
@ -0,0 +1,257 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.exceptions import RayChannelError
|
||||
from ray.experimental.channel.communicator import (Communicator,
|
||||
TorchTensorAllocator)
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RayPPCommunicator(Communicator):
|
||||
"""
|
||||
Communicator to be used for pipeline parallelism in Ray Compiled Graph.
|
||||
This is wraps around the vLLM _PP GroupCoordinator.
|
||||
|
||||
This class is not thread-safe.
|
||||
"""
|
||||
|
||||
_comm: Optional[DeviceCommunicatorBase]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size: int,
|
||||
comm_id: Any,
|
||||
rank: Optional[int],
|
||||
actor_handles: list["ray.actor.ActorHandle"],
|
||||
cuda_stream: Optional[torch.cuda.Stream],
|
||||
use_communication_streams: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize a RayPPCommunicator that can be used to communicate with
|
||||
other Ray Compiled Graph actors for pipeline parallelism.
|
||||
|
||||
Args:
|
||||
world_size: The number of participating actors.
|
||||
comm_id: A unique communicator ID. This is just to conform with
|
||||
the Ray Communicator API and is not used.
|
||||
rank: The rank of this actor. If None, then the caller is not a
|
||||
participant of the RayPPCommunicator group (e.g., the Ray
|
||||
driver).
|
||||
actor_handles: A list of actor handles.
|
||||
cuda_stream: A CUDA stream to dispatch communication ops to. This
|
||||
is not supported.
|
||||
use_communication_streams: Whether to use communication streams.
|
||||
This is not supported.
|
||||
"""
|
||||
self._world_size = world_size
|
||||
self._rank: Optional[int] = None
|
||||
self._actor_handles = actor_handles
|
||||
if use_communication_streams:
|
||||
raise NotImplementedError(
|
||||
"use_communication_streams is not supported")
|
||||
if cuda_stream is not None and cuda_stream != current_stream():
|
||||
raise ValueError(
|
||||
"cuda_stream other than the current stream is not supported")
|
||||
|
||||
if rank is not None:
|
||||
# Rank is not None, this is Ray worker
|
||||
assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned"
|
||||
|
||||
self._comm = get_pp_group().device_communicator
|
||||
|
||||
# Since we wrap around the vLLM _PP communicator, we use
|
||||
# the rank from the vLLM communicator, and ignore the rank
|
||||
# passed in from Ray.
|
||||
# TODO(rui): refactor the Ray Communicator API so that
|
||||
# it also supports no rank passed in.
|
||||
self._rank = self._comm.rank_in_group
|
||||
|
||||
self._build_actor_rank_mapping()
|
||||
else:
|
||||
# Rank is None, this is Ray driver
|
||||
self._comm = None
|
||||
|
||||
self._closed = False
|
||||
|
||||
def _build_actor_rank_mapping(self):
|
||||
"""
|
||||
Use collective communication to build a mapping from actor IDs to ranks.
|
||||
This should be called once during initialization.
|
||||
"""
|
||||
if self._comm is None:
|
||||
return {}
|
||||
|
||||
current_actor = ray.get_runtime_context().current_actor
|
||||
actor_id_str = current_actor._actor_id.hex()
|
||||
|
||||
# Ray actor IDs are 32-character hex strings (128 bits)
|
||||
ACTOR_ID_LEN = 32
|
||||
actor_id_bytes = actor_id_str.encode('utf-8')
|
||||
assert len(
|
||||
actor_id_bytes
|
||||
) == ACTOR_ID_LEN, f"Unexpected actor ID length: {len(actor_id_bytes)}"
|
||||
|
||||
actor_id_tensor = torch.frombuffer(
|
||||
actor_id_bytes, dtype=torch.uint8).to(self._comm.device)
|
||||
|
||||
# All-gather full actor IDs from all actors
|
||||
gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0)
|
||||
|
||||
# Build mapping: actor_id -> device_comm_rank
|
||||
self._actor_id_to_rank = {}
|
||||
for rank in range(self._world_size):
|
||||
start_idx = rank * ACTOR_ID_LEN
|
||||
end_idx = (rank + 1) * ACTOR_ID_LEN
|
||||
actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy(
|
||||
).tobytes()
|
||||
actor_id = actor_bytes.decode('utf-8')
|
||||
self._actor_id_to_rank[actor_id] = rank
|
||||
|
||||
def initialize(self, rank: int) -> None:
|
||||
# No additional initialization is needed.
|
||||
pass
|
||||
|
||||
def get_actor_handles(self) -> list["ray.actor.ActorHandle"]:
|
||||
return self._actor_handles
|
||||
|
||||
def get_rank(self, actor: ray.actor.ActorHandle) -> int:
|
||||
"""
|
||||
Return the given actor's rank using device communicator collective ops.
|
||||
"""
|
||||
assert hasattr(self, '_actor_id_to_rank'), (
|
||||
"Actor rank mapping not built. "
|
||||
"This should have been done during initialization.")
|
||||
|
||||
actor_id_str = actor._actor_id.hex()
|
||||
|
||||
if actor_id_str in self._actor_id_to_rank:
|
||||
return self._actor_id_to_rank[actor_id_str] # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Actor {actor} not found in communicator group")
|
||||
|
||||
def get_self_rank(self) -> Optional[int]:
|
||||
"""
|
||||
Return this actor's rank.
|
||||
"""
|
||||
return self._rank
|
||||
|
||||
def get_world_size(self) -> int:
|
||||
"""
|
||||
Return the number of ranks in the RayPPCommunicator group.
|
||||
"""
|
||||
return self._world_size
|
||||
|
||||
def send(self, buf: "torch.Tensor", peer_rank: int) -> None:
|
||||
"""
|
||||
Send a torch.Tensor to a peer.
|
||||
|
||||
This returns when the send kernel has been queued, but the kernel may
|
||||
not have completed. Therefore, the caller should ensure that there are
|
||||
no concurrent writes to the sent `buf` until the send has finished.
|
||||
That is, either all writes should be submitted on the current stream
|
||||
(self._cuda_stream) or, if on a different stream, that stream should
|
||||
synchronize with the current stream.
|
||||
|
||||
Args:
|
||||
buf: The torch.Tensor to send. It should already be on this
|
||||
actor's default device.
|
||||
peer_rank: The rank of the actor to send to.
|
||||
"""
|
||||
if self._closed:
|
||||
raise RayChannelError("RayPPCommunicator has been destroyed.")
|
||||
|
||||
assert self._comm is not None
|
||||
self._comm.send(buf, peer_rank)
|
||||
|
||||
def recv(
|
||||
self,
|
||||
shape: tuple[int],
|
||||
dtype: "torch.dtype",
|
||||
peer_rank: int,
|
||||
allocator: TorchTensorAllocator,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Receive a torch.Tensor from a peer and synchronize the current stream.
|
||||
|
||||
After this call returns, the receive buffer is safe to read from from
|
||||
any stream. An RayChannelError will be raised if an error occurred
|
||||
(e.g., remote actor died), and the buffer is not safe to read.
|
||||
|
||||
Args:
|
||||
shape: The shape of the tensor to receive.
|
||||
dtype: The dtype of the tensor to receive.
|
||||
peer_rank: The rank of the actor to receive from.
|
||||
allocator: The allocator to use to create the received tensor.
|
||||
This is ignored for this implementation.
|
||||
"""
|
||||
if self._closed:
|
||||
raise RayChannelError("RayPPCommunicator has been destroyed.")
|
||||
|
||||
assert self._comm is not None
|
||||
size = torch.Size(shape)
|
||||
buf = self._comm.recv(size, dtype, src=peer_rank)
|
||||
|
||||
# Buffer values are undefined if NCCL ops are aborted. Therefore, we
|
||||
# need to synchronize here and check that the channel is still
|
||||
# open to ensure that the receive buffer is valid.
|
||||
# TODO(swang): Avoid CUDA synchronization.
|
||||
current_stream().synchronize()
|
||||
|
||||
if self._closed:
|
||||
raise RayChannelError("RayPPCommunicator has been destroyed.")
|
||||
return buf
|
||||
|
||||
def allgather(
|
||||
self,
|
||||
send_buf: "torch.Tensor",
|
||||
recv_buf: "torch.Tensor",
|
||||
):
|
||||
raise NotImplementedError("allgather is not supported")
|
||||
|
||||
def allreduce(
|
||||
self,
|
||||
send_buf: "torch.Tensor",
|
||||
recv_buf: "torch.Tensor",
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
):
|
||||
raise NotImplementedError("allreduce is not supported")
|
||||
|
||||
def reducescatter(
|
||||
self,
|
||||
send_buf: "torch.Tensor",
|
||||
recv_buf: "torch.Tensor",
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
):
|
||||
raise NotImplementedError("reducescatter is not supported")
|
||||
|
||||
@property
|
||||
def recv_stream(self):
|
||||
return torch.cuda.StreamContext(current_stream())
|
||||
|
||||
@property
|
||||
def send_stream(self):
|
||||
return torch.cuda.StreamContext(current_stream())
|
||||
|
||||
def destroy(self) -> None:
|
||||
# Just sets a flag, vLLM manages the lifecycle of the underlying
|
||||
# _PP GroupCoordinator.
|
||||
self._closed = True
|
||||
|
||||
def get_transport_name(self) -> str:
|
||||
return "nccl"
|
||||
|
||||
@classmethod
|
||||
def generate_communicator_id(cls) -> Any:
|
||||
return uuid.uuid4()
|
||||
@ -55,6 +55,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
||||
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "auto"
|
||||
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
|
||||
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
|
||||
VLLM_XLA_USE_SPMD: bool = False
|
||||
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
|
||||
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
|
||||
@ -498,6 +499,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0"))
|
||||
),
|
||||
|
||||
# If the env var is set, it uses a Ray Communicator wrapping
|
||||
# vLLM's pipeline parallelism communicator to interact with Ray's
|
||||
# Compiled Graph. Otherwise, it uses Ray's NCCL communicator.
|
||||
# This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set.
|
||||
"VLLM_USE_RAY_WRAPPED_PP_COMM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1"))),
|
||||
|
||||
# Use dedicated multiprocess context for workers.
|
||||
# Both spawn and fork work
|
||||
"VLLM_WORKER_MULTIPROC_METHOD":
|
||||
|
||||
@ -608,6 +608,21 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
|
||||
forward_dag = MultiOutputNode(outputs)
|
||||
|
||||
if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
|
||||
from ray.experimental.channel.accelerator_context import (
|
||||
register_accelerator_context)
|
||||
|
||||
from vllm.distributed.device_communicators.ray_communicator import (
|
||||
RayPPCommunicator)
|
||||
register_accelerator_context(torch_module_name="cuda",
|
||||
communicator_cls=RayPPCommunicator)
|
||||
logger.info("Using RayPPCommunicator "
|
||||
"(which wraps vLLM _PP GroupCoordinator) "
|
||||
"for Ray Compiled Graph communication.")
|
||||
else:
|
||||
logger.info("Using Ray's NCCL communicator for "
|
||||
"Ray Compiled Graph communication.")
|
||||
|
||||
return forward_dag.experimental_compile(
|
||||
enable_asyncio=enable_asyncio,
|
||||
_overlap_gpu_communication=envs.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user