Add estimate time for weight update

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
knlnguyen1802 2025-11-19 15:17:12 +08:00
parent c64ec86c33
commit 18b68fad65

View File

@ -34,6 +34,7 @@ import os
import ray
import torch
import time
import zmq
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@ -353,6 +354,7 @@ def main():
num_handles = args.num_ipc_handles
training_actors, inference_engines, zmq_handles = setup_train_cluster()
print("Update the weights of the inference engines.")
start_time = time.time()
if num_handles == 1:
# Synchronous update
ray.get(
@ -378,7 +380,9 @@ def main():
for llm in inference_engines
]
)
end_time = time.time()
elapsed = end_time - start_time
print(f"Weight update completed in {elapsed:.2f} seconds.")
print("Check if the weights are updated.")
for llm in inference_engines:
assert ray.get(