[Docs] Improve documentation for RLHF example (#20598)

Signed-off-by: Ricardo Decal <rdecal@anyscale.com>
This commit is contained in:
Ricardo Decal 2025-07-15 04:54:10 -04:00 committed by GitHub
parent 68d28e37b0
commit 235bfd5dfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,17 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
a simple demonstration of RLHF with vLLM, inspired by
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
It follows the design that, training processes and inference processes
are different, and they live on different GPUs.
Training processes send prompts to inference processes to generate data,
and also synchronize the weights of the model by broadcasting the weights
from the training process to the inference process.
Note that this is a simple demonstration of one training instance and one
inference instance. In practice, there could be multiple training instances
and multiple inference instances. For the full implementation, please refer
to the OpenRLHF framework.
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
tensor-parallel vLLM inference engine occupies GPU 12.
The example performs the following steps:
* Load the training model on GPU 0.
* Split the inference model across GPUs 12 using vLLM's tensor parallelism
and Ray placement groups.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group. Note that
for demonstration purposes we simply zero out the weights.
For a production-ready implementation that supports multiple training and
inference replicas, see the OpenRLHF framework:
https://github.com/OpenRLHF/OpenRLHF
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import os
@ -28,29 +42,27 @@ from vllm.utils import get_ip, get_open_port
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
# a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES
# at the top-level
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# so that vLLM can manage its own device placement within the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
super().__init__(*args, **kwargs)
"""
Start the training process, here we use huggingface transformers
as an example to hold a model on GPU 0.
"""
# Load the OPT-125M model onto GPU 0 for the training workload.
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")
"""
Start the inference process, here we use vLLM to hold a model on GPU 1 and
GPU 2. For the details on how to use ray, please refer to the ray
documentation https://docs.ray.io/en/latest/ .
"""
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()
# Create a placement group that reserves GPU 12 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
@ -58,10 +70,9 @@ scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
"""
launch the vLLM inference engine.
here we use `enforce_eager` to reduce the start time.
"""
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
@ -74,7 +85,7 @@ llm = ray.remote(
distributed_executor_backend="ray",
)
# Generate texts from the prompts.
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
@ -93,8 +104,8 @@ for output in outputs:
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# set up the communication between the training process
# and the inference engine.
# Set up the communication channel between the training process and the
# inference engine.
master_address = get_ip()
master_port = get_open_port()
@ -107,21 +118,23 @@ model_update_group = stateless_init_process_group(
)
ray.get(handle)
# simulate training, modify the weights of the model.
# Simulate a training step by zeroing out all model weights.
# In a real RLHF training loop the weights would be updated using the gradient
# from an RL objective such as PPO on a reward model.
for name, p in train_model.named_parameters():
p.data.zero_()
# sync weight from the training process to the inference engine.
# Synchronize the updated weights to the inference engine.
for name, p in train_model.named_parameters():
handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle)
# check if the weights are updated.
# Verify that the inference weights have been updated.
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
# use the updated model to generate texts, they will be nonsense
# because the weights are all zeros.
# Generate text with the updated model. The output is expected to be nonsense
# because the weights are zero.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated: