[Docs] Improve documentation for RLHF example (#20598)

Signed-off-by: Ricardo Decal <rdecal@anyscale.com>
This commit is contained in:
Ricardo Decal 2025-07-15 04:54:10 -04:00 committed by GitHub
parent 68d28e37b0
commit 235bfd5dfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,17 +1,31 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
a simple demonstration of RLHF with vLLM, inspired by Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
It follows the design that, training processes and inference processes The script separates training and inference workloads onto distinct GPUs
are different, and they live on different GPUs. so that Ray can manage process placement and inter-process communication.
Training processes send prompts to inference processes to generate data, A Hugging Face Transformer model occupies GPU 0 for training, whereas a
and also synchronize the weights of the model by broadcasting the weights tensor-parallel vLLM inference engine occupies GPU 12.
from the training process to the inference process.
Note that this is a simple demonstration of one training instance and one The example performs the following steps:
inference instance. In practice, there could be multiple training instances
and multiple inference instances. For the full implementation, please refer * Load the training model on GPU 0.
to the OpenRLHF framework. * Split the inference model across GPUs 12 using vLLM's tensor parallelism
and Ray placement groups.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group. Note that
for demonstration purposes we simply zero out the weights.
For a production-ready implementation that supports multiple training and
inference replicas, see the OpenRLHF framework:
https://github.com/OpenRLHF/OpenRLHF
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
""" """
import os import os
@ -28,29 +42,27 @@ from vllm.utils import get_ip, get_open_port
class MyLLM(LLM): class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# a hack to make the script work. # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# stop ray from manipulating CUDA_VISIBLE_DEVICES # so that vLLM can manage its own device placement within the worker.
# at the top-level
os.environ.pop("CUDA_VISIBLE_DEVICES", None) os.environ.pop("CUDA_VISIBLE_DEVICES", None)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
""" # Load the OPT-125M model onto GPU 0 for the training workload.
Start the training process, here we use huggingface transformers
as an example to hold a model on GPU 0.
"""
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0") train_model.to("cuda:0")
"""
Start the inference process, here we use vLLM to hold a model on GPU 1 and # Initialize Ray and set the visible devices. The vLLM engine will
GPU 2. For the details on how to use ray, please refer to the ray # be placed on GPUs 1 and 2.
documentation https://docs.ray.io/en/latest/ .
"""
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init() ray.init()
# Create a placement group that reserves GPU 12 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready()) ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy( scheduling_inference = PlacementGroupSchedulingStrategy(
@ -58,10 +70,9 @@ scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group_capture_child_tasks=True, placement_group_capture_child_tasks=True,
placement_group_bundle_index=0, placement_group_bundle_index=0,
) )
"""
launch the vLLM inference engine. # Launch the vLLM inference engine. The `enforce_eager` flag reduces
here we use `enforce_eager` to reduce the start time. # start-up latency.
"""
llm = ray.remote( llm = ray.remote(
num_cpus=0, num_cpus=0,
num_gpus=0, num_gpus=0,
@ -74,7 +85,7 @@ llm = ray.remote(
distributed_executor_backend="ray", distributed_executor_backend="ray",
) )
# Generate texts from the prompts. # Generate text from the prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
@ -93,8 +104,8 @@ for output in outputs:
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50) print("-" * 50)
# set up the communication between the training process # Set up the communication channel between the training process and the
# and the inference engine. # inference engine.
master_address = get_ip() master_address = get_ip()
master_port = get_open_port() master_port = get_open_port()
@ -107,21 +118,23 @@ model_update_group = stateless_init_process_group(
) )
ray.get(handle) ray.get(handle)
# simulate training, modify the weights of the model. # Simulate a training step by zeroing out all model weights.
# In a real RLHF training loop the weights would be updated using the gradient
# from an RL objective such as PPO on a reward model.
for name, p in train_model.named_parameters(): for name, p in train_model.named_parameters():
p.data.zero_() p.data.zero_()
# sync weight from the training process to the inference engine. # Synchronize the updated weights to the inference engine.
for name, p in train_model.named_parameters(): for name, p in train_model.named_parameters():
handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape)) handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle) ray.get(handle)
# check if the weights are updated. # Verify that the inference weights have been updated.
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed"))) assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
# use the updated model to generate texts, they will be nonsense # Generate text with the updated model. The output is expected to be nonsense
# because the weights are all zeros. # because the weights are zero.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50) print("-" * 50)
for output in outputs_updated: for output in outputs_updated: