mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 17:22:23 +08:00
[RL] Multi IPC handles example for rlhf colocated
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
parent
1761dea1a8
commit
bdf34c1265
@ -28,6 +28,7 @@ Learn more about Ray placement groups:
|
|||||||
https://docs.ray.io/en/latest/placement-groups.html
|
https://docs.ray.io/en/latest/placement-groups.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -78,7 +79,9 @@ class RayTrainingActor:
|
|||||||
# Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor.
|
# Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor.
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"/mnt/nvme3n1/models/qwen2.5_7B"
|
||||||
|
)
|
||||||
self.model.to("cuda:0")
|
self.model.to("cuda:0")
|
||||||
# Zero out all the parameters.
|
# Zero out all the parameters.
|
||||||
for name, p in self.model.named_parameters():
|
for name, p in self.model.named_parameters():
|
||||||
@ -156,96 +159,235 @@ class RayTrainingActor:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def update_weights_async(self, num_handles: int):
|
||||||
|
align_size = 256
|
||||||
|
|
||||||
# Ray manages four GPUs.
|
def get_size(p: torch.Tensor) -> int:
|
||||||
|
return (p.nbytes + align_size - 1) // align_size * align_size
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
named_parameters: dict[str, torch.nn.Parameter] = dict(
|
||||||
ray.init()
|
self.model.named_parameters()
|
||||||
|
)
|
||||||
|
max_tensor_size = max(get_size(p) for p in named_parameters.values())
|
||||||
|
# use max_tensor_size * 2 as buffer size
|
||||||
|
buffer_capacity = max_tensor_size * 2
|
||||||
|
buffers = [
|
||||||
|
torch.empty(buffer_capacity, dtype=torch.uint8, device="cuda:0")
|
||||||
|
for _ in range(num_handles)
|
||||||
|
]
|
||||||
|
handles = [reduce_tensor(b) for b in buffers]
|
||||||
|
|
||||||
# Co-locate vLLM instances and training actors on the same set of GPUs:
|
# Establish ZMQ connection
|
||||||
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
|
s = self.zmq_context.socket(zmq.DEALER)
|
||||||
# (tensor parallelism = 2).
|
s.connect(self.zmq_handle)
|
||||||
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
|
s.send_pyobj(handles)
|
||||||
# (tensor parallelism = 2).
|
_ = s.recv_pyobj()
|
||||||
|
|
||||||
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
|
# === Partition tensors into buffers ===
|
||||||
ray.get(pg.ready())
|
offset = 0
|
||||||
print(f"placement group has bundles {pg.bundle_specs=}")
|
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
|
||||||
|
named_tensors: list[dict] = []
|
||||||
|
real_tensors: list[torch.Tensor] = []
|
||||||
|
|
||||||
training_actors = []
|
for name, p in named_parameters.items():
|
||||||
training_actor_device_ids = []
|
size = get_size(p)
|
||||||
inference_engines = []
|
if offset + size > buffer_capacity:
|
||||||
inference_engine_device_ids = []
|
buckets.append((named_tensors, real_tensors))
|
||||||
|
named_tensors, real_tensors = [], []
|
||||||
|
offset = 0
|
||||||
|
# assume tensors are contiguous
|
||||||
|
named_tensors.append(
|
||||||
|
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
|
||||||
|
)
|
||||||
|
real_tensors.append(p)
|
||||||
|
offset += size
|
||||||
|
if named_tensors:
|
||||||
|
buckets.append((named_tensors, real_tensors))
|
||||||
|
|
||||||
for bundle_index in [0, 1, 2, 3]:
|
poller = zmq.Poller()
|
||||||
training_actor = ray.remote(
|
poller.register(s, zmq.POLLOUT)
|
||||||
num_cpus=0,
|
free_buffers = list(range(num_handles))
|
||||||
num_gpus=0.4,
|
inflight = 0
|
||||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
idx = 0
|
||||||
placement_group=pg,
|
total = len(buckets)
|
||||||
placement_group_capture_child_tasks=True,
|
print(f"[Training] Total {total} buckets to send.")
|
||||||
placement_group_bundle_index=bundle_index,
|
|
||||||
),
|
|
||||||
)(RayTrainingActor).remote()
|
|
||||||
training_actors.append(training_actor)
|
|
||||||
|
|
||||||
for bundle_index, training_actor in enumerate(training_actors):
|
# === Send loop ===
|
||||||
device_id = ray.get(training_actor.report_device_id.remote())
|
while idx < total or inflight > 0:
|
||||||
print(f"training actor {bundle_index} is on {device_id}")
|
events = dict(poller.poll(timeout=50))
|
||||||
training_actor_device_ids.append(device_id)
|
if (
|
||||||
|
s in events
|
||||||
|
and (events[s] & zmq.POLLOUT)
|
||||||
|
and idx < total
|
||||||
|
and free_buffers
|
||||||
|
):
|
||||||
|
buf_id = free_buffers.pop(0)
|
||||||
|
meta, tensors = buckets[idx]
|
||||||
|
buffer = buffers[buf_id]
|
||||||
|
offset = 0
|
||||||
|
for item, t in zip(meta, tensors):
|
||||||
|
size = get_size(t)
|
||||||
|
buffer[offset : offset + t.nbytes].copy_(
|
||||||
|
t.contiguous().view(torch.uint8).flatten()
|
||||||
|
)
|
||||||
|
offset += size
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
s.send_pyobj((buf_id, meta))
|
||||||
|
inflight += 1
|
||||||
|
print(f"[Training] Sent bucket {idx} using buffer {buf_id}.")
|
||||||
|
idx += 1
|
||||||
|
|
||||||
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
|
# Receive buffer-free ACKs
|
||||||
# Use the following syntax instead of the @ray.remote decorator so that
|
try:
|
||||||
# the placement group is customized for each bundle.
|
while s.getsockopt(zmq.EVENTS) & zmq.POLLIN:
|
||||||
llm = ray.remote(
|
ack = s.recv_pyobj(flags=zmq.NOBLOCK)
|
||||||
num_cpus=0,
|
if isinstance(ack, int):
|
||||||
num_gpus=0,
|
free_buffers.append(ack)
|
||||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
inflight -= 1
|
||||||
placement_group=pg,
|
print(f"[Training] Ack received: buffer {ack} now free.")
|
||||||
placement_group_capture_child_tasks=True,
|
except zmq.Again:
|
||||||
),
|
pass
|
||||||
)(MyLLM).remote(
|
|
||||||
model="facebook/opt-125m",
|
# Signal done
|
||||||
enforce_eager=True,
|
s.send_pyobj((None, None))
|
||||||
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
|
_ = s.recv_pyobj()
|
||||||
tensor_parallel_size=2,
|
s.close()
|
||||||
distributed_executor_backend="ray",
|
del buffers
|
||||||
gpu_memory_utilization=0.4,
|
gc.collect()
|
||||||
bundle_indices=bundle_indices,
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_train_cluster():
|
||||||
|
# Ray manages four GPUs.
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
# Co-locate vLLM instances and training actors on the same set of GPUs:
|
||||||
|
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
|
||||||
|
# (tensor parallelism = 2).
|
||||||
|
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
|
||||||
|
# (tensor parallelism = 2).
|
||||||
|
|
||||||
|
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
|
||||||
|
ray.get(pg.ready())
|
||||||
|
print(f"placement group has bundles {pg.bundle_specs=}")
|
||||||
|
|
||||||
|
training_actors = []
|
||||||
|
training_actor_device_ids = []
|
||||||
|
inference_engines = []
|
||||||
|
inference_engine_device_ids = []
|
||||||
|
|
||||||
|
for bundle_index in [0, 1, 2, 3]:
|
||||||
|
training_actor = ray.remote(
|
||||||
|
num_cpus=0,
|
||||||
|
num_gpus=0.4,
|
||||||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=pg,
|
||||||
|
placement_group_capture_child_tasks=True,
|
||||||
|
placement_group_bundle_index=bundle_index,
|
||||||
|
),
|
||||||
|
)(RayTrainingActor).remote()
|
||||||
|
training_actors.append(training_actor)
|
||||||
|
|
||||||
|
for bundle_index, training_actor in enumerate(training_actors):
|
||||||
|
device_id = ray.get(training_actor.report_device_id.remote())
|
||||||
|
print(f"training actor {bundle_index} is on {device_id}")
|
||||||
|
training_actor_device_ids.append(device_id)
|
||||||
|
|
||||||
|
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
|
||||||
|
# Use the following syntax instead of the @ray.remote decorator so that
|
||||||
|
# the placement group is customized for each bundle.
|
||||||
|
llm = ray.remote(
|
||||||
|
num_cpus=0,
|
||||||
|
num_gpus=0,
|
||||||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=pg,
|
||||||
|
placement_group_capture_child_tasks=True,
|
||||||
|
),
|
||||||
|
)(MyLLM).remote(
|
||||||
|
model="/mnt/nvme3n1/models/qwen2.5_7B",
|
||||||
|
enforce_eager=True,
|
||||||
|
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
distributed_executor_backend="ray",
|
||||||
|
gpu_memory_utilization=0.3,
|
||||||
|
bundle_indices=bundle_indices,
|
||||||
|
)
|
||||||
|
inference_engines.append(llm)
|
||||||
|
# Do not call any method on the inference engine at this point; the call
|
||||||
|
# blocks until the vLLM instance finishes initialization.
|
||||||
|
|
||||||
|
for i, llm in enumerate(inference_engines):
|
||||||
|
inference_engine_device_ids.append(
|
||||||
|
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
|
||||||
|
)
|
||||||
|
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
|
||||||
|
|
||||||
|
# Verify placement: the first two training actors share the same GPUs as
|
||||||
|
# the first inference engine.
|
||||||
|
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
|
||||||
|
# Verify placement: the last two training actors share the same GPUs as
|
||||||
|
# the second inference engine.
|
||||||
|
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
|
||||||
|
|
||||||
|
print("Gather all the ZMQ handles from the training actors.")
|
||||||
|
zmq_handles = {}
|
||||||
|
for actor in training_actors:
|
||||||
|
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
|
||||||
|
|
||||||
|
print(f"ZMQ handles: {zmq_handles}")
|
||||||
|
return training_actors, inference_engines, zmq_handles
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Update model weights across training and inference actors."
|
||||||
)
|
)
|
||||||
inference_engines.append(llm)
|
parser.add_argument(
|
||||||
# Do not call any method on the inference engine at this point; the call
|
"--num-ipc-handles",
|
||||||
# blocks until the vLLM instance finishes initialization.
|
type=int,
|
||||||
|
default=1,
|
||||||
for i, llm in enumerate(inference_engines):
|
help="Number of IPC handles. If 1, use synchronous update; \
|
||||||
inference_engine_device_ids.append(
|
if >1, use asynchronous update.",
|
||||||
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
|
|
||||||
)
|
)
|
||||||
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Verify placement: the first two training actors share the same GPUs as
|
num_handles = args.num_ipc_handles
|
||||||
# the first inference engine.
|
training_actors, inference_engines, zmq_handles = setup_train_cluster()
|
||||||
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
|
print("Update the weights of the inference engines.")
|
||||||
# Verify placement: the last two training actors share the same GPUs as
|
if num_handles == 1:
|
||||||
# the second inference engine.
|
# Synchronous update
|
||||||
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
|
ray.get(
|
||||||
|
[actor.update_weights.remote() for actor in training_actors]
|
||||||
|
+ [
|
||||||
|
llm.collective_rpc.remote(
|
||||||
|
"update_weights_from_ipc", args=(zmq_handles,)
|
||||||
|
)
|
||||||
|
for llm in inference_engines
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Asynchronous update
|
||||||
|
ray.get(
|
||||||
|
[
|
||||||
|
actor.update_weights_async.remote(num_handles=num_handles)
|
||||||
|
for actor in training_actors
|
||||||
|
]
|
||||||
|
+ [
|
||||||
|
llm.collective_rpc.remote(
|
||||||
|
"update_weights_from_ipc_async", args=(zmq_handles,)
|
||||||
|
)
|
||||||
|
for llm in inference_engines
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
print("Gather all the ZMQ handles from the training actors.")
|
print("Check if the weights are updated.")
|
||||||
zmq_handles = {}
|
for llm in inference_engines:
|
||||||
for actor in training_actors:
|
assert ray.get(
|
||||||
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
|
llm.collective_rpc.remote("check_weights_changed", args=tuple())
|
||||||
|
), "Weights were not updated properly!"
|
||||||
|
|
||||||
print(f"ZMQ handles: {zmq_handles}")
|
|
||||||
|
|
||||||
print("Update the weights of the inference engines.")
|
if __name__ == "__main__":
|
||||||
ray.get(
|
main()
|
||||||
[actor.update_weights.remote() for actor in training_actors]
|
|
||||||
+ [
|
|
||||||
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
|
|
||||||
for llm in inference_engines
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Check if the weights are updated.")
|
|
||||||
for llm in inference_engines:
|
|
||||||
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
|
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import gc
|
import gc
|
||||||
|
import pickle
|
||||||
|
from ast import Dict, Tuple
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from enum import Enum
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -92,6 +95,15 @@ class FlattenedTensorMetadata(TypedDict):
|
|||||||
offset: int
|
offset: int
|
||||||
|
|
||||||
|
|
||||||
|
class PayloadType(Enum):
|
||||||
|
"""Enumerates possible payload types in IPC protocol."""
|
||||||
|
|
||||||
|
HANDLES = "handles"
|
||||||
|
BUFFER_UPDATE = "buffer_update"
|
||||||
|
DONE = "done"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
class ColocateWorkerExtension:
|
class ColocateWorkerExtension:
|
||||||
"""
|
"""
|
||||||
The class for vLLM's worker to inherit from, in the colocate setting.
|
The class for vLLM's worker to inherit from, in the colocate setting.
|
||||||
@ -152,12 +164,90 @@ class ColocateWorkerExtension:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def update_weights_from_ipc_async(self, zmq_handles: dict[str, str]):
|
||||||
|
assert self.device is not None
|
||||||
|
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
|
||||||
|
self._zmq_ctx = zmq.Context()
|
||||||
|
socket = self._zmq_ctx.socket(zmq.ROUTER)
|
||||||
|
socket.bind(zmq_handles[self.report_device_id()])
|
||||||
|
poller = zmq.Poller()
|
||||||
|
poller.register(socket, zmq.POLLIN)
|
||||||
|
|
||||||
|
buffers: Dict[int, torch.Tensor] = {}
|
||||||
|
while True:
|
||||||
|
events = dict(poller.poll(timeout=100))
|
||||||
|
if socket in events and (events[socket] & zmq.POLLIN):
|
||||||
|
# Router identity
|
||||||
|
identity = socket.recv()
|
||||||
|
|
||||||
|
payload: (
|
||||||
|
list[tuple[Callable, tuple]]
|
||||||
|
| tuple[int, list[FlattenedTensorMetadata]]
|
||||||
|
| None
|
||||||
|
) = socket.recv_pyobj()
|
||||||
|
|
||||||
|
payload_type = self._identify_payload_type(payload)
|
||||||
|
|
||||||
|
# === HANDLE LIST OF SHARED MEMORY HANDLES ===
|
||||||
|
if payload_type == PayloadType.HANDLES:
|
||||||
|
handles: list[tuple[Callable, tuple]] = payload
|
||||||
|
for i, h in enumerate(handles):
|
||||||
|
buffers[i] = rebuild_ipc(h, self.device.index)
|
||||||
|
socket.send_multipart([identity, pickle.dumps("ACK_HANDLES")])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# === HANDLE BUFFERED MODEL UPDATES ===
|
||||||
|
if payload_type == PayloadType.BUFFER_UPDATE:
|
||||||
|
buf_id, items = payload
|
||||||
|
buffer = buffers.get(buf_id)
|
||||||
|
if buffer is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
weights: list[Tuple[str, torch.Tensor]] = []
|
||||||
|
for item in items:
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
shape = torch.Size(item["shape"])
|
||||||
|
dtype, offset = item["dtype"], item["offset"]
|
||||||
|
size = dtype.itemsize * shape.numel()
|
||||||
|
tensor = (
|
||||||
|
buffer[offset : offset + size].view(dtype=dtype).view(shape)
|
||||||
|
)
|
||||||
|
weights.append((item["name"], tensor))
|
||||||
|
|
||||||
|
self.model_runner.model.load_weights(weights=weights)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# # --- Added verify step here ---
|
||||||
|
# if self.verify_weights_enabled:
|
||||||
|
# self.verify_weights(weights)
|
||||||
|
socket.send_multipart([identity, pickle.dumps(buf_id)])
|
||||||
|
|
||||||
|
# === DONE SIGNAL ===
|
||||||
|
elif payload_type == PayloadType.DONE:
|
||||||
|
socket.send_multipart([identity, pickle.dumps("DONE")])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def report_device_id(self) -> str:
|
def report_device_id(self) -> str:
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||||
return self.device_uuid
|
return self.device_uuid
|
||||||
|
|
||||||
|
def _identify_payload_type(self, payload) -> PayloadType:
|
||||||
|
if isinstance(payload, list):
|
||||||
|
return PayloadType.HANDLES
|
||||||
|
elif isinstance(payload, tuple):
|
||||||
|
buf_id, _ = payload
|
||||||
|
if buf_id is None:
|
||||||
|
return PayloadType.DONE
|
||||||
|
return PayloadType.BUFFER_UPDATE
|
||||||
|
return PayloadType.UNKNOWN
|
||||||
|
|
||||||
def check_weights_changed(self):
|
def check_weights_changed(self):
|
||||||
"""
|
"""
|
||||||
Check if the weights are updated to 0.
|
Check if the weights are updated to 0.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user