diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 360fd79b55aad..d798311188c1e 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -28,8 +28,10 @@ Learn more about Ray placement groups: https://docs.ray.io/en/latest/placement-groups.html """ +import argparse import gc import os +import time import ray import torch @@ -156,96 +158,237 @@ class RayTrainingActor: gc.collect() torch.cuda.empty_cache() + def update_weights_async(self, num_handles: int): + align_size = 256 -# Ray manages four GPUs. + def get_size(p: torch.Tensor) -> int: + return (p.nbytes + align_size - 1) // align_size * align_size -os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" -ray.init() + named_parameters: dict[str, torch.nn.Parameter] = dict( + self.model.named_parameters() + ) + max_tensor_size = max(get_size(p) for p in named_parameters.values()) + # use max_tensor_size * 2 as buffer size + buffer_capacity = max_tensor_size * 2 + buffers = [ + torch.empty(buffer_capacity, dtype=torch.uint8, device="cuda:0") + for _ in range(num_handles) + ] + handles = [reduce_tensor(b) for b in buffers] -# Co-locate vLLM instances and training actors on the same set of GPUs: -# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0 -# (tensor parallelism = 2). -# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1 -# (tensor parallelism = 2). + # Establish ZMQ connection + s = self.zmq_context.socket(zmq.DEALER) + s.connect(self.zmq_handle) + s.send_pyobj(handles) + s.recv_multipart() # ACK for handles -pg = placement_group([{"GPU": 1, "CPU": 0}] * 4) -ray.get(pg.ready()) -print(f"placement group has bundles {pg.bundle_specs=}") + # === Partition tensors into buffers === + offset = 0 + buckets: list[tuple[list[dict], list[torch.Tensor]]] = [] + named_tensors: list[dict] = [] + real_tensors: list[torch.Tensor] = [] -training_actors = [] -training_actor_device_ids = [] -inference_engines = [] -inference_engine_device_ids = [] + for name, p in named_parameters.items(): + size = get_size(p) + if offset + size > buffer_capacity: + buckets.append((named_tensors, real_tensors)) + named_tensors, real_tensors = [], [] + offset = 0 + # assume tensors are contiguous + named_tensors.append( + {"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset} + ) + real_tensors.append(p) + offset += size + if named_tensors: + buckets.append((named_tensors, real_tensors)) -for bundle_index in [0, 1, 2, 3]: - training_actor = ray.remote( - num_cpus=0, - num_gpus=0.4, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=bundle_index, - ), - )(RayTrainingActor).remote() - training_actors.append(training_actor) + poller = zmq.Poller() + poller.register(s, zmq.POLLOUT) + free_buffers = list(range(num_handles)) + inflight = 0 + idx = 0 + total = len(buckets) + print(f"[Training] Total {total} buckets to send.") -for bundle_index, training_actor in enumerate(training_actors): - device_id = ray.get(training_actor.report_device_id.remote()) - print(f"training actor {bundle_index} is on {device_id}") - training_actor_device_ids.append(device_id) + # === Send loop === + while idx < total or inflight > 0: + events = dict(poller.poll(timeout=50)) + if ( + s in events + and (events[s] & zmq.POLLOUT) + and idx < total + and free_buffers + ): + buf_id = free_buffers.pop(0) + meta, tensors = buckets[idx] + buffer = buffers[buf_id] + offset = 0 + for item, t in zip(meta, tensors): + size = get_size(t) + buffer[offset : offset + t.nbytes].copy_( + t.contiguous().view(torch.uint8).flatten() + ) + offset += size + torch.cuda.synchronize() + s.send_pyobj((buf_id, meta)) + inflight += 1 + print(f"[Training] Sent bucket {idx} using buffer {buf_id}.") + idx += 1 -for i, bundle_indices in enumerate([[0, 1], [2, 3]]): - # Use the following syntax instead of the @ray.remote decorator so that - # the placement group is customized for each bundle. - llm = ray.remote( - num_cpus=0, - num_gpus=0, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_capture_child_tasks=True, - ), - )(MyLLM).remote( - model="facebook/opt-125m", - enforce_eager=True, - worker_extension_cls="rlhf_utils.ColocateWorkerExtension", - tensor_parallel_size=2, - distributed_executor_backend="ray", - gpu_memory_utilization=0.4, - bundle_indices=bundle_indices, + # Receive buffer-free ACKs + try: + while s.getsockopt(zmq.EVENTS) & zmq.POLLIN: + ack = int(s.recv_multipart(flags=zmq.NOBLOCK)[0].decode()) + free_buffers.append(ack) + inflight -= 1 + print(f"[Training] Ack received: buffer {ack} now free.") + except zmq.Again: + pass + + # Signal done + s.send_pyobj((None, None)) + s.recv_multipart() # DONE ack + s.close() + del buffers + gc.collect() + torch.cuda.empty_cache() + + +def setup_train_cluster(): + # Ray manages four GPUs. + + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + ray.init() + + # Co-locate vLLM instances and training actors on the same set of GPUs: + # * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0 + # (tensor parallelism = 2). + # * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1 + # (tensor parallelism = 2). + + pg = placement_group([{"GPU": 1, "CPU": 0}] * 4) + ray.get(pg.ready()) + print(f"placement group has bundles {pg.bundle_specs=}") + + training_actors = [] + training_actor_device_ids = [] + inference_engines = [] + inference_engine_device_ids = [] + + for bundle_index in [0, 1, 2, 3]: + training_actor = ray.remote( + num_cpus=0, + num_gpus=0.4, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_index, + ), + )(RayTrainingActor).remote() + training_actors.append(training_actor) + + for bundle_index, training_actor in enumerate(training_actors): + device_id = ray.get(training_actor.report_device_id.remote()) + print(f"training actor {bundle_index} is on {device_id}") + training_actor_device_ids.append(device_id) + + for i, bundle_indices in enumerate([[0, 1], [2, 3]]): + # Use the following syntax instead of the @ray.remote decorator so that + # the placement group is customized for each bundle. + llm = ray.remote( + num_cpus=0, + num_gpus=0, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_capture_child_tasks=True, + ), + )(MyLLM).remote( + model="facebook/opt-125m", + enforce_eager=True, + worker_extension_cls="rlhf_utils.ColocateWorkerExtension", + tensor_parallel_size=2, + distributed_executor_backend="ray", + gpu_memory_utilization=0.3, + bundle_indices=bundle_indices, + ) + inference_engines.append(llm) + # Do not call any method on the inference engine at this point; the call + # blocks until the vLLM instance finishes initialization. + + for i, llm in enumerate(inference_engines): + inference_engine_device_ids.append( + ray.get(llm.collective_rpc.remote("report_device_id", args=tuple())) + ) + print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") + + # Verify placement: the first two training actors share the same GPUs as + # the first inference engine. + assert training_actor_device_ids[:2] == inference_engine_device_ids[0] + # Verify placement: the last two training actors share the same GPUs as + # the second inference engine. + assert training_actor_device_ids[2:] == inference_engine_device_ids[1] + + print("Gather all the ZMQ handles from the training actors.") + zmq_handles = {} + for actor in training_actors: + zmq_handles.update(ray.get(actor.get_zmq_handles.remote())) + + print(f"ZMQ handles: {zmq_handles}") + return training_actors, inference_engines, zmq_handles + + +def main(): + parser = argparse.ArgumentParser( + description="Update model weights across training and inference actors." ) - inference_engines.append(llm) - # Do not call any method on the inference engine at this point; the call - # blocks until the vLLM instance finishes initialization. - -for i, llm in enumerate(inference_engines): - inference_engine_device_ids.append( - ray.get(llm.collective_rpc.remote("report_device_id", args=tuple())) + parser.add_argument( + "--num-ipc-handles", + type=int, + default=1, + help="Number of IPC handles. If 1, use synchronous update; \ + if >1, use asynchronous update.", ) - print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") + args = parser.parse_args() -# Verify placement: the first two training actors share the same GPUs as -# the first inference engine. -assert training_actor_device_ids[:2] == inference_engine_device_ids[0] -# Verify placement: the last two training actors share the same GPUs as -# the second inference engine. -assert training_actor_device_ids[2:] == inference_engine_device_ids[1] + num_handles = args.num_ipc_handles + training_actors, inference_engines, zmq_handles = setup_train_cluster() + print("Update the weights of the inference engines.") + start_time = time.time() + if num_handles == 1: + # Synchronous update + ray.get( + [actor.update_weights.remote() for actor in training_actors] + + [ + llm.collective_rpc.remote( + "update_weights_from_ipc", args=(zmq_handles,) + ) + for llm in inference_engines + ] + ) + else: + # Asynchronous update + ray.get( + [ + actor.update_weights_async.remote(num_handles=num_handles) + for actor in training_actors + ] + + [ + llm.collective_rpc.remote( + "update_weights_from_ipc_async", args=(zmq_handles,) + ) + for llm in inference_engines + ] + ) + end_time = time.time() + elapsed = end_time - start_time + print(f"Weight update completed in {elapsed:.2f} seconds.") + print("Check if the weights are updated.") + for llm in inference_engines: + assert ray.get( + llm.collective_rpc.remote("check_weights_changed", args=tuple()) + ), "Weights were not updated properly!" -print("Gather all the ZMQ handles from the training actors.") -zmq_handles = {} -for actor in training_actors: - zmq_handles.update(ray.get(actor.get_zmq_handles.remote())) -print(f"ZMQ handles: {zmq_handles}") - -print("Update the weights of the inference engines.") -ray.get( - [actor.update_weights.remote() for actor in training_actors] - + [ - llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,)) - for llm in inference_engines - ] -) - -print("Check if the weights are updated.") -for llm in inference_engines: - assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index 5c0787b8778d6..c1847a2c92977 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +from ast import Dict, Tuple from collections.abc import Callable +from enum import Enum from typing import TypedDict import torch @@ -92,6 +94,15 @@ class FlattenedTensorMetadata(TypedDict): offset: int +class PayloadType(Enum): + """Enumerates possible payload types in IPC protocol.""" + + HANDLES = "handles" + BUFFER_UPDATE = "buffer_update" + DONE = "done" + UNKNOWN = "unknown" + + class ColocateWorkerExtension: """ The class for vLLM's worker to inherit from, in the colocate setting. @@ -152,12 +163,87 @@ class ColocateWorkerExtension: gc.collect() torch.cuda.empty_cache() + def update_weights_from_ipc_async(self, zmq_handles: dict[str, str]): + assert self.device is not None + if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + socket = self._zmq_ctx.socket(zmq.ROUTER) + socket.bind(zmq_handles[self.report_device_id()]) + poller = zmq.Poller() + poller.register(socket, zmq.POLLIN) + + buffers: Dict[int, torch.Tensor] = {} + while True: + events = dict(poller.poll(timeout=100)) + if socket in events and (events[socket] & zmq.POLLIN): + # Router identity + identity = socket.recv() + + payload: ( + list[tuple[Callable, tuple]] + | tuple[int, list[FlattenedTensorMetadata]] + | None + ) = socket.recv_pyobj() + + payload_type = self._identify_payload_type(payload) + + # === HANDLE LIST OF SHARED MEMORY HANDLES === + if payload_type == PayloadType.HANDLES: + handles: list[tuple[Callable, tuple]] = payload + for i, h in enumerate(handles): + buffers[i] = rebuild_ipc(h, self.device.index) + socket.send_multipart([identity, b"ACK_HANDLES"]) + continue + + # === HANDLE BUFFERED MODEL UPDATES === + if payload_type == PayloadType.BUFFER_UPDATE: + buf_id, items = payload + buffer = buffers.get(buf_id) + if buffer is None: + continue + + weights: list[Tuple[str, torch.Tensor]] = [] + for item in items: + assert isinstance(item, dict) + shape = torch.Size(item["shape"]) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = ( + buffer[offset : offset + size].view(dtype=dtype).view(shape) + ) + weights.append((item["name"], tensor)) + + self.model_runner.model.load_weights(weights=weights) + torch.cuda.synchronize() + socket.send_multipart([identity, str(buf_id).encode()]) + + # === DONE SIGNAL === + elif payload_type == PayloadType.DONE: + socket.send_multipart([identity, b"DONE"]) + break + else: + continue + + socket.close() + gc.collect() + torch.cuda.empty_cache() + def report_device_id(self) -> str: from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(self.device.index) return self.device_uuid + def _identify_payload_type(self, payload) -> PayloadType: + if isinstance(payload, list): + return PayloadType.HANDLES + elif isinstance(payload, tuple): + buf_id, _ = payload + if buf_id is None: + return PayloadType.DONE + return PayloadType.BUFFER_UPDATE + return PayloadType.UNKNOWN + def check_weights_changed(self): """ Check if the weights are updated to 0.