diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index fa7e57fca78a3..dbd9aff404f6e 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -34,6 +34,7 @@ import os import ray import torch +import time import zmq from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -353,6 +354,7 @@ def main(): num_handles = args.num_ipc_handles training_actors, inference_engines, zmq_handles = setup_train_cluster() print("Update the weights of the inference engines.") + start_time = time.time() if num_handles == 1: # Synchronous update ray.get( @@ -378,7 +380,9 @@ def main(): for llm in inference_engines ] ) - + end_time = time.time() + elapsed = end_time - start_time + print(f"Weight update completed in {elapsed:.2f} seconds.") print("Check if the weights are updated.") for llm in inference_engines: assert ray.get(