diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index c6e63531a99d1..752117a4e3623 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -1,17 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -a simple demonstration of RLHF with vLLM, inspired by -the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF . -It follows the design that, training processes and inference processes -are different, and they live on different GPUs. -Training processes send prompts to inference processes to generate data, -and also synchronize the weights of the model by broadcasting the weights -from the training process to the inference process. -Note that this is a simple demonstration of one training instance and one -inference instance. In practice, there could be multiple training instances -and multiple inference instances. For the full implementation, please refer -to the OpenRLHF framework. +Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray. + +The script separates training and inference workloads onto distinct GPUs +so that Ray can manage process placement and inter-process communication. +A Hugging Face Transformer model occupies GPU 0 for training, whereas a +tensor-parallel vLLM inference engine occupies GPU 1–2. + +The example performs the following steps: + +* Load the training model on GPU 0. +* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism + and Ray placement groups. +* Generate text from a list of prompts using the inference engine. +* Update the weights of the training model and broadcast the updated weights + to the inference engine by using a Ray collective RPC group. Note that + for demonstration purposes we simply zero out the weights. + +For a production-ready implementation that supports multiple training and +inference replicas, see the OpenRLHF framework: +https://github.com/OpenRLHF/OpenRLHF + +This example assumes a single-node cluster with three GPUs, but Ray +supports multi-node clusters. vLLM expects the GPUs are only used for vLLM +workloads. Residual GPU activity interferes with vLLM memory profiling and +causes unexpected behavior. """ import os @@ -28,29 +42,27 @@ from vllm.utils import get_ip, get_open_port class MyLLM(LLM): + """Configure the vLLM worker for Ray placement group execution.""" + def __init__(self, *args, **kwargs): - # a hack to make the script work. - # stop ray from manipulating CUDA_VISIBLE_DEVICES - # at the top-level + # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray + # so that vLLM can manage its own device placement within the worker. os.environ.pop("CUDA_VISIBLE_DEVICES", None) super().__init__(*args, **kwargs) -""" -Start the training process, here we use huggingface transformers -as an example to hold a model on GPU 0. -""" - +# Load the OPT-125M model onto GPU 0 for the training workload. train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") train_model.to("cuda:0") -""" -Start the inference process, here we use vLLM to hold a model on GPU 1 and -GPU 2. For the details on how to use ray, please refer to the ray -documentation https://docs.ray.io/en/latest/ . -""" + +# Initialize Ray and set the visible devices. The vLLM engine will +# be placed on GPUs 1 and 2. os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" ray.init() +# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. +# Learn more about Ray placement groups: +# https://docs.ray.io/en/latest/placement-groups.html pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) ray.get(pg_inference.ready()) scheduling_inference = PlacementGroupSchedulingStrategy( @@ -58,10 +70,9 @@ scheduling_inference = PlacementGroupSchedulingStrategy( placement_group_capture_child_tasks=True, placement_group_bundle_index=0, ) -""" -launch the vLLM inference engine. -here we use `enforce_eager` to reduce the start time. -""" + +# Launch the vLLM inference engine. The `enforce_eager` flag reduces +# start-up latency. llm = ray.remote( num_cpus=0, num_gpus=0, @@ -74,7 +85,7 @@ llm = ray.remote( distributed_executor_backend="ray", ) -# Generate texts from the prompts. +# Generate text from the prompts. prompts = [ "Hello, my name is", "The president of the United States is", @@ -93,8 +104,8 @@ for output in outputs: print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") print("-" * 50) -# set up the communication between the training process -# and the inference engine. +# Set up the communication channel between the training process and the +# inference engine. master_address = get_ip() master_port = get_open_port() @@ -107,21 +118,23 @@ model_update_group = stateless_init_process_group( ) ray.get(handle) -# simulate training, modify the weights of the model. +# Simulate a training step by zeroing out all model weights. +# In a real RLHF training loop the weights would be updated using the gradient +# from an RL objective such as PPO on a reward model. for name, p in train_model.named_parameters(): p.data.zero_() -# sync weight from the training process to the inference engine. +# Synchronize the updated weights to the inference engine. for name, p in train_model.named_parameters(): handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape)) model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) ray.get(handle) -# check if the weights are updated. +# Verify that the inference weights have been updated. assert all(ray.get(llm.collective_rpc.remote("check_weights_changed"))) -# use the updated model to generate texts, they will be nonsense -# because the weights are all zeros. +# Generate text with the updated model. The output is expected to be nonsense +# because the weights are zero. outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) print("-" * 50) for output in outputs_updated: