[Fix] remove pickle and clean code

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
knlnguyen1802 2025-11-13 11:14:54 +08:00
parent c594843112
commit c64ec86c33
2 changed files with 9 additions and 14 deletions

View File

@ -179,7 +179,7 @@ class RayTrainingActor:
s = self.zmq_context.socket(zmq.DEALER) s = self.zmq_context.socket(zmq.DEALER)
s.connect(self.zmq_handle) s.connect(self.zmq_handle)
s.send_pyobj(handles) s.send_pyobj(handles)
_ = s.recv_pyobj() s.recv_multipart() # ACK for handles
# === Partition tensors into buffers === # === Partition tensors into buffers ===
offset = 0 offset = 0
@ -238,17 +238,16 @@ class RayTrainingActor:
# Receive buffer-free ACKs # Receive buffer-free ACKs
try: try:
while s.getsockopt(zmq.EVENTS) & zmq.POLLIN: while s.getsockopt(zmq.EVENTS) & zmq.POLLIN:
ack = s.recv_pyobj(flags=zmq.NOBLOCK) ack = int(s.recv_multipart(flags=zmq.NOBLOCK)[0].decode())
if isinstance(ack, int): free_buffers.append(ack)
free_buffers.append(ack) inflight -= 1
inflight -= 1 print(f"[Training] Ack received: buffer {ack} now free.")
print(f"[Training] Ack received: buffer {ack} now free.")
except zmq.Again: except zmq.Again:
pass pass
# Signal done # Signal done
s.send_pyobj((None, None)) s.send_pyobj((None, None))
_ = s.recv_pyobj() s.recv_multipart() # DONE ack
s.close() s.close()
del buffers del buffers
gc.collect() gc.collect()

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc import gc
import pickle
from ast import Dict, Tuple from ast import Dict, Tuple
from collections.abc import Callable from collections.abc import Callable
from enum import Enum from enum import Enum
@ -193,7 +192,7 @@ class ColocateWorkerExtension:
handles: list[tuple[Callable, tuple]] = payload handles: list[tuple[Callable, tuple]] = payload
for i, h in enumerate(handles): for i, h in enumerate(handles):
buffers[i] = rebuild_ipc(h, self.device.index) buffers[i] = rebuild_ipc(h, self.device.index)
socket.send_multipart([identity, pickle.dumps("ACK_HANDLES")]) socket.send_multipart([identity, b"ACK_HANDLES"])
continue continue
# === HANDLE BUFFERED MODEL UPDATES === # === HANDLE BUFFERED MODEL UPDATES ===
@ -216,14 +215,11 @@ class ColocateWorkerExtension:
self.model_runner.model.load_weights(weights=weights) self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize() torch.cuda.synchronize()
# # --- Added verify step here --- socket.send_multipart([identity, str(buf_id).encode()])
# if self.verify_weights_enabled:
# self.verify_weights(weights)
socket.send_multipart([identity, pickle.dumps(buf_id)])
# === DONE SIGNAL === # === DONE SIGNAL ===
elif payload_type == PayloadType.DONE: elif payload_type == PayloadType.DONE:
socket.send_multipart([identity, pickle.dumps("DONE")]) socket.send_multipart([identity, b"DONE"])
break break
else: else:
continue continue