[RL] Multi IPC handles example for rlhf colocated

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
knlnguyen1802 2025-11-13 10:00:30 +08:00
parent 1761dea1a8
commit bdf34c1265
2 changed files with 312 additions and 80 deletions

View File

@ -28,6 +28,7 @@ Learn more about Ray placement groups:
https://docs.ray.io/en/latest/placement-groups.html
"""
import argparse
import gc
import os
@ -78,7 +79,9 @@ class RayTrainingActor:
# Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor.
from transformers import AutoModelForCausalLM
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
self.model = AutoModelForCausalLM.from_pretrained(
"/mnt/nvme3n1/models/qwen2.5_7B"
)
self.model.to("cuda:0")
# Zero out all the parameters.
for name, p in self.model.named_parameters():
@ -156,96 +159,235 @@ class RayTrainingActor:
gc.collect()
torch.cuda.empty_cache()
def update_weights_async(self, num_handles: int):
align_size = 256
# Ray manages four GPUs.
def get_size(p: torch.Tensor) -> int:
return (p.nbytes + align_size - 1) // align_size * align_size
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.init()
named_parameters: dict[str, torch.nn.Parameter] = dict(
self.model.named_parameters()
)
max_tensor_size = max(get_size(p) for p in named_parameters.values())
# use max_tensor_size * 2 as buffer size
buffer_capacity = max_tensor_size * 2
buffers = [
torch.empty(buffer_capacity, dtype=torch.uint8, device="cuda:0")
for _ in range(num_handles)
]
handles = [reduce_tensor(b) for b in buffers]
# Co-locate vLLM instances and training actors on the same set of GPUs:
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
# (tensor parallelism = 2).
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
# (tensor parallelism = 2).
# Establish ZMQ connection
s = self.zmq_context.socket(zmq.DEALER)
s.connect(self.zmq_handle)
s.send_pyobj(handles)
_ = s.recv_pyobj()
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
ray.get(pg.ready())
print(f"placement group has bundles {pg.bundle_specs=}")
# === Partition tensors into buffers ===
offset = 0
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
named_tensors: list[dict] = []
real_tensors: list[torch.Tensor] = []
training_actors = []
training_actor_device_ids = []
inference_engines = []
inference_engine_device_ids = []
for name, p in named_parameters.items():
size = get_size(p)
if offset + size > buffer_capacity:
buckets.append((named_tensors, real_tensors))
named_tensors, real_tensors = [], []
offset = 0
# assume tensors are contiguous
named_tensors.append(
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
)
real_tensors.append(p)
offset += size
if named_tensors:
buckets.append((named_tensors, real_tensors))
for bundle_index in [0, 1, 2, 3]:
training_actor = ray.remote(
num_cpus=0,
num_gpus=0.4,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_index,
),
)(RayTrainingActor).remote()
training_actors.append(training_actor)
poller = zmq.Poller()
poller.register(s, zmq.POLLOUT)
free_buffers = list(range(num_handles))
inflight = 0
idx = 0
total = len(buckets)
print(f"[Training] Total {total} buckets to send.")
for bundle_index, training_actor in enumerate(training_actors):
device_id = ray.get(training_actor.report_device_id.remote())
print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id)
# === Send loop ===
while idx < total or inflight > 0:
events = dict(poller.poll(timeout=50))
if (
s in events
and (events[s] & zmq.POLLOUT)
and idx < total
and free_buffers
):
buf_id = free_buffers.pop(0)
meta, tensors = buckets[idx]
buffer = buffers[buf_id]
offset = 0
for item, t in zip(meta, tensors):
size = get_size(t)
buffer[offset : offset + t.nbytes].copy_(
t.contiguous().view(torch.uint8).flatten()
)
offset += size
torch.cuda.synchronize()
s.send_pyobj((buf_id, meta))
inflight += 1
print(f"[Training] Sent bucket {idx} using buffer {buf_id}.")
idx += 1
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# Use the following syntax instead of the @ray.remote decorator so that
# the placement group is customized for each bundle.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
),
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
gpu_memory_utilization=0.4,
bundle_indices=bundle_indices,
# Receive buffer-free ACKs
try:
while s.getsockopt(zmq.EVENTS) & zmq.POLLIN:
ack = s.recv_pyobj(flags=zmq.NOBLOCK)
if isinstance(ack, int):
free_buffers.append(ack)
inflight -= 1
print(f"[Training] Ack received: buffer {ack} now free.")
except zmq.Again:
pass
# Signal done
s.send_pyobj((None, None))
_ = s.recv_pyobj()
s.close()
del buffers
gc.collect()
torch.cuda.empty_cache()
def setup_train_cluster():
# Ray manages four GPUs.
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.init()
# Co-locate vLLM instances and training actors on the same set of GPUs:
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
# (tensor parallelism = 2).
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
# (tensor parallelism = 2).
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
ray.get(pg.ready())
print(f"placement group has bundles {pg.bundle_specs=}")
training_actors = []
training_actor_device_ids = []
inference_engines = []
inference_engine_device_ids = []
for bundle_index in [0, 1, 2, 3]:
training_actor = ray.remote(
num_cpus=0,
num_gpus=0.4,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_index,
),
)(RayTrainingActor).remote()
training_actors.append(training_actor)
for bundle_index, training_actor in enumerate(training_actors):
device_id = ray.get(training_actor.report_device_id.remote())
print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id)
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# Use the following syntax instead of the @ray.remote decorator so that
# the placement group is customized for each bundle.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
),
)(MyLLM).remote(
model="/mnt/nvme3n1/models/qwen2.5_7B",
enforce_eager=True,
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
gpu_memory_utilization=0.3,
bundle_indices=bundle_indices,
)
inference_engines.append(llm)
# Do not call any method on the inference engine at this point; the call
# blocks until the vLLM instance finishes initialization.
for i, llm in enumerate(inference_engines):
inference_engine_device_ids.append(
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
)
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
# Verify placement: the first two training actors share the same GPUs as
# the first inference engine.
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# Verify placement: the last two training actors share the same GPUs as
# the second inference engine.
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
print("Gather all the ZMQ handles from the training actors.")
zmq_handles = {}
for actor in training_actors:
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
print(f"ZMQ handles: {zmq_handles}")
return training_actors, inference_engines, zmq_handles
def main():
parser = argparse.ArgumentParser(
description="Update model weights across training and inference actors."
)
inference_engines.append(llm)
# Do not call any method on the inference engine at this point; the call
# blocks until the vLLM instance finishes initialization.
for i, llm in enumerate(inference_engines):
inference_engine_device_ids.append(
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
parser.add_argument(
"--num-ipc-handles",
type=int,
default=1,
help="Number of IPC handles. If 1, use synchronous update; \
if >1, use asynchronous update.",
)
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
args = parser.parse_args()
# Verify placement: the first two training actors share the same GPUs as
# the first inference engine.
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# Verify placement: the last two training actors share the same GPUs as
# the second inference engine.
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
num_handles = args.num_ipc_handles
training_actors, inference_engines, zmq_handles = setup_train_cluster()
print("Update the weights of the inference engines.")
if num_handles == 1:
# Synchronous update
ray.get(
[actor.update_weights.remote() for actor in training_actors]
+ [
llm.collective_rpc.remote(
"update_weights_from_ipc", args=(zmq_handles,)
)
for llm in inference_engines
]
)
else:
# Asynchronous update
ray.get(
[
actor.update_weights_async.remote(num_handles=num_handles)
for actor in training_actors
]
+ [
llm.collective_rpc.remote(
"update_weights_from_ipc_async", args=(zmq_handles,)
)
for llm in inference_engines
]
)
print("Gather all the ZMQ handles from the training actors.")
zmq_handles = {}
for actor in training_actors:
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(
llm.collective_rpc.remote("check_weights_changed", args=tuple())
), "Weights were not updated properly!"
print(f"ZMQ handles: {zmq_handles}")
print("Update the weights of the inference engines.")
ray.get(
[actor.update_weights.remote() for actor in training_actors]
+ [
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
for llm in inference_engines
]
)
print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
if __name__ == "__main__":
main()

View File

@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import pickle
from ast import Dict, Tuple
from collections.abc import Callable
from enum import Enum
from typing import TypedDict
import torch
@ -92,6 +95,15 @@ class FlattenedTensorMetadata(TypedDict):
offset: int
class PayloadType(Enum):
"""Enumerates possible payload types in IPC protocol."""
HANDLES = "handles"
BUFFER_UPDATE = "buffer_update"
DONE = "done"
UNKNOWN = "unknown"
class ColocateWorkerExtension:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
@ -152,12 +164,90 @@ class ColocateWorkerExtension:
gc.collect()
torch.cuda.empty_cache()
def update_weights_from_ipc_async(self, zmq_handles: dict[str, str]):
assert self.device is not None
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
socket = self._zmq_ctx.socket(zmq.ROUTER)
socket.bind(zmq_handles[self.report_device_id()])
poller = zmq.Poller()
poller.register(socket, zmq.POLLIN)
buffers: Dict[int, torch.Tensor] = {}
while True:
events = dict(poller.poll(timeout=100))
if socket in events and (events[socket] & zmq.POLLIN):
# Router identity
identity = socket.recv()
payload: (
list[tuple[Callable, tuple]]
| tuple[int, list[FlattenedTensorMetadata]]
| None
) = socket.recv_pyobj()
payload_type = self._identify_payload_type(payload)
# === HANDLE LIST OF SHARED MEMORY HANDLES ===
if payload_type == PayloadType.HANDLES:
handles: list[tuple[Callable, tuple]] = payload
for i, h in enumerate(handles):
buffers[i] = rebuild_ipc(h, self.device.index)
socket.send_multipart([identity, pickle.dumps("ACK_HANDLES")])
continue
# === HANDLE BUFFERED MODEL UPDATES ===
if payload_type == PayloadType.BUFFER_UPDATE:
buf_id, items = payload
buffer = buffers.get(buf_id)
if buffer is None:
continue
weights: list[Tuple[str, torch.Tensor]] = []
for item in items:
assert isinstance(item, dict)
shape = torch.Size(item["shape"])
dtype, offset = item["dtype"], item["offset"]
size = dtype.itemsize * shape.numel()
tensor = (
buffer[offset : offset + size].view(dtype=dtype).view(shape)
)
weights.append((item["name"], tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
# # --- Added verify step here ---
# if self.verify_weights_enabled:
# self.verify_weights(weights)
socket.send_multipart([identity, pickle.dumps(buf_id)])
# === DONE SIGNAL ===
elif payload_type == PayloadType.DONE:
socket.send_multipart([identity, pickle.dumps("DONE")])
break
else:
continue
socket.close()
gc.collect()
torch.cuda.empty_cache()
def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
def _identify_payload_type(self, payload) -> PayloadType:
if isinstance(payload, list):
return PayloadType.HANDLES
elif isinstance(payload, tuple):
buf_id, _ = payload
if buf_id is None:
return PayloadType.DONE
return PayloadType.BUFFER_UPDATE
return PayloadType.UNKNOWN
def check_weights_changed(self):
"""
Check if the weights are updated to 0.