From 3d2a2de8f75b973927d63b4cab63d9abb1a24722 Mon Sep 17 00:00:00 2001 From: Weixiao Huang Date: Tue, 9 Sep 2025 16:57:46 +0800 Subject: [PATCH] [RL] fast weight update with zmq + ipc handles (#24295) Signed-off-by: huangweixiao Signed-off-by: youkaichao Co-authored-by: youkaichao --- examples/offline_inference/rlhf_colocate.py | 95 +++++++++++++++++---- examples/offline_inference/rlhf_utils.py | 90 +++++++++++++++---- 2 files changed, 152 insertions(+), 33 deletions(-) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 65621023ab6c..360fd79b55aa 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -28,12 +28,15 @@ Learn more about Ray placement groups: https://docs.ray.io/en/latest/placement-groups.html """ +import gc import os import ray import torch +import zmq from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from torch.multiprocessing.reductions import reduce_tensor from vllm import LLM @@ -86,20 +89,72 @@ class RayTrainingActor: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(0) + self.zmq_context = zmq.Context() + self.zmq_address_counter = 0 + self.zmq_handle = None def report_device_id(self) -> str: return self.device_uuid - def get_weight_ipc_handles(self): - from torch.multiprocessing.reductions import reduce_tensor + def get_zmq_handles(self) -> dict[str, str]: + suffix = f"{self.device_uuid}-{self.zmq_address_counter}" + self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock" + self.zmq_address_counter += 1 + return {self.device_uuid: self.zmq_handle} - data = {} - for name, p in self.model.named_parameters(): - # A training actor might hold only a subset of the weights and may - # need to gather weights from other actors. For demonstration - # purposes, each training actor owns the full weight set. - data[name] = reduce_tensor(p.detach()) - return {self.device_uuid: data} + def update_weights(self): + # align size to avoid misaligned address + align_size = 256 + + def get_size(p: torch.Tensor) -> int: + return (p.nbytes + align_size - 1) // align_size * align_size + + named_parameters: dict[str, torch.nn.Parameter] = dict( + self.model.named_parameters() + ) + max_tensor_size = max(get_size(p) for p in named_parameters.values()) + # use max_tensor_size * 2 as buffer size + buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0") + s = self.zmq_context.socket(zmq.REQ) + s.bind(self.zmq_handle) + handle = reduce_tensor(buffer) + + offset = 0 + buckets: list[tuple[list[dict], list[torch.Tensor]]] = [] + named_tensors: list[dict] = [] + real_tensors: list[torch.Tensor] = [] + for name, p in named_parameters.items(): + size = get_size(p) + if offset + size > buffer.numel(): + buckets.append((named_tensors, real_tensors)) + named_tensors, real_tensors = [], [] + offset = 0 + # assume tensors are contiguous + named_tensors.append( + {"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset} + ) + real_tensors.append(p) + offset += size + if named_tensors: + buckets.append((named_tensors, real_tensors)) + s.send_pyobj(handle) + s.recv() + for named_tensors, real_tensors in buckets: + offset = 0 + for p in real_tensors: + buffer[offset : offset + p.nbytes].data.copy_( + p.data.view(-1).view(dtype=torch.uint8), non_blocking=True + ) + offset += get_size(p) + torch.cuda.synchronize() + s.send_pyobj(named_tensors) + s.recv() + s.send_pyobj(None) + s.recv() + s.close() + del buffer + gc.collect() + torch.cuda.empty_cache() # Ray manages four GPUs. @@ -175,18 +230,22 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0] # the second inference engine. assert training_actor_device_ids[2:] == inference_engine_device_ids[1] -print("Gather all the IPC handles from the training actors.") -ipc_handles = {} +print("Gather all the ZMQ handles from the training actors.") +zmq_handles = {} for actor in training_actors: - ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) + zmq_handles.update(ray.get(actor.get_zmq_handles.remote())) + +print(f"ZMQ handles: {zmq_handles}") print("Update the weights of the inference engines.") -for llm in inference_engines: - ray.get( - llm.collective_rpc.remote( - "update_weights_from_ipc_handles", args=(ipc_handles,) - ) - ) +ray.get( + [actor.update_weights.remote() for actor in training_actors] + + [ + llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,)) + for llm in inference_engines + ] +) + print("Check if the weights are updated.") for llm in inference_engines: assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index d2a8419ffabc..c0e60b979340 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -1,6 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +from typing import Callable, Optional, TypedDict + import torch +import zmq def stateless_init_process_group(master_address, master_port, rank, world_size, device): @@ -66,6 +70,27 @@ class WorkerExtension: return weights_updated +def rebuild_ipc( + handle: tuple[Callable, tuple], device_id: Optional[int] = None +) -> torch.Tensor: + func, args = handle + list_args = list(args) + if device_id is not None: + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + buffer = func(*list_args) + return buffer + + +class FlattenedTensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + # specify the start offset of this tensor in shared ipc_buffer tensor + offset: int + + class ColocateWorkerExtension: """ The class for vLLM's worker to inherit from, in the colocate setting. @@ -76,27 +101,62 @@ class ColocateWorkerExtension: should pass the full qualified name as `worker_extension_cls` argument. """ + def update_weights_from_ipc(self, zmq_handles: dict[str, str]): + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + assert self.device is not None + if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + socket = self._zmq_ctx.socket(zmq.REP) + socket.connect(zmq_handles[self.report_device_id()]) + buffer: Optional[torch.Tensor] = None + while True: + payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = ( + socket.recv_pyobj() + ) + if payload is None: + # means the update is done + process_weights_after_loading( + self.model_runner.model, self.model_config, self.device + ) + torch.cuda.synchronize() + socket.send(b"") + break + if isinstance(payload, tuple): + # an ipc handle that vLLM can use `func, args = handle` + # and `func(*args)` to rebuild GPU tensor. + buffer = rebuild_ipc(payload, self.device.index) + assert buffer.dtype == torch.uint8 + socket.send(b"") + continue + assert isinstance(payload, list) + assert buffer is not None + weights = [] + for item in payload: + shape = item["shape"] + if isinstance(shape, (list, tuple)): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + weights.append((item["name"], tensor)) + self.model_runner.model.load_weights(weights=weights) + del weights + torch.cuda.synchronize() + socket.send(b"") + + socket.close() + del buffer + gc.collect() + torch.cuda.empty_cache() + def report_device_id(self) -> str: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(self.device.index) return self.device_uuid - def update_weights_from_ipc_handles(self, ipc_handles): - handles = ipc_handles[self.device_uuid] - device_id = self.device.index - weights = [] - for name, handle in handles.items(): - func, args = handle - list_args = list(args) - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) - self.model_runner.model.load_weights(weights=weights) - torch.cuda.synchronize() - def check_weights_changed(self): """ Check if the weights are updated to 0.