[Security] Serialize using safetensors instead of pickle in Mooncake Pipe (#14228)

Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
Kuntai Du 2025-03-04 15:10:32 -06:00 committed by GitHub
parent c2bd2196fc
commit 288ca110f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,13 +2,14 @@
import json
import os
import pickle
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional, Union
import torch
import zmq
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
@ -237,14 +238,13 @@ class MooncakePipe(KVPipeBase):
return hash(tensor.data_ptr())
def _send_impl(self, tensor: torch.Tensor) -> None:
"""Implement the tensor sending logic."""
value_bytes = pickle.dumps(tensor)
self.transfer_engine.send_bytes(value_bytes)
"""Implement the tensor sending logic using safetensors."""
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
def _recv_impl(self) -> torch.Tensor:
"""Implement the tensor receiving logic."""
"""Implement the tensor receiving logic using safetensors."""
data = self.transfer_engine.recv_bytes()
return pickle.loads(data)
return safetensors_load(data)["tensor"].to(self.device)
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""Send tensor to the target process."""