[RL] Multi IPC handles example for rlhf colocated

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
knlnguyen1802 2025-11-13 10:00:30 +08:00
parent 1761dea1a8
commit bdf34c1265
2 changed files with 312 additions and 80 deletions

View File

@ -28,6 +28,7 @@ Learn more about Ray placement groups:
https://docs.ray.io/en/latest/placement-groups.html https://docs.ray.io/en/latest/placement-groups.html
""" """
import argparse
import gc import gc
import os import os
@ -78,7 +79,9 @@ class RayTrainingActor:
# Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor. # Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor.
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") self.model = AutoModelForCausalLM.from_pretrained(
"/mnt/nvme3n1/models/qwen2.5_7B"
)
self.model.to("cuda:0") self.model.to("cuda:0")
# Zero out all the parameters. # Zero out all the parameters.
for name, p in self.model.named_parameters(): for name, p in self.model.named_parameters():
@ -156,96 +159,235 @@ class RayTrainingActor:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def update_weights_async(self, num_handles: int):
align_size = 256
# Ray manages four GPUs. def get_size(p: torch.Tensor) -> int:
return (p.nbytes + align_size - 1) // align_size * align_size
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" named_parameters: dict[str, torch.nn.Parameter] = dict(
ray.init() self.model.named_parameters()
)
max_tensor_size = max(get_size(p) for p in named_parameters.values())
# use max_tensor_size * 2 as buffer size
buffer_capacity = max_tensor_size * 2
buffers = [
torch.empty(buffer_capacity, dtype=torch.uint8, device="cuda:0")
for _ in range(num_handles)
]
handles = [reduce_tensor(b) for b in buffers]
# Co-locate vLLM instances and training actors on the same set of GPUs: # Establish ZMQ connection
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0 s = self.zmq_context.socket(zmq.DEALER)
# (tensor parallelism = 2). s.connect(self.zmq_handle)
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1 s.send_pyobj(handles)
# (tensor parallelism = 2). _ = s.recv_pyobj()
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4) # === Partition tensors into buffers ===
ray.get(pg.ready()) offset = 0
print(f"placement group has bundles {pg.bundle_specs=}") buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
named_tensors: list[dict] = []
real_tensors: list[torch.Tensor] = []
training_actors = [] for name, p in named_parameters.items():
training_actor_device_ids = [] size = get_size(p)
inference_engines = [] if offset + size > buffer_capacity:
inference_engine_device_ids = [] buckets.append((named_tensors, real_tensors))
named_tensors, real_tensors = [], []
offset = 0
# assume tensors are contiguous
named_tensors.append(
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
)
real_tensors.append(p)
offset += size
if named_tensors:
buckets.append((named_tensors, real_tensors))
for bundle_index in [0, 1, 2, 3]: poller = zmq.Poller()
training_actor = ray.remote( poller.register(s, zmq.POLLOUT)
num_cpus=0, free_buffers = list(range(num_handles))
num_gpus=0.4, inflight = 0
scheduling_strategy=PlacementGroupSchedulingStrategy( idx = 0
placement_group=pg, total = len(buckets)
placement_group_capture_child_tasks=True, print(f"[Training] Total {total} buckets to send.")
placement_group_bundle_index=bundle_index,
),
)(RayTrainingActor).remote()
training_actors.append(training_actor)
for bundle_index, training_actor in enumerate(training_actors): # === Send loop ===
device_id = ray.get(training_actor.report_device_id.remote()) while idx < total or inflight > 0:
print(f"training actor {bundle_index} is on {device_id}") events = dict(poller.poll(timeout=50))
training_actor_device_ids.append(device_id) if (
s in events
and (events[s] & zmq.POLLOUT)
and idx < total
and free_buffers
):
buf_id = free_buffers.pop(0)
meta, tensors = buckets[idx]
buffer = buffers[buf_id]
offset = 0
for item, t in zip(meta, tensors):
size = get_size(t)
buffer[offset : offset + t.nbytes].copy_(
t.contiguous().view(torch.uint8).flatten()
)
offset += size
torch.cuda.synchronize()
s.send_pyobj((buf_id, meta))
inflight += 1
print(f"[Training] Sent bucket {idx} using buffer {buf_id}.")
idx += 1
for i, bundle_indices in enumerate([[0, 1], [2, 3]]): # Receive buffer-free ACKs
# Use the following syntax instead of the @ray.remote decorator so that try:
# the placement group is customized for each bundle. while s.getsockopt(zmq.EVENTS) & zmq.POLLIN:
llm = ray.remote( ack = s.recv_pyobj(flags=zmq.NOBLOCK)
num_cpus=0, if isinstance(ack, int):
num_gpus=0, free_buffers.append(ack)
scheduling_strategy=PlacementGroupSchedulingStrategy( inflight -= 1
placement_group=pg, print(f"[Training] Ack received: buffer {ack} now free.")
placement_group_capture_child_tasks=True, except zmq.Again:
), pass
)(MyLLM).remote(
model="facebook/opt-125m", # Signal done
enforce_eager=True, s.send_pyobj((None, None))
worker_extension_cls="rlhf_utils.ColocateWorkerExtension", _ = s.recv_pyobj()
tensor_parallel_size=2, s.close()
distributed_executor_backend="ray", del buffers
gpu_memory_utilization=0.4, gc.collect()
bundle_indices=bundle_indices, torch.cuda.empty_cache()
def setup_train_cluster():
# Ray manages four GPUs.
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.init()
# Co-locate vLLM instances and training actors on the same set of GPUs:
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
# (tensor parallelism = 2).
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
# (tensor parallelism = 2).
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
ray.get(pg.ready())
print(f"placement group has bundles {pg.bundle_specs=}")
training_actors = []
training_actor_device_ids = []
inference_engines = []
inference_engine_device_ids = []
for bundle_index in [0, 1, 2, 3]:
training_actor = ray.remote(
num_cpus=0,
num_gpus=0.4,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_index,
),
)(RayTrainingActor).remote()
training_actors.append(training_actor)
for bundle_index, training_actor in enumerate(training_actors):
device_id = ray.get(training_actor.report_device_id.remote())
print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id)
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# Use the following syntax instead of the @ray.remote decorator so that
# the placement group is customized for each bundle.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
),
)(MyLLM).remote(
model="/mnt/nvme3n1/models/qwen2.5_7B",
enforce_eager=True,
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
gpu_memory_utilization=0.3,
bundle_indices=bundle_indices,
)
inference_engines.append(llm)
# Do not call any method on the inference engine at this point; the call
# blocks until the vLLM instance finishes initialization.
for i, llm in enumerate(inference_engines):
inference_engine_device_ids.append(
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
)
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
# Verify placement: the first two training actors share the same GPUs as
# the first inference engine.
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# Verify placement: the last two training actors share the same GPUs as
# the second inference engine.
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
print("Gather all the ZMQ handles from the training actors.")
zmq_handles = {}
for actor in training_actors:
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
print(f"ZMQ handles: {zmq_handles}")
return training_actors, inference_engines, zmq_handles
def main():
parser = argparse.ArgumentParser(
description="Update model weights across training and inference actors."
) )
inference_engines.append(llm) parser.add_argument(
# Do not call any method on the inference engine at this point; the call "--num-ipc-handles",
# blocks until the vLLM instance finishes initialization. type=int,
default=1,
for i, llm in enumerate(inference_engines): help="Number of IPC handles. If 1, use synchronous update; \
inference_engine_device_ids.append( if >1, use asynchronous update.",
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
) )
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") args = parser.parse_args()
# Verify placement: the first two training actors share the same GPUs as num_handles = args.num_ipc_handles
# the first inference engine. training_actors, inference_engines, zmq_handles = setup_train_cluster()
assert training_actor_device_ids[:2] == inference_engine_device_ids[0] print("Update the weights of the inference engines.")
# Verify placement: the last two training actors share the same GPUs as if num_handles == 1:
# the second inference engine. # Synchronous update
assert training_actor_device_ids[2:] == inference_engine_device_ids[1] ray.get(
[actor.update_weights.remote() for actor in training_actors]
+ [
llm.collective_rpc.remote(
"update_weights_from_ipc", args=(zmq_handles,)
)
for llm in inference_engines
]
)
else:
# Asynchronous update
ray.get(
[
actor.update_weights_async.remote(num_handles=num_handles)
for actor in training_actors
]
+ [
llm.collective_rpc.remote(
"update_weights_from_ipc_async", args=(zmq_handles,)
)
for llm in inference_engines
]
)
print("Gather all the ZMQ handles from the training actors.") print("Check if the weights are updated.")
zmq_handles = {} for llm in inference_engines:
for actor in training_actors: assert ray.get(
zmq_handles.update(ray.get(actor.get_zmq_handles.remote())) llm.collective_rpc.remote("check_weights_changed", args=tuple())
), "Weights were not updated properly!"
print(f"ZMQ handles: {zmq_handles}")
print("Update the weights of the inference engines.") if __name__ == "__main__":
ray.get( main()
[actor.update_weights.remote() for actor in training_actors]
+ [
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
for llm in inference_engines
]
)
print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))

View File

@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc import gc
import pickle
from ast import Dict, Tuple
from collections.abc import Callable from collections.abc import Callable
from enum import Enum
from typing import TypedDict from typing import TypedDict
import torch import torch
@ -92,6 +95,15 @@ class FlattenedTensorMetadata(TypedDict):
offset: int offset: int
class PayloadType(Enum):
"""Enumerates possible payload types in IPC protocol."""
HANDLES = "handles"
BUFFER_UPDATE = "buffer_update"
DONE = "done"
UNKNOWN = "unknown"
class ColocateWorkerExtension: class ColocateWorkerExtension:
""" """
The class for vLLM's worker to inherit from, in the colocate setting. The class for vLLM's worker to inherit from, in the colocate setting.
@ -152,12 +164,90 @@ class ColocateWorkerExtension:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def update_weights_from_ipc_async(self, zmq_handles: dict[str, str]):
assert self.device is not None
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
socket = self._zmq_ctx.socket(zmq.ROUTER)
socket.bind(zmq_handles[self.report_device_id()])
poller = zmq.Poller()
poller.register(socket, zmq.POLLIN)
buffers: Dict[int, torch.Tensor] = {}
while True:
events = dict(poller.poll(timeout=100))
if socket in events and (events[socket] & zmq.POLLIN):
# Router identity
identity = socket.recv()
payload: (
list[tuple[Callable, tuple]]
| tuple[int, list[FlattenedTensorMetadata]]
| None
) = socket.recv_pyobj()
payload_type = self._identify_payload_type(payload)
# === HANDLE LIST OF SHARED MEMORY HANDLES ===
if payload_type == PayloadType.HANDLES:
handles: list[tuple[Callable, tuple]] = payload
for i, h in enumerate(handles):
buffers[i] = rebuild_ipc(h, self.device.index)
socket.send_multipart([identity, pickle.dumps("ACK_HANDLES")])
continue
# === HANDLE BUFFERED MODEL UPDATES ===
if payload_type == PayloadType.BUFFER_UPDATE:
buf_id, items = payload
buffer = buffers.get(buf_id)
if buffer is None:
continue
weights: list[Tuple[str, torch.Tensor]] = []
for item in items:
assert isinstance(item, dict)
shape = torch.Size(item["shape"])
dtype, offset = item["dtype"], item["offset"]
size = dtype.itemsize * shape.numel()
tensor = (
buffer[offset : offset + size].view(dtype=dtype).view(shape)
)
weights.append((item["name"], tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
# # --- Added verify step here ---
# if self.verify_weights_enabled:
# self.verify_weights(weights)
socket.send_multipart([identity, pickle.dumps(buf_id)])
# === DONE SIGNAL ===
elif payload_type == PayloadType.DONE:
socket.send_multipart([identity, pickle.dumps("DONE")])
break
else:
continue
socket.close()
gc.collect()
torch.cuda.empty_cache()
def report_device_id(self) -> str: def report_device_id(self) -> str:
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index) self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid return self.device_uuid
def _identify_payload_type(self, payload) -> PayloadType:
if isinstance(payload, list):
return PayloadType.HANDLES
elif isinstance(payload, tuple):
buf_id, _ = payload
if buf_id is None:
return PayloadType.DONE
return PayloadType.BUFFER_UPDATE
return PayloadType.UNKNOWN
def check_weights_changed(self): def check_weights_changed(self):
""" """
Check if the weights are updated to 0. Check if the weights are updated to 0.