[Fix] remove pickle and clean code

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
knlnguyen1802 2025-11-13 11:14:54 +08:00
parent c594843112
commit c64ec86c33
2 changed files with 9 additions and 14 deletions

View File

@ -179,7 +179,7 @@ class RayTrainingActor:
s = self.zmq_context.socket(zmq.DEALER)
s.connect(self.zmq_handle)
s.send_pyobj(handles)
_ = s.recv_pyobj()
s.recv_multipart() # ACK for handles
# === Partition tensors into buffers ===
offset = 0
@ -238,17 +238,16 @@ class RayTrainingActor:
# Receive buffer-free ACKs
try:
while s.getsockopt(zmq.EVENTS) & zmq.POLLIN:
ack = s.recv_pyobj(flags=zmq.NOBLOCK)
if isinstance(ack, int):
free_buffers.append(ack)
inflight -= 1
print(f"[Training] Ack received: buffer {ack} now free.")
ack = int(s.recv_multipart(flags=zmq.NOBLOCK)[0].decode())
free_buffers.append(ack)
inflight -= 1
print(f"[Training] Ack received: buffer {ack} now free.")
except zmq.Again:
pass
# Signal done
s.send_pyobj((None, None))
_ = s.recv_pyobj()
s.recv_multipart() # DONE ack
s.close()
del buffers
gc.collect()

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import pickle
from ast import Dict, Tuple
from collections.abc import Callable
from enum import Enum
@ -193,7 +192,7 @@ class ColocateWorkerExtension:
handles: list[tuple[Callable, tuple]] = payload
for i, h in enumerate(handles):
buffers[i] = rebuild_ipc(h, self.device.index)
socket.send_multipart([identity, pickle.dumps("ACK_HANDLES")])
socket.send_multipart([identity, b"ACK_HANDLES"])
continue
# === HANDLE BUFFERED MODEL UPDATES ===
@ -216,14 +215,11 @@ class ColocateWorkerExtension:
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
# # --- Added verify step here ---
# if self.verify_weights_enabled:
# self.verify_weights(weights)
socket.send_multipart([identity, pickle.dumps(buf_id)])
socket.send_multipart([identity, str(buf_id).encode()])
# === DONE SIGNAL ===
elif payload_type == PayloadType.DONE:
socket.send_multipart([identity, pickle.dumps("DONE")])
socket.send_multipart([identity, b"DONE"])
break
else:
continue