mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-09 21:52:38 +08:00
[Docs] Improve documentation for RLHF example (#20598)
Signed-off-by: Ricardo Decal <rdecal@anyscale.com>
This commit is contained in:
parent
68d28e37b0
commit
235bfd5dfe
@ -1,17 +1,31 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""
|
"""
|
||||||
a simple demonstration of RLHF with vLLM, inspired by
|
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
|
||||||
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
|
|
||||||
It follows the design that, training processes and inference processes
|
The script separates training and inference workloads onto distinct GPUs
|
||||||
are different, and they live on different GPUs.
|
so that Ray can manage process placement and inter-process communication.
|
||||||
Training processes send prompts to inference processes to generate data,
|
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
|
||||||
and also synchronize the weights of the model by broadcasting the weights
|
tensor-parallel vLLM inference engine occupies GPU 1–2.
|
||||||
from the training process to the inference process.
|
|
||||||
Note that this is a simple demonstration of one training instance and one
|
The example performs the following steps:
|
||||||
inference instance. In practice, there could be multiple training instances
|
|
||||||
and multiple inference instances. For the full implementation, please refer
|
* Load the training model on GPU 0.
|
||||||
to the OpenRLHF framework.
|
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
|
||||||
|
and Ray placement groups.
|
||||||
|
* Generate text from a list of prompts using the inference engine.
|
||||||
|
* Update the weights of the training model and broadcast the updated weights
|
||||||
|
to the inference engine by using a Ray collective RPC group. Note that
|
||||||
|
for demonstration purposes we simply zero out the weights.
|
||||||
|
|
||||||
|
For a production-ready implementation that supports multiple training and
|
||||||
|
inference replicas, see the OpenRLHF framework:
|
||||||
|
https://github.com/OpenRLHF/OpenRLHF
|
||||||
|
|
||||||
|
This example assumes a single-node cluster with three GPUs, but Ray
|
||||||
|
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
|
||||||
|
workloads. Residual GPU activity interferes with vLLM memory profiling and
|
||||||
|
causes unexpected behavior.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -28,29 +42,27 @@ from vllm.utils import get_ip, get_open_port
|
|||||||
|
|
||||||
|
|
||||||
class MyLLM(LLM):
|
class MyLLM(LLM):
|
||||||
|
"""Configure the vLLM worker for Ray placement group execution."""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
# a hack to make the script work.
|
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
|
||||||
# stop ray from manipulating CUDA_VISIBLE_DEVICES
|
# so that vLLM can manage its own device placement within the worker.
|
||||||
# at the top-level
|
|
||||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
"""
|
# Load the OPT-125M model onto GPU 0 for the training workload.
|
||||||
Start the training process, here we use huggingface transformers
|
|
||||||
as an example to hold a model on GPU 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||||
train_model.to("cuda:0")
|
train_model.to("cuda:0")
|
||||||
"""
|
|
||||||
Start the inference process, here we use vLLM to hold a model on GPU 1 and
|
# Initialize Ray and set the visible devices. The vLLM engine will
|
||||||
GPU 2. For the details on how to use ray, please refer to the ray
|
# be placed on GPUs 1 and 2.
|
||||||
documentation https://docs.ray.io/en/latest/ .
|
|
||||||
"""
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
|
||||||
ray.init()
|
ray.init()
|
||||||
|
|
||||||
|
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
||||||
|
# Learn more about Ray placement groups:
|
||||||
|
# https://docs.ray.io/en/latest/placement-groups.html
|
||||||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||||
ray.get(pg_inference.ready())
|
ray.get(pg_inference.ready())
|
||||||
scheduling_inference = PlacementGroupSchedulingStrategy(
|
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||||
@ -58,10 +70,9 @@ scheduling_inference = PlacementGroupSchedulingStrategy(
|
|||||||
placement_group_capture_child_tasks=True,
|
placement_group_capture_child_tasks=True,
|
||||||
placement_group_bundle_index=0,
|
placement_group_bundle_index=0,
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
launch the vLLM inference engine.
|
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
|
||||||
here we use `enforce_eager` to reduce the start time.
|
# start-up latency.
|
||||||
"""
|
|
||||||
llm = ray.remote(
|
llm = ray.remote(
|
||||||
num_cpus=0,
|
num_cpus=0,
|
||||||
num_gpus=0,
|
num_gpus=0,
|
||||||
@ -74,7 +85,7 @@ llm = ray.remote(
|
|||||||
distributed_executor_backend="ray",
|
distributed_executor_backend="ray",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate texts from the prompts.
|
# Generate text from the prompts.
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
@ -93,8 +104,8 @@ for output in outputs:
|
|||||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
# set up the communication between the training process
|
# Set up the communication channel between the training process and the
|
||||||
# and the inference engine.
|
# inference engine.
|
||||||
master_address = get_ip()
|
master_address = get_ip()
|
||||||
master_port = get_open_port()
|
master_port = get_open_port()
|
||||||
|
|
||||||
@ -107,21 +118,23 @@ model_update_group = stateless_init_process_group(
|
|||||||
)
|
)
|
||||||
ray.get(handle)
|
ray.get(handle)
|
||||||
|
|
||||||
# simulate training, modify the weights of the model.
|
# Simulate a training step by zeroing out all model weights.
|
||||||
|
# In a real RLHF training loop the weights would be updated using the gradient
|
||||||
|
# from an RL objective such as PPO on a reward model.
|
||||||
for name, p in train_model.named_parameters():
|
for name, p in train_model.named_parameters():
|
||||||
p.data.zero_()
|
p.data.zero_()
|
||||||
|
|
||||||
# sync weight from the training process to the inference engine.
|
# Synchronize the updated weights to the inference engine.
|
||||||
for name, p in train_model.named_parameters():
|
for name, p in train_model.named_parameters():
|
||||||
handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
|
handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
|
||||||
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
|
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
|
||||||
ray.get(handle)
|
ray.get(handle)
|
||||||
|
|
||||||
# check if the weights are updated.
|
# Verify that the inference weights have been updated.
|
||||||
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
|
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
|
||||||
|
|
||||||
# use the updated model to generate texts, they will be nonsense
|
# Generate text with the updated model. The output is expected to be nonsense
|
||||||
# because the weights are all zeros.
|
# because the weights are zero.
|
||||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
for output in outputs_updated:
|
for output in outputs_updated:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user