diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ab6a576b22b8..948eab97ffae 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -128,7 +128,7 @@ steps: - tests/spec_decode/e2e/test_integration_dist_tp4 - tests/compile - examples/offline_inference/rlhf.py - - examples/offline_inference/ray_placement.py + - examples/offline_inference/rlhf_colocate.py commands: - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py @@ -137,7 +137,7 @@ steps: # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - python3 ../examples/offline_inference/rlhf.py - - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py + - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py - label: Metrics, Tracing Test # 10min num_gpus: 2 diff --git a/examples/offline_inference/ray_placement.py b/examples/offline_inference/rlhf_colocate.py similarity index 56% rename from examples/offline_inference/ray_placement.py rename to examples/offline_inference/rlhf_colocate.py index cd801a3c0c85..b921bc71feb9 100644 --- a/examples/offline_inference/ray_placement.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 """ -a simple demonstration to show how to control -the placement of the vLLM workers with Ray. -The key is to set VLLM_RAY_PER_WORKER_GPUS and -VLLM_RAY_BUNDLE_INDICES properly. +a simple demonstration to show how to co-locate +vLLM worker with training actors on the same GPUs, +for RLHF-like applications. +The key points: +- Control the placement of the vLLM workers with Ray, by setting + VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly. +- Use cuda-ipc to pass tensors, since NCCL does not work when we have + multiple processes on the same GPU. """ import os import ray +import torch from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -19,7 +24,33 @@ class MyWorker(Worker): def report_device_id(self) -> str: from vllm.platforms import current_platform - return current_platform.get_device_uuid(self.device.index) + self.device_uuid = current_platform.get_device_uuid(self.device.index) + return self.device_uuid + + def update_weights_from_ipc_handles(self, ipc_handles): + handles = ipc_handles[self.device_uuid] + device_id = self.device.index + weights = [] + for name, handle in handles.items(): + func, args = handle + list_args = list(args) + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + tensor = func(*list_args) + weights.append((name, tensor)) + self.model_runner.model.load_weights(weights=weights) + torch.cuda.synchronize() + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated class MyLLM(LLM): @@ -40,12 +71,32 @@ class MyLLM(LLM): class RayTrainingActor: - def report_device_id(self) -> str: + def __init__(self): + # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs + from transformers import AutoModelForCausalLM + self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + self.model.to("cuda:0") + for name, p in self.model.named_parameters(): + p.data.zero_() + torch.cuda.synchronize() # the argument for get_device_uuid is the index # of the GPU in the visible devices. - # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs from vllm.platforms import current_platform - return current_platform.get_device_uuid(0) + self.device_uuid = current_platform.get_device_uuid(0) + + def report_device_id(self) -> str: + return self.device_uuid + + def get_weight_ipc_handles(self): + from torch.multiprocessing.reductions import reduce_tensor + data = {} + for name, p in self.model.named_parameters(): + # the training actor might only have a subset of the weights + # and need to all-gather the weights from all the actors. + # for demonstration, here we assume all training actors have + # the full weights. + data[name] = reduce_tensor(p.detach()) + return {self.device_uuid: data} # ray manages 4 GPUs @@ -78,6 +129,8 @@ for bundle_index in [0, 1, 2, 3]: ), )(RayTrainingActor).remote() training_actors.append(training_actor) + +for bundle_index, training_actor in enumerate(training_actors): device_id = ray.get(training_actor.report_device_id.remote()) print(f"training actor {bundle_index} is on {device_id}") training_actor_device_ids.append(device_id) @@ -119,3 +172,18 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0] # the last two training actors should be # on the same GPUs as the second inference engine assert training_actor_device_ids[2:] == inference_engine_device_ids[1] + +print("gather all the IPC handles from the training actors") +ipc_handles = {} +for actor in training_actors: + ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) + +print("update the weights of the inference engines") +for llm in inference_engines: + ray.get( + llm.collective_rpc.remote("update_weights_from_ipc_handles", + args=(ipc_handles, ))) +print("check if the weights are updated") +for llm in inference_engines: + assert ray.get( + llm.collective_rpc.remote("check_weights_changed", args=tuple()))