From 845420ac2c2bc27ae0f96c25430b4f1cd20063cc Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Sun, 3 Aug 2025 19:43:33 -0700 Subject: [PATCH] [RLHF] Fix torch.dtype not serializable in example (#22158) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> --- examples/offline_inference/rlhf.py | 5 ++++- examples/offline_inference/rlhf_utils.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index 752117a4e3623..ed974b90b57ee 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -126,7 +126,10 @@ for name, p in train_model.named_parameters(): # Synchronize the updated weights to the inference engine. for name, p in train_model.named_parameters(): - handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape)) + dtype_name = str(p.dtype).split(".")[-1] + handle = llm.collective_rpc.remote( + "update_weight", args=(name, dtype_name, p.shape) + ) model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) ray.get(handle) diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index c445224d75686..d2a8419ffabcd 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -45,7 +45,8 @@ class WorkerExtension: self.device, ) - def update_weight(self, name, dtype, shape): + def update_weight(self, name, dtype_name, shape): + dtype = getattr(torch, dtype_name) weight = torch.empty(shape, dtype=dtype, device="cuda") self.model_update_group.broadcast( weight, src=0, stream=torch.cuda.current_stream()