Merge 5cb240825e4f81fc69a34a1ca96424018691c3c5 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
knlnguyen1802 2025-12-25 08:48:49 +08:00 committed by GitHub
commit 7895845974
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 309 additions and 80 deletions

View File

@ -28,8 +28,10 @@ Learn more about Ray placement groups:
https://docs.ray.io/en/latest/placement-groups.html
"""
import argparse
import gc
import os
import time
import ray
import torch
@ -156,96 +158,237 @@ class RayTrainingActor:
gc.collect()
torch.cuda.empty_cache()
def update_weights_async(self, num_handles: int):
align_size = 256
# Ray manages four GPUs.
def get_size(p: torch.Tensor) -> int:
return (p.nbytes + align_size - 1) // align_size * align_size
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.init()
named_parameters: dict[str, torch.nn.Parameter] = dict(
self.model.named_parameters()
)
max_tensor_size = max(get_size(p) for p in named_parameters.values())
# use max_tensor_size * 2 as buffer size
buffer_capacity = max_tensor_size * 2
buffers = [
torch.empty(buffer_capacity, dtype=torch.uint8, device="cuda:0")
for _ in range(num_handles)
]
handles = [reduce_tensor(b) for b in buffers]
# Co-locate vLLM instances and training actors on the same set of GPUs:
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
# (tensor parallelism = 2).
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
# (tensor parallelism = 2).
# Establish ZMQ connection
s = self.zmq_context.socket(zmq.DEALER)
s.connect(self.zmq_handle)
s.send_pyobj(handles)
s.recv_multipart() # ACK for handles
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
ray.get(pg.ready())
print(f"placement group has bundles {pg.bundle_specs=}")
# === Partition tensors into buffers ===
offset = 0
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
named_tensors: list[dict] = []
real_tensors: list[torch.Tensor] = []
training_actors = []
training_actor_device_ids = []
inference_engines = []
inference_engine_device_ids = []
for name, p in named_parameters.items():
size = get_size(p)
if offset + size > buffer_capacity:
buckets.append((named_tensors, real_tensors))
named_tensors, real_tensors = [], []
offset = 0
# assume tensors are contiguous
named_tensors.append(
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
)
real_tensors.append(p)
offset += size
if named_tensors:
buckets.append((named_tensors, real_tensors))
for bundle_index in [0, 1, 2, 3]:
training_actor = ray.remote(
num_cpus=0,
num_gpus=0.4,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_index,
),
)(RayTrainingActor).remote()
training_actors.append(training_actor)
poller = zmq.Poller()
poller.register(s, zmq.POLLOUT)
free_buffers = list(range(num_handles))
inflight = 0
idx = 0
total = len(buckets)
print(f"[Training] Total {total} buckets to send.")
for bundle_index, training_actor in enumerate(training_actors):
device_id = ray.get(training_actor.report_device_id.remote())
print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id)
# === Send loop ===
while idx < total or inflight > 0:
events = dict(poller.poll(timeout=50))
if (
s in events
and (events[s] & zmq.POLLOUT)
and idx < total
and free_buffers
):
buf_id = free_buffers.pop(0)
meta, tensors = buckets[idx]
buffer = buffers[buf_id]
offset = 0
for item, t in zip(meta, tensors):
size = get_size(t)
buffer[offset : offset + t.nbytes].copy_(
t.contiguous().view(torch.uint8).flatten()
)
offset += size
torch.cuda.synchronize()
s.send_pyobj((buf_id, meta))
inflight += 1
print(f"[Training] Sent bucket {idx} using buffer {buf_id}.")
idx += 1
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# Use the following syntax instead of the @ray.remote decorator so that
# the placement group is customized for each bundle.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
),
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
gpu_memory_utilization=0.4,
bundle_indices=bundle_indices,
# Receive buffer-free ACKs
try:
while s.getsockopt(zmq.EVENTS) & zmq.POLLIN:
ack = int(s.recv_multipart(flags=zmq.NOBLOCK)[0].decode())
free_buffers.append(ack)
inflight -= 1
print(f"[Training] Ack received: buffer {ack} now free.")
except zmq.Again:
pass
# Signal done
s.send_pyobj((None, None))
s.recv_multipart() # DONE ack
s.close()
del buffers
gc.collect()
torch.cuda.empty_cache()
def setup_train_cluster():
# Ray manages four GPUs.
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.init()
# Co-locate vLLM instances and training actors on the same set of GPUs:
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
# (tensor parallelism = 2).
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
# (tensor parallelism = 2).
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
ray.get(pg.ready())
print(f"placement group has bundles {pg.bundle_specs=}")
training_actors = []
training_actor_device_ids = []
inference_engines = []
inference_engine_device_ids = []
for bundle_index in [0, 1, 2, 3]:
training_actor = ray.remote(
num_cpus=0,
num_gpus=0.4,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_index,
),
)(RayTrainingActor).remote()
training_actors.append(training_actor)
for bundle_index, training_actor in enumerate(training_actors):
device_id = ray.get(training_actor.report_device_id.remote())
print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id)
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# Use the following syntax instead of the @ray.remote decorator so that
# the placement group is customized for each bundle.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
),
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
gpu_memory_utilization=0.3,
bundle_indices=bundle_indices,
)
inference_engines.append(llm)
# Do not call any method on the inference engine at this point; the call
# blocks until the vLLM instance finishes initialization.
for i, llm in enumerate(inference_engines):
inference_engine_device_ids.append(
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
)
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
# Verify placement: the first two training actors share the same GPUs as
# the first inference engine.
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# Verify placement: the last two training actors share the same GPUs as
# the second inference engine.
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
print("Gather all the ZMQ handles from the training actors.")
zmq_handles = {}
for actor in training_actors:
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
print(f"ZMQ handles: {zmq_handles}")
return training_actors, inference_engines, zmq_handles
def main():
parser = argparse.ArgumentParser(
description="Update model weights across training and inference actors."
)
inference_engines.append(llm)
# Do not call any method on the inference engine at this point; the call
# blocks until the vLLM instance finishes initialization.
for i, llm in enumerate(inference_engines):
inference_engine_device_ids.append(
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
parser.add_argument(
"--num-ipc-handles",
type=int,
default=1,
help="Number of IPC handles. If 1, use synchronous update; \
if >1, use asynchronous update.",
)
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
args = parser.parse_args()
# Verify placement: the first two training actors share the same GPUs as
# the first inference engine.
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# Verify placement: the last two training actors share the same GPUs as
# the second inference engine.
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
num_handles = args.num_ipc_handles
training_actors, inference_engines, zmq_handles = setup_train_cluster()
print("Update the weights of the inference engines.")
start_time = time.time()
if num_handles == 1:
# Synchronous update
ray.get(
[actor.update_weights.remote() for actor in training_actors]
+ [
llm.collective_rpc.remote(
"update_weights_from_ipc", args=(zmq_handles,)
)
for llm in inference_engines
]
)
else:
# Asynchronous update
ray.get(
[
actor.update_weights_async.remote(num_handles=num_handles)
for actor in training_actors
]
+ [
llm.collective_rpc.remote(
"update_weights_from_ipc_async", args=(zmq_handles,)
)
for llm in inference_engines
]
)
end_time = time.time()
elapsed = end_time - start_time
print(f"Weight update completed in {elapsed:.2f} seconds.")
print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(
llm.collective_rpc.remote("check_weights_changed", args=tuple())
), "Weights were not updated properly!"
print("Gather all the ZMQ handles from the training actors.")
zmq_handles = {}
for actor in training_actors:
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
print(f"ZMQ handles: {zmq_handles}")
print("Update the weights of the inference engines.")
ray.get(
[actor.update_weights.remote() for actor in training_actors]
+ [
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
for llm in inference_engines
]
)
print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
if __name__ == "__main__":
main()

View File

@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from ast import Dict, Tuple
from collections.abc import Callable
from enum import Enum
from typing import TypedDict
import torch
@ -92,6 +94,15 @@ class FlattenedTensorMetadata(TypedDict):
offset: int
class PayloadType(Enum):
"""Enumerates possible payload types in IPC protocol."""
HANDLES = "handles"
BUFFER_UPDATE = "buffer_update"
DONE = "done"
UNKNOWN = "unknown"
class ColocateWorkerExtension:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
@ -152,12 +163,87 @@ class ColocateWorkerExtension:
gc.collect()
torch.cuda.empty_cache()
def update_weights_from_ipc_async(self, zmq_handles: dict[str, str]):
assert self.device is not None
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
socket = self._zmq_ctx.socket(zmq.ROUTER)
socket.bind(zmq_handles[self.report_device_id()])
poller = zmq.Poller()
poller.register(socket, zmq.POLLIN)
buffers: Dict[int, torch.Tensor] = {}
while True:
events = dict(poller.poll(timeout=100))
if socket in events and (events[socket] & zmq.POLLIN):
# Router identity
identity = socket.recv()
payload: (
list[tuple[Callable, tuple]]
| tuple[int, list[FlattenedTensorMetadata]]
| None
) = socket.recv_pyobj()
payload_type = self._identify_payload_type(payload)
# === HANDLE LIST OF SHARED MEMORY HANDLES ===
if payload_type == PayloadType.HANDLES:
handles: list[tuple[Callable, tuple]] = payload
for i, h in enumerate(handles):
buffers[i] = rebuild_ipc(h, self.device.index)
socket.send_multipart([identity, b"ACK_HANDLES"])
continue
# === HANDLE BUFFERED MODEL UPDATES ===
if payload_type == PayloadType.BUFFER_UPDATE:
buf_id, items = payload
buffer = buffers.get(buf_id)
if buffer is None:
continue
weights: list[Tuple[str, torch.Tensor]] = []
for item in items:
assert isinstance(item, dict)
shape = torch.Size(item["shape"])
dtype, offset = item["dtype"], item["offset"]
size = dtype.itemsize * shape.numel()
tensor = (
buffer[offset : offset + size].view(dtype=dtype).view(shape)
)
weights.append((item["name"], tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
socket.send_multipart([identity, str(buf_id).encode()])
# === DONE SIGNAL ===
elif payload_type == PayloadType.DONE:
socket.send_multipart([identity, b"DONE"])
break
else:
continue
socket.close()
gc.collect()
torch.cuda.empty_cache()
def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
def _identify_payload_type(self, payload) -> PayloadType:
if isinstance(payload, list):
return PayloadType.HANDLES
elif isinstance(payload, tuple):
buf_id, _ = payload
if buf_id is None:
return PayloadType.DONE
return PayloadType.BUFFER_UPDATE
return PayloadType.UNKNOWN
def check_weights_changed(self):
"""
Check if the weights are updated to 0.