mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 06:17:52 +08:00
[Fix] remove pickle and clean code
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
parent
c594843112
commit
c64ec86c33
@ -179,7 +179,7 @@ class RayTrainingActor:
|
|||||||
s = self.zmq_context.socket(zmq.DEALER)
|
s = self.zmq_context.socket(zmq.DEALER)
|
||||||
s.connect(self.zmq_handle)
|
s.connect(self.zmq_handle)
|
||||||
s.send_pyobj(handles)
|
s.send_pyobj(handles)
|
||||||
_ = s.recv_pyobj()
|
s.recv_multipart() # ACK for handles
|
||||||
|
|
||||||
# === Partition tensors into buffers ===
|
# === Partition tensors into buffers ===
|
||||||
offset = 0
|
offset = 0
|
||||||
@ -238,17 +238,16 @@ class RayTrainingActor:
|
|||||||
# Receive buffer-free ACKs
|
# Receive buffer-free ACKs
|
||||||
try:
|
try:
|
||||||
while s.getsockopt(zmq.EVENTS) & zmq.POLLIN:
|
while s.getsockopt(zmq.EVENTS) & zmq.POLLIN:
|
||||||
ack = s.recv_pyobj(flags=zmq.NOBLOCK)
|
ack = int(s.recv_multipart(flags=zmq.NOBLOCK)[0].decode())
|
||||||
if isinstance(ack, int):
|
free_buffers.append(ack)
|
||||||
free_buffers.append(ack)
|
inflight -= 1
|
||||||
inflight -= 1
|
print(f"[Training] Ack received: buffer {ack} now free.")
|
||||||
print(f"[Training] Ack received: buffer {ack} now free.")
|
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Signal done
|
# Signal done
|
||||||
s.send_pyobj((None, None))
|
s.send_pyobj((None, None))
|
||||||
_ = s.recv_pyobj()
|
s.recv_multipart() # DONE ack
|
||||||
s.close()
|
s.close()
|
||||||
del buffers
|
del buffers
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import gc
|
import gc
|
||||||
import pickle
|
|
||||||
from ast import Dict, Tuple
|
from ast import Dict, Tuple
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -193,7 +192,7 @@ class ColocateWorkerExtension:
|
|||||||
handles: list[tuple[Callable, tuple]] = payload
|
handles: list[tuple[Callable, tuple]] = payload
|
||||||
for i, h in enumerate(handles):
|
for i, h in enumerate(handles):
|
||||||
buffers[i] = rebuild_ipc(h, self.device.index)
|
buffers[i] = rebuild_ipc(h, self.device.index)
|
||||||
socket.send_multipart([identity, pickle.dumps("ACK_HANDLES")])
|
socket.send_multipart([identity, b"ACK_HANDLES"])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# === HANDLE BUFFERED MODEL UPDATES ===
|
# === HANDLE BUFFERED MODEL UPDATES ===
|
||||||
@ -216,14 +215,11 @@ class ColocateWorkerExtension:
|
|||||||
|
|
||||||
self.model_runner.model.load_weights(weights=weights)
|
self.model_runner.model.load_weights(weights=weights)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# # --- Added verify step here ---
|
socket.send_multipart([identity, str(buf_id).encode()])
|
||||||
# if self.verify_weights_enabled:
|
|
||||||
# self.verify_weights(weights)
|
|
||||||
socket.send_multipart([identity, pickle.dumps(buf_id)])
|
|
||||||
|
|
||||||
# === DONE SIGNAL ===
|
# === DONE SIGNAL ===
|
||||||
elif payload_type == PayloadType.DONE:
|
elif payload_type == PayloadType.DONE:
|
||||||
socket.send_multipart([identity, pickle.dumps("DONE")])
|
socket.send_multipart([identity, b"DONE"])
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user