diff --git a/tests/utils_/test_serial_utils.py b/tests/utils_/test_serial_utils.py index 51b2e4de02693..ff48cc7ed2ded 100644 --- a/tests/utils_/test_serial_utils.py +++ b/tests/utils_/test_serial_utils.py @@ -38,3 +38,42 @@ def test_encode_and_decode(embed_dtype: str, endianness: str): name_1="new", tol=1e-2, ) + + +@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys()) +@torch.inference_mode() +def test_binary2tensor_no_warning(embed_dtype: str): + """Test that binary2tensor does not emit UserWarning about non-writable buffers. + + This addresses issue #26781 where torch.frombuffer on non-writable bytes + would emit: "UserWarning: The given buffer is not writable..." + """ + import warnings + + tensor = torch.rand(10, 20, device="cpu", dtype=torch.float32) + binary = tensor2binary(tensor, embed_dtype, "native") + + # Capture warnings during binary2tensor call + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = binary2tensor(binary, tensor.shape, embed_dtype, "native") + + # Filter for the specific UserWarning about non-writable buffers + buffer_warnings = [ + warning + for warning in w + if issubclass(warning.category, UserWarning) + and "not writable" in str(warning.message) + ] + assert len(buffer_warnings) == 0, ( + f"Expected no warnings about non-writable buffers, got: {buffer_warnings}" + ) + + # Verify the result is correct + result_float = result.to(torch.float32) + if embed_dtype in ["float32", "float16"]: + torch.testing.assert_close(tensor, result_float, atol=0.001, rtol=0.001) + elif embed_dtype == "bfloat16": + torch.testing.assert_close(tensor, result_float, atol=0.01, rtol=0.01) + else: # fp8 + torch.testing.assert_close(tensor, result_float, atol=0.1, rtol=0.1) diff --git a/vllm/utils/serial_utils.py b/vllm/utils/serial_utils.py index 07db5eaf74c8d..91f33034223a6 100644 --- a/vllm/utils/serial_utils.py +++ b/vllm/utils/serial_utils.py @@ -107,7 +107,10 @@ def binary2tensor( torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype] np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype] - np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape) + # Use bytearray to create a mutable copy of the binary data. + # This ensures np.frombuffer returns a writable array, avoiding + # UserWarning from torch.from_numpy on read-only arrays. + np_array = np.frombuffer(bytearray(binary), dtype=np_dtype).reshape(shape) if endianness != "native" and endianness != sys_byteorder: np_array = np_array.byteswap()