mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 21:47:10 +08:00
[Fix] remove pickle and clean code
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
parent
c594843112
commit
c64ec86c33
@ -179,7 +179,7 @@ class RayTrainingActor:
|
||||
s = self.zmq_context.socket(zmq.DEALER)
|
||||
s.connect(self.zmq_handle)
|
||||
s.send_pyobj(handles)
|
||||
_ = s.recv_pyobj()
|
||||
s.recv_multipart() # ACK for handles
|
||||
|
||||
# === Partition tensors into buffers ===
|
||||
offset = 0
|
||||
@ -238,17 +238,16 @@ class RayTrainingActor:
|
||||
# Receive buffer-free ACKs
|
||||
try:
|
||||
while s.getsockopt(zmq.EVENTS) & zmq.POLLIN:
|
||||
ack = s.recv_pyobj(flags=zmq.NOBLOCK)
|
||||
if isinstance(ack, int):
|
||||
free_buffers.append(ack)
|
||||
inflight -= 1
|
||||
print(f"[Training] Ack received: buffer {ack} now free.")
|
||||
ack = int(s.recv_multipart(flags=zmq.NOBLOCK)[0].decode())
|
||||
free_buffers.append(ack)
|
||||
inflight -= 1
|
||||
print(f"[Training] Ack received: buffer {ack} now free.")
|
||||
except zmq.Again:
|
||||
pass
|
||||
|
||||
# Signal done
|
||||
s.send_pyobj((None, None))
|
||||
_ = s.recv_pyobj()
|
||||
s.recv_multipart() # DONE ack
|
||||
s.close()
|
||||
del buffers
|
||||
gc.collect()
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import pickle
|
||||
from ast import Dict, Tuple
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
@ -193,7 +192,7 @@ class ColocateWorkerExtension:
|
||||
handles: list[tuple[Callable, tuple]] = payload
|
||||
for i, h in enumerate(handles):
|
||||
buffers[i] = rebuild_ipc(h, self.device.index)
|
||||
socket.send_multipart([identity, pickle.dumps("ACK_HANDLES")])
|
||||
socket.send_multipart([identity, b"ACK_HANDLES"])
|
||||
continue
|
||||
|
||||
# === HANDLE BUFFERED MODEL UPDATES ===
|
||||
@ -216,14 +215,11 @@ class ColocateWorkerExtension:
|
||||
|
||||
self.model_runner.model.load_weights(weights=weights)
|
||||
torch.cuda.synchronize()
|
||||
# # --- Added verify step here ---
|
||||
# if self.verify_weights_enabled:
|
||||
# self.verify_weights(weights)
|
||||
socket.send_multipart([identity, pickle.dumps(buf_id)])
|
||||
socket.send_multipart([identity, str(buf_id).encode()])
|
||||
|
||||
# === DONE SIGNAL ===
|
||||
elif payload_type == PayloadType.DONE:
|
||||
socket.send_multipart([identity, pickle.dumps("DONE")])
|
||||
socket.send_multipart([identity, b"DONE"])
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user