[RL] fast weight update with zmq + ipc handles (#24295)

Signed-off-by: huangweixiao <huangweixiao@msh.team>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Weixiao Huang 2025-09-09 16:57:46 +08:00 committed by GitHub
parent 1116590b16
commit 3d2a2de8f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 152 additions and 33 deletions

View File

@ -28,12 +28,15 @@ Learn more about Ray placement groups:
https://docs.ray.io/en/latest/placement-groups.html
"""
import gc
import os
import ray
import torch
import zmq
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.multiprocessing.reductions import reduce_tensor
from vllm import LLM
@ -86,20 +89,72 @@ class RayTrainingActor:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(0)
self.zmq_context = zmq.Context()
self.zmq_address_counter = 0
self.zmq_handle = None
def report_device_id(self) -> str:
return self.device_uuid
def get_weight_ipc_handles(self):
from torch.multiprocessing.reductions import reduce_tensor
def get_zmq_handles(self) -> dict[str, str]:
suffix = f"{self.device_uuid}-{self.zmq_address_counter}"
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock"
self.zmq_address_counter += 1
return {self.device_uuid: self.zmq_handle}
data = {}
for name, p in self.model.named_parameters():
# A training actor might hold only a subset of the weights and may
# need to gather weights from other actors. For demonstration
# purposes, each training actor owns the full weight set.
data[name] = reduce_tensor(p.detach())
return {self.device_uuid: data}
def update_weights(self):
# align size to avoid misaligned address
align_size = 256
def get_size(p: torch.Tensor) -> int:
return (p.nbytes + align_size - 1) // align_size * align_size
named_parameters: dict[str, torch.nn.Parameter] = dict(
self.model.named_parameters()
)
max_tensor_size = max(get_size(p) for p in named_parameters.values())
# use max_tensor_size * 2 as buffer size
buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0")
s = self.zmq_context.socket(zmq.REQ)
s.bind(self.zmq_handle)
handle = reduce_tensor(buffer)
offset = 0
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
named_tensors: list[dict] = []
real_tensors: list[torch.Tensor] = []
for name, p in named_parameters.items():
size = get_size(p)
if offset + size > buffer.numel():
buckets.append((named_tensors, real_tensors))
named_tensors, real_tensors = [], []
offset = 0
# assume tensors are contiguous
named_tensors.append(
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
)
real_tensors.append(p)
offset += size
if named_tensors:
buckets.append((named_tensors, real_tensors))
s.send_pyobj(handle)
s.recv()
for named_tensors, real_tensors in buckets:
offset = 0
for p in real_tensors:
buffer[offset : offset + p.nbytes].data.copy_(
p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
)
offset += get_size(p)
torch.cuda.synchronize()
s.send_pyobj(named_tensors)
s.recv()
s.send_pyobj(None)
s.recv()
s.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
# Ray manages four GPUs.
@ -175,18 +230,22 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# the second inference engine.
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
print("Gather all the IPC handles from the training actors.")
ipc_handles = {}
print("Gather all the ZMQ handles from the training actors.")
zmq_handles = {}
for actor in training_actors:
ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
print(f"ZMQ handles: {zmq_handles}")
print("Update the weights of the inference engines.")
for llm in inference_engines:
ray.get(
llm.collective_rpc.remote(
"update_weights_from_ipc_handles", args=(ipc_handles,)
)
)
ray.get(
[actor.update_weights.remote() for actor in training_actors]
+ [
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
for llm in inference_engines
]
)
print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))

View File

@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from typing import Callable, Optional, TypedDict
import torch
import zmq
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
@ -66,6 +70,27 @@ class WorkerExtension:
return weights_updated
def rebuild_ipc(
handle: tuple[Callable, tuple], device_id: Optional[int] = None
) -> torch.Tensor:
func, args = handle
list_args = list(args)
if device_id is not None:
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
buffer = func(*list_args)
return buffer
class FlattenedTensorMetadata(TypedDict):
name: str
shape: torch.Size
dtype: torch.dtype
# specify the start offset of this tensor in shared ipc_buffer tensor
offset: int
class ColocateWorkerExtension:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
@ -76,27 +101,62 @@ class ColocateWorkerExtension:
should pass the full qualified name as `worker_extension_cls` argument.
"""
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
from vllm.model_executor.model_loader.utils import process_weights_after_loading
assert self.device is not None
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
socket = self._zmq_ctx.socket(zmq.REP)
socket.connect(zmq_handles[self.report_device_id()])
buffer: Optional[torch.Tensor] = None
while True:
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = (
socket.recv_pyobj()
)
if payload is None:
# means the update is done
process_weights_after_loading(
self.model_runner.model, self.model_config, self.device
)
torch.cuda.synchronize()
socket.send(b"")
break
if isinstance(payload, tuple):
# an ipc handle that vLLM can use `func, args = handle`
# and `func(*args)` to rebuild GPU tensor.
buffer = rebuild_ipc(payload, self.device.index)
assert buffer.dtype == torch.uint8
socket.send(b"")
continue
assert isinstance(payload, list)
assert buffer is not None
weights = []
for item in payload:
shape = item["shape"]
if isinstance(shape, (list, tuple)):
shape = torch.Size(shape)
assert isinstance(shape, torch.Size)
dtype, offset = item["dtype"], item["offset"]
size = dtype.itemsize * shape.numel()
tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
weights.append((item["name"], tensor))
self.model_runner.model.load_weights(weights=weights)
del weights
torch.cuda.synchronize()
socket.send(b"")
socket.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
def check_weights_changed(self):
"""
Check if the weights are updated to 0.