diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index 752117a4e3623..ed974b90b57ee 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -126,7 +126,10 @@ for name, p in train_model.named_parameters(): # Synchronize the updated weights to the inference engine. for name, p in train_model.named_parameters(): - handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape)) + dtype_name = str(p.dtype).split(".")[-1] + handle = llm.collective_rpc.remote( + "update_weight", args=(name, dtype_name, p.shape) + ) model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) ray.get(handle) diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index c445224d75686..d2a8419ffabcd 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -45,7 +45,8 @@ class WorkerExtension: self.device, ) - def update_weight(self, name, dtype, shape): + def update_weight(self, name, dtype_name, shape): + dtype = getattr(torch, dtype_name) weight = torch.empty(shape, dtype=dtype, device="cuda") self.model_update_group.broadcast( weight, src=0, stream=torch.cuda.current_stream()