mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:44:55 +08:00
[RL] fast weight update with zmq + ipc handles (#24295)
Signed-off-by: huangweixiao <huangweixiao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
1116590b16
commit
3d2a2de8f7
@ -28,12 +28,15 @@ Learn more about Ray placement groups:
|
|||||||
https://docs.ray.io/en/latest/placement-groups.html
|
https://docs.ray.io/en/latest/placement-groups.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
|
import zmq
|
||||||
from ray.util.placement_group import placement_group
|
from ray.util.placement_group import placement_group
|
||||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
from torch.multiprocessing.reductions import reduce_tensor
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
|
|
||||||
@ -86,20 +89,72 @@ class RayTrainingActor:
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
self.device_uuid = current_platform.get_device_uuid(0)
|
self.device_uuid = current_platform.get_device_uuid(0)
|
||||||
|
self.zmq_context = zmq.Context()
|
||||||
|
self.zmq_address_counter = 0
|
||||||
|
self.zmq_handle = None
|
||||||
|
|
||||||
def report_device_id(self) -> str:
|
def report_device_id(self) -> str:
|
||||||
return self.device_uuid
|
return self.device_uuid
|
||||||
|
|
||||||
def get_weight_ipc_handles(self):
|
def get_zmq_handles(self) -> dict[str, str]:
|
||||||
from torch.multiprocessing.reductions import reduce_tensor
|
suffix = f"{self.device_uuid}-{self.zmq_address_counter}"
|
||||||
|
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock"
|
||||||
|
self.zmq_address_counter += 1
|
||||||
|
return {self.device_uuid: self.zmq_handle}
|
||||||
|
|
||||||
data = {}
|
def update_weights(self):
|
||||||
for name, p in self.model.named_parameters():
|
# align size to avoid misaligned address
|
||||||
# A training actor might hold only a subset of the weights and may
|
align_size = 256
|
||||||
# need to gather weights from other actors. For demonstration
|
|
||||||
# purposes, each training actor owns the full weight set.
|
def get_size(p: torch.Tensor) -> int:
|
||||||
data[name] = reduce_tensor(p.detach())
|
return (p.nbytes + align_size - 1) // align_size * align_size
|
||||||
return {self.device_uuid: data}
|
|
||||||
|
named_parameters: dict[str, torch.nn.Parameter] = dict(
|
||||||
|
self.model.named_parameters()
|
||||||
|
)
|
||||||
|
max_tensor_size = max(get_size(p) for p in named_parameters.values())
|
||||||
|
# use max_tensor_size * 2 as buffer size
|
||||||
|
buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0")
|
||||||
|
s = self.zmq_context.socket(zmq.REQ)
|
||||||
|
s.bind(self.zmq_handle)
|
||||||
|
handle = reduce_tensor(buffer)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
|
||||||
|
named_tensors: list[dict] = []
|
||||||
|
real_tensors: list[torch.Tensor] = []
|
||||||
|
for name, p in named_parameters.items():
|
||||||
|
size = get_size(p)
|
||||||
|
if offset + size > buffer.numel():
|
||||||
|
buckets.append((named_tensors, real_tensors))
|
||||||
|
named_tensors, real_tensors = [], []
|
||||||
|
offset = 0
|
||||||
|
# assume tensors are contiguous
|
||||||
|
named_tensors.append(
|
||||||
|
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
|
||||||
|
)
|
||||||
|
real_tensors.append(p)
|
||||||
|
offset += size
|
||||||
|
if named_tensors:
|
||||||
|
buckets.append((named_tensors, real_tensors))
|
||||||
|
s.send_pyobj(handle)
|
||||||
|
s.recv()
|
||||||
|
for named_tensors, real_tensors in buckets:
|
||||||
|
offset = 0
|
||||||
|
for p in real_tensors:
|
||||||
|
buffer[offset : offset + p.nbytes].data.copy_(
|
||||||
|
p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
|
||||||
|
)
|
||||||
|
offset += get_size(p)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
s.send_pyobj(named_tensors)
|
||||||
|
s.recv()
|
||||||
|
s.send_pyobj(None)
|
||||||
|
s.recv()
|
||||||
|
s.close()
|
||||||
|
del buffer
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
# Ray manages four GPUs.
|
# Ray manages four GPUs.
|
||||||
@ -175,18 +230,22 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
|
|||||||
# the second inference engine.
|
# the second inference engine.
|
||||||
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
|
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
|
||||||
|
|
||||||
print("Gather all the IPC handles from the training actors.")
|
print("Gather all the ZMQ handles from the training actors.")
|
||||||
ipc_handles = {}
|
zmq_handles = {}
|
||||||
for actor in training_actors:
|
for actor in training_actors:
|
||||||
ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))
|
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
|
||||||
|
|
||||||
|
print(f"ZMQ handles: {zmq_handles}")
|
||||||
|
|
||||||
print("Update the weights of the inference engines.")
|
print("Update the weights of the inference engines.")
|
||||||
for llm in inference_engines:
|
ray.get(
|
||||||
ray.get(
|
[actor.update_weights.remote() for actor in training_actors]
|
||||||
llm.collective_rpc.remote(
|
+ [
|
||||||
"update_weights_from_ipc_handles", args=(ipc_handles,)
|
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
|
||||||
)
|
for llm in inference_engines
|
||||||
)
|
]
|
||||||
|
)
|
||||||
|
|
||||||
print("Check if the weights are updated.")
|
print("Check if the weights are updated.")
|
||||||
for llm in inference_engines:
|
for llm in inference_engines:
|
||||||
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
|
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import gc
|
||||||
|
from typing import Callable, Optional, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
|
||||||
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
|
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
|
||||||
@ -66,6 +70,27 @@ class WorkerExtension:
|
|||||||
return weights_updated
|
return weights_updated
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_ipc(
|
||||||
|
handle: tuple[Callable, tuple], device_id: Optional[int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
func, args = handle
|
||||||
|
list_args = list(args)
|
||||||
|
if device_id is not None:
|
||||||
|
# the key is to change device id to the current device id
|
||||||
|
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||||||
|
list_args[6] = device_id
|
||||||
|
buffer = func(*list_args)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
class FlattenedTensorMetadata(TypedDict):
|
||||||
|
name: str
|
||||||
|
shape: torch.Size
|
||||||
|
dtype: torch.dtype
|
||||||
|
# specify the start offset of this tensor in shared ipc_buffer tensor
|
||||||
|
offset: int
|
||||||
|
|
||||||
|
|
||||||
class ColocateWorkerExtension:
|
class ColocateWorkerExtension:
|
||||||
"""
|
"""
|
||||||
The class for vLLM's worker to inherit from, in the colocate setting.
|
The class for vLLM's worker to inherit from, in the colocate setting.
|
||||||
@ -76,27 +101,62 @@ class ColocateWorkerExtension:
|
|||||||
should pass the full qualified name as `worker_extension_cls` argument.
|
should pass the full qualified name as `worker_extension_cls` argument.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
|
||||||
|
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
||||||
|
|
||||||
|
assert self.device is not None
|
||||||
|
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
|
||||||
|
self._zmq_ctx = zmq.Context()
|
||||||
|
socket = self._zmq_ctx.socket(zmq.REP)
|
||||||
|
socket.connect(zmq_handles[self.report_device_id()])
|
||||||
|
buffer: Optional[torch.Tensor] = None
|
||||||
|
while True:
|
||||||
|
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = (
|
||||||
|
socket.recv_pyobj()
|
||||||
|
)
|
||||||
|
if payload is None:
|
||||||
|
# means the update is done
|
||||||
|
process_weights_after_loading(
|
||||||
|
self.model_runner.model, self.model_config, self.device
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
socket.send(b"")
|
||||||
|
break
|
||||||
|
if isinstance(payload, tuple):
|
||||||
|
# an ipc handle that vLLM can use `func, args = handle`
|
||||||
|
# and `func(*args)` to rebuild GPU tensor.
|
||||||
|
buffer = rebuild_ipc(payload, self.device.index)
|
||||||
|
assert buffer.dtype == torch.uint8
|
||||||
|
socket.send(b"")
|
||||||
|
continue
|
||||||
|
assert isinstance(payload, list)
|
||||||
|
assert buffer is not None
|
||||||
|
weights = []
|
||||||
|
for item in payload:
|
||||||
|
shape = item["shape"]
|
||||||
|
if isinstance(shape, (list, tuple)):
|
||||||
|
shape = torch.Size(shape)
|
||||||
|
assert isinstance(shape, torch.Size)
|
||||||
|
dtype, offset = item["dtype"], item["offset"]
|
||||||
|
size = dtype.itemsize * shape.numel()
|
||||||
|
tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
|
||||||
|
weights.append((item["name"], tensor))
|
||||||
|
self.model_runner.model.load_weights(weights=weights)
|
||||||
|
del weights
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
socket.send(b"")
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
del buffer
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def report_device_id(self) -> str:
|
def report_device_id(self) -> str:
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||||
return self.device_uuid
|
return self.device_uuid
|
||||||
|
|
||||||
def update_weights_from_ipc_handles(self, ipc_handles):
|
|
||||||
handles = ipc_handles[self.device_uuid]
|
|
||||||
device_id = self.device.index
|
|
||||||
weights = []
|
|
||||||
for name, handle in handles.items():
|
|
||||||
func, args = handle
|
|
||||||
list_args = list(args)
|
|
||||||
# the key is to change device id to the current device id
|
|
||||||
# in case two processes have different CUDA_VISIBLE_DEVICES
|
|
||||||
list_args[6] = device_id
|
|
||||||
tensor = func(*list_args)
|
|
||||||
weights.append((name, tensor))
|
|
||||||
self.model_runner.model.load_weights(weights=weights)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def check_weights_changed(self):
|
def check_weights_changed(self):
|
||||||
"""
|
"""
|
||||||
Check if the weights are updated to 0.
|
Check if the weights are updated to 0.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user