[core][rlhf] add colocate example for RLHF (#12984)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-02-10 10:28:59 +08:00 committed by GitHub
parent 59fff4a01a
commit aa0ca5ebb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 10 deletions

View File

@ -128,7 +128,7 @@ steps:
- tests/spec_decode/e2e/test_integration_dist_tp4 - tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile - tests/compile
- examples/offline_inference/rlhf.py - examples/offline_inference/rlhf.py
- examples/offline_inference/ray_placement.py - examples/offline_inference/rlhf_colocate.py
commands: commands:
- pytest -v -s distributed/test_utils.py - pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
@ -137,7 +137,7 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests # TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests # when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py - python3 ../examples/offline_inference/rlhf.py
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
- label: Metrics, Tracing Test # 10min - label: Metrics, Tracing Test # 10min
num_gpus: 2 num_gpus: 2

View File

@ -1,13 +1,18 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
a simple demonstration to show how to control a simple demonstration to show how to co-locate
the placement of the vLLM workers with Ray. vLLM worker with training actors on the same GPUs,
The key is to set VLLM_RAY_PER_WORKER_GPUS and for RLHF-like applications.
VLLM_RAY_BUNDLE_INDICES properly. The key points:
- Control the placement of the vLLM workers with Ray, by setting
VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly.
- Use cuda-ipc to pass tensors, since NCCL does not work when we have
multiple processes on the same GPU.
""" """
import os import os
import ray import ray
import torch
from ray.util.placement_group import placement_group from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@ -19,7 +24,33 @@ class MyWorker(Worker):
def report_device_id(self) -> str: def report_device_id(self) -> str:
from vllm.platforms import current_platform from vllm.platforms import current_platform
return current_platform.get_device_uuid(self.device.index) self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
class MyLLM(LLM): class MyLLM(LLM):
@ -40,12 +71,32 @@ class MyLLM(LLM):
class RayTrainingActor: class RayTrainingActor:
def report_device_id(self) -> str: def __init__(self):
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from transformers import AutoModelForCausalLM
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
self.model.to("cuda:0")
for name, p in self.model.named_parameters():
p.data.zero_()
torch.cuda.synchronize()
# the argument for get_device_uuid is the index # the argument for get_device_uuid is the index
# of the GPU in the visible devices. # of the GPU in the visible devices.
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from vllm.platforms import current_platform from vllm.platforms import current_platform
return current_platform.get_device_uuid(0) self.device_uuid = current_platform.get_device_uuid(0)
def report_device_id(self) -> str:
return self.device_uuid
def get_weight_ipc_handles(self):
from torch.multiprocessing.reductions import reduce_tensor
data = {}
for name, p in self.model.named_parameters():
# the training actor might only have a subset of the weights
# and need to all-gather the weights from all the actors.
# for demonstration, here we assume all training actors have
# the full weights.
data[name] = reduce_tensor(p.detach())
return {self.device_uuid: data}
# ray manages 4 GPUs # ray manages 4 GPUs
@ -78,6 +129,8 @@ for bundle_index in [0, 1, 2, 3]:
), ),
)(RayTrainingActor).remote() )(RayTrainingActor).remote()
training_actors.append(training_actor) training_actors.append(training_actor)
for bundle_index, training_actor in enumerate(training_actors):
device_id = ray.get(training_actor.report_device_id.remote()) device_id = ray.get(training_actor.report_device_id.remote())
print(f"training actor {bundle_index} is on {device_id}") print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id) training_actor_device_ids.append(device_id)
@ -119,3 +172,18 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# the last two training actors should be # the last two training actors should be
# on the same GPUs as the second inference engine # on the same GPUs as the second inference engine
assert training_actor_device_ids[2:] == inference_engine_device_ids[1] assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
print("gather all the IPC handles from the training actors")
ipc_handles = {}
for actor in training_actors:
ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))
print("update the weights of the inference engines")
for llm in inference_engines:
ray.get(
llm.collective_rpc.remote("update_weights_from_ipc_handles",
args=(ipc_handles, )))
print("check if the weights are updated")
for llm in inference_engines:
assert ray.get(
llm.collective_rpc.remote("check_weights_changed", args=tuple()))