From c64ec86c337f628df452bf8d807d9318f6c7a7ed Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Thu, 13 Nov 2025 11:14:54 +0800 Subject: [PATCH] [Fix] remove pickle and clean code Signed-off-by: knlnguyen1802 --- examples/offline_inference/rlhf_colocate.py | 13 ++++++------- examples/offline_inference/rlhf_utils.py | 10 +++------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index a7fb1f7a66f67..fa7e57fca78a3 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -179,7 +179,7 @@ class RayTrainingActor: s = self.zmq_context.socket(zmq.DEALER) s.connect(self.zmq_handle) s.send_pyobj(handles) - _ = s.recv_pyobj() + s.recv_multipart() # ACK for handles # === Partition tensors into buffers === offset = 0 @@ -238,17 +238,16 @@ class RayTrainingActor: # Receive buffer-free ACKs try: while s.getsockopt(zmq.EVENTS) & zmq.POLLIN: - ack = s.recv_pyobj(flags=zmq.NOBLOCK) - if isinstance(ack, int): - free_buffers.append(ack) - inflight -= 1 - print(f"[Training] Ack received: buffer {ack} now free.") + ack = int(s.recv_multipart(flags=zmq.NOBLOCK)[0].decode()) + free_buffers.append(ack) + inflight -= 1 + print(f"[Training] Ack received: buffer {ack} now free.") except zmq.Again: pass # Signal done s.send_pyobj((None, None)) - _ = s.recv_pyobj() + s.recv_multipart() # DONE ack s.close() del buffers gc.collect() diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index e25f36ad48afb..044d6e0c6aab5 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc -import pickle from ast import Dict, Tuple from collections.abc import Callable from enum import Enum @@ -193,7 +192,7 @@ class ColocateWorkerExtension: handles: list[tuple[Callable, tuple]] = payload for i, h in enumerate(handles): buffers[i] = rebuild_ipc(h, self.device.index) - socket.send_multipart([identity, pickle.dumps("ACK_HANDLES")]) + socket.send_multipart([identity, b"ACK_HANDLES"]) continue # === HANDLE BUFFERED MODEL UPDATES === @@ -216,14 +215,11 @@ class ColocateWorkerExtension: self.model_runner.model.load_weights(weights=weights) torch.cuda.synchronize() - # # --- Added verify step here --- - # if self.verify_weights_enabled: - # self.verify_weights(weights) - socket.send_multipart([identity, pickle.dumps(buf_id)]) + socket.send_multipart([identity, str(buf_id).encode()]) # === DONE SIGNAL === elif payload_type == PayloadType.DONE: - socket.send_multipart([identity, pickle.dumps("DONE")]) + socket.send_multipart([identity, b"DONE"]) break else: continue