mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-27 05:01:23 +08:00
Merge 5cb240825e4f81fc69a34a1ca96424018691c3c5 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
7895845974
@ -28,8 +28,10 @@ Learn more about Ray placement groups:
|
||||
https://docs.ray.io/en/latest/placement-groups.html
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@ -156,96 +158,237 @@ class RayTrainingActor:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def update_weights_async(self, num_handles: int):
|
||||
align_size = 256
|
||||
|
||||
# Ray manages four GPUs.
|
||||
def get_size(p: torch.Tensor) -> int:
|
||||
return (p.nbytes + align_size - 1) // align_size * align_size
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||
ray.init()
|
||||
named_parameters: dict[str, torch.nn.Parameter] = dict(
|
||||
self.model.named_parameters()
|
||||
)
|
||||
max_tensor_size = max(get_size(p) for p in named_parameters.values())
|
||||
# use max_tensor_size * 2 as buffer size
|
||||
buffer_capacity = max_tensor_size * 2
|
||||
buffers = [
|
||||
torch.empty(buffer_capacity, dtype=torch.uint8, device="cuda:0")
|
||||
for _ in range(num_handles)
|
||||
]
|
||||
handles = [reduce_tensor(b) for b in buffers]
|
||||
|
||||
# Co-locate vLLM instances and training actors on the same set of GPUs:
|
||||
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
|
||||
# (tensor parallelism = 2).
|
||||
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
|
||||
# (tensor parallelism = 2).
|
||||
# Establish ZMQ connection
|
||||
s = self.zmq_context.socket(zmq.DEALER)
|
||||
s.connect(self.zmq_handle)
|
||||
s.send_pyobj(handles)
|
||||
s.recv_multipart() # ACK for handles
|
||||
|
||||
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
|
||||
ray.get(pg.ready())
|
||||
print(f"placement group has bundles {pg.bundle_specs=}")
|
||||
# === Partition tensors into buffers ===
|
||||
offset = 0
|
||||
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
|
||||
named_tensors: list[dict] = []
|
||||
real_tensors: list[torch.Tensor] = []
|
||||
|
||||
training_actors = []
|
||||
training_actor_device_ids = []
|
||||
inference_engines = []
|
||||
inference_engine_device_ids = []
|
||||
for name, p in named_parameters.items():
|
||||
size = get_size(p)
|
||||
if offset + size > buffer_capacity:
|
||||
buckets.append((named_tensors, real_tensors))
|
||||
named_tensors, real_tensors = [], []
|
||||
offset = 0
|
||||
# assume tensors are contiguous
|
||||
named_tensors.append(
|
||||
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
|
||||
)
|
||||
real_tensors.append(p)
|
||||
offset += size
|
||||
if named_tensors:
|
||||
buckets.append((named_tensors, real_tensors))
|
||||
|
||||
for bundle_index in [0, 1, 2, 3]:
|
||||
training_actor = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0.4,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_index,
|
||||
),
|
||||
)(RayTrainingActor).remote()
|
||||
training_actors.append(training_actor)
|
||||
poller = zmq.Poller()
|
||||
poller.register(s, zmq.POLLOUT)
|
||||
free_buffers = list(range(num_handles))
|
||||
inflight = 0
|
||||
idx = 0
|
||||
total = len(buckets)
|
||||
print(f"[Training] Total {total} buckets to send.")
|
||||
|
||||
for bundle_index, training_actor in enumerate(training_actors):
|
||||
device_id = ray.get(training_actor.report_device_id.remote())
|
||||
print(f"training actor {bundle_index} is on {device_id}")
|
||||
training_actor_device_ids.append(device_id)
|
||||
# === Send loop ===
|
||||
while idx < total or inflight > 0:
|
||||
events = dict(poller.poll(timeout=50))
|
||||
if (
|
||||
s in events
|
||||
and (events[s] & zmq.POLLOUT)
|
||||
and idx < total
|
||||
and free_buffers
|
||||
):
|
||||
buf_id = free_buffers.pop(0)
|
||||
meta, tensors = buckets[idx]
|
||||
buffer = buffers[buf_id]
|
||||
offset = 0
|
||||
for item, t in zip(meta, tensors):
|
||||
size = get_size(t)
|
||||
buffer[offset : offset + t.nbytes].copy_(
|
||||
t.contiguous().view(torch.uint8).flatten()
|
||||
)
|
||||
offset += size
|
||||
torch.cuda.synchronize()
|
||||
s.send_pyobj((buf_id, meta))
|
||||
inflight += 1
|
||||
print(f"[Training] Sent bucket {idx} using buffer {buf_id}.")
|
||||
idx += 1
|
||||
|
||||
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
|
||||
# Use the following syntax instead of the @ray.remote decorator so that
|
||||
# the placement group is customized for each bundle.
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
),
|
||||
)(MyLLM).remote(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="ray",
|
||||
gpu_memory_utilization=0.4,
|
||||
bundle_indices=bundle_indices,
|
||||
# Receive buffer-free ACKs
|
||||
try:
|
||||
while s.getsockopt(zmq.EVENTS) & zmq.POLLIN:
|
||||
ack = int(s.recv_multipart(flags=zmq.NOBLOCK)[0].decode())
|
||||
free_buffers.append(ack)
|
||||
inflight -= 1
|
||||
print(f"[Training] Ack received: buffer {ack} now free.")
|
||||
except zmq.Again:
|
||||
pass
|
||||
|
||||
# Signal done
|
||||
s.send_pyobj((None, None))
|
||||
s.recv_multipart() # DONE ack
|
||||
s.close()
|
||||
del buffers
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def setup_train_cluster():
|
||||
# Ray manages four GPUs.
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||
ray.init()
|
||||
|
||||
# Co-locate vLLM instances and training actors on the same set of GPUs:
|
||||
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
|
||||
# (tensor parallelism = 2).
|
||||
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
|
||||
# (tensor parallelism = 2).
|
||||
|
||||
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
|
||||
ray.get(pg.ready())
|
||||
print(f"placement group has bundles {pg.bundle_specs=}")
|
||||
|
||||
training_actors = []
|
||||
training_actor_device_ids = []
|
||||
inference_engines = []
|
||||
inference_engine_device_ids = []
|
||||
|
||||
for bundle_index in [0, 1, 2, 3]:
|
||||
training_actor = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0.4,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_index,
|
||||
),
|
||||
)(RayTrainingActor).remote()
|
||||
training_actors.append(training_actor)
|
||||
|
||||
for bundle_index, training_actor in enumerate(training_actors):
|
||||
device_id = ray.get(training_actor.report_device_id.remote())
|
||||
print(f"training actor {bundle_index} is on {device_id}")
|
||||
training_actor_device_ids.append(device_id)
|
||||
|
||||
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
|
||||
# Use the following syntax instead of the @ray.remote decorator so that
|
||||
# the placement group is customized for each bundle.
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
),
|
||||
)(MyLLM).remote(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="ray",
|
||||
gpu_memory_utilization=0.3,
|
||||
bundle_indices=bundle_indices,
|
||||
)
|
||||
inference_engines.append(llm)
|
||||
# Do not call any method on the inference engine at this point; the call
|
||||
# blocks until the vLLM instance finishes initialization.
|
||||
|
||||
for i, llm in enumerate(inference_engines):
|
||||
inference_engine_device_ids.append(
|
||||
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
|
||||
)
|
||||
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
|
||||
|
||||
# Verify placement: the first two training actors share the same GPUs as
|
||||
# the first inference engine.
|
||||
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
|
||||
# Verify placement: the last two training actors share the same GPUs as
|
||||
# the second inference engine.
|
||||
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
|
||||
|
||||
print("Gather all the ZMQ handles from the training actors.")
|
||||
zmq_handles = {}
|
||||
for actor in training_actors:
|
||||
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
|
||||
|
||||
print(f"ZMQ handles: {zmq_handles}")
|
||||
return training_actors, inference_engines, zmq_handles
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Update model weights across training and inference actors."
|
||||
)
|
||||
inference_engines.append(llm)
|
||||
# Do not call any method on the inference engine at this point; the call
|
||||
# blocks until the vLLM instance finishes initialization.
|
||||
|
||||
for i, llm in enumerate(inference_engines):
|
||||
inference_engine_device_ids.append(
|
||||
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
|
||||
parser.add_argument(
|
||||
"--num-ipc-handles",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of IPC handles. If 1, use synchronous update; \
|
||||
if >1, use asynchronous update.",
|
||||
)
|
||||
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Verify placement: the first two training actors share the same GPUs as
|
||||
# the first inference engine.
|
||||
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
|
||||
# Verify placement: the last two training actors share the same GPUs as
|
||||
# the second inference engine.
|
||||
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
|
||||
num_handles = args.num_ipc_handles
|
||||
training_actors, inference_engines, zmq_handles = setup_train_cluster()
|
||||
print("Update the weights of the inference engines.")
|
||||
start_time = time.time()
|
||||
if num_handles == 1:
|
||||
# Synchronous update
|
||||
ray.get(
|
||||
[actor.update_weights.remote() for actor in training_actors]
|
||||
+ [
|
||||
llm.collective_rpc.remote(
|
||||
"update_weights_from_ipc", args=(zmq_handles,)
|
||||
)
|
||||
for llm in inference_engines
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Asynchronous update
|
||||
ray.get(
|
||||
[
|
||||
actor.update_weights_async.remote(num_handles=num_handles)
|
||||
for actor in training_actors
|
||||
]
|
||||
+ [
|
||||
llm.collective_rpc.remote(
|
||||
"update_weights_from_ipc_async", args=(zmq_handles,)
|
||||
)
|
||||
for llm in inference_engines
|
||||
]
|
||||
)
|
||||
end_time = time.time()
|
||||
elapsed = end_time - start_time
|
||||
print(f"Weight update completed in {elapsed:.2f} seconds.")
|
||||
print("Check if the weights are updated.")
|
||||
for llm in inference_engines:
|
||||
assert ray.get(
|
||||
llm.collective_rpc.remote("check_weights_changed", args=tuple())
|
||||
), "Weights were not updated properly!"
|
||||
|
||||
print("Gather all the ZMQ handles from the training actors.")
|
||||
zmq_handles = {}
|
||||
for actor in training_actors:
|
||||
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
|
||||
|
||||
print(f"ZMQ handles: {zmq_handles}")
|
||||
|
||||
print("Update the weights of the inference engines.")
|
||||
ray.get(
|
||||
[actor.update_weights.remote() for actor in training_actors]
|
||||
+ [
|
||||
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
|
||||
for llm in inference_engines
|
||||
]
|
||||
)
|
||||
|
||||
print("Check if the weights are updated.")
|
||||
for llm in inference_engines:
|
||||
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
from ast import Dict, Tuple
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
@ -92,6 +94,15 @@ class FlattenedTensorMetadata(TypedDict):
|
||||
offset: int
|
||||
|
||||
|
||||
class PayloadType(Enum):
|
||||
"""Enumerates possible payload types in IPC protocol."""
|
||||
|
||||
HANDLES = "handles"
|
||||
BUFFER_UPDATE = "buffer_update"
|
||||
DONE = "done"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class ColocateWorkerExtension:
|
||||
"""
|
||||
The class for vLLM's worker to inherit from, in the colocate setting.
|
||||
@ -152,12 +163,87 @@ class ColocateWorkerExtension:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def update_weights_from_ipc_async(self, zmq_handles: dict[str, str]):
|
||||
assert self.device is not None
|
||||
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
|
||||
self._zmq_ctx = zmq.Context()
|
||||
socket = self._zmq_ctx.socket(zmq.ROUTER)
|
||||
socket.bind(zmq_handles[self.report_device_id()])
|
||||
poller = zmq.Poller()
|
||||
poller.register(socket, zmq.POLLIN)
|
||||
|
||||
buffers: Dict[int, torch.Tensor] = {}
|
||||
while True:
|
||||
events = dict(poller.poll(timeout=100))
|
||||
if socket in events and (events[socket] & zmq.POLLIN):
|
||||
# Router identity
|
||||
identity = socket.recv()
|
||||
|
||||
payload: (
|
||||
list[tuple[Callable, tuple]]
|
||||
| tuple[int, list[FlattenedTensorMetadata]]
|
||||
| None
|
||||
) = socket.recv_pyobj()
|
||||
|
||||
payload_type = self._identify_payload_type(payload)
|
||||
|
||||
# === HANDLE LIST OF SHARED MEMORY HANDLES ===
|
||||
if payload_type == PayloadType.HANDLES:
|
||||
handles: list[tuple[Callable, tuple]] = payload
|
||||
for i, h in enumerate(handles):
|
||||
buffers[i] = rebuild_ipc(h, self.device.index)
|
||||
socket.send_multipart([identity, b"ACK_HANDLES"])
|
||||
continue
|
||||
|
||||
# === HANDLE BUFFERED MODEL UPDATES ===
|
||||
if payload_type == PayloadType.BUFFER_UPDATE:
|
||||
buf_id, items = payload
|
||||
buffer = buffers.get(buf_id)
|
||||
if buffer is None:
|
||||
continue
|
||||
|
||||
weights: list[Tuple[str, torch.Tensor]] = []
|
||||
for item in items:
|
||||
assert isinstance(item, dict)
|
||||
shape = torch.Size(item["shape"])
|
||||
dtype, offset = item["dtype"], item["offset"]
|
||||
size = dtype.itemsize * shape.numel()
|
||||
tensor = (
|
||||
buffer[offset : offset + size].view(dtype=dtype).view(shape)
|
||||
)
|
||||
weights.append((item["name"], tensor))
|
||||
|
||||
self.model_runner.model.load_weights(weights=weights)
|
||||
torch.cuda.synchronize()
|
||||
socket.send_multipart([identity, str(buf_id).encode()])
|
||||
|
||||
# === DONE SIGNAL ===
|
||||
elif payload_type == PayloadType.DONE:
|
||||
socket.send_multipart([identity, b"DONE"])
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
socket.close()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||
return self.device_uuid
|
||||
|
||||
def _identify_payload_type(self, payload) -> PayloadType:
|
||||
if isinstance(payload, list):
|
||||
return PayloadType.HANDLES
|
||||
elif isinstance(payload, tuple):
|
||||
buf_id, _ = payload
|
||||
if buf_id is None:
|
||||
return PayloadType.DONE
|
||||
return PayloadType.BUFFER_UPDATE
|
||||
return PayloadType.UNKNOWN
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user