fix: suppress UserWarning for non-writable buffer in binary2tensor

Use bytearray to create a mutable copy of the binary data before
passing to np.frombuffer. This ensures the numpy array is writable,
avoiding UserWarning from torch.from_numpy on read-only arrays.

The warning occurred because base64.b64decode returns immutable bytes,
and np.frombuffer on immutable bytes returns a read-only array. When
converted to a torch tensor via torch.from_numpy, PyTorch would emit:
"UserWarning: The given buffer is not writable..."

This fix maintains the efficient numpy-based conversion while ensuring
compatibility with all embed_dtype formats (float32, float16, bfloat16,
fp8_e4m3, fp8_e5m2) that numpy doesn't natively support.

Fixes #26781

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: yurekami <yurekami@users.noreply.github.com>
This commit is contained in:
yurekami 2025-12-25 02:46:18 +09:00
parent daf8032542
commit ef89079712
2 changed files with 43 additions and 1 deletions

View File

@ -38,3 +38,42 @@ def test_encode_and_decode(embed_dtype: str, endianness: str):
name_1="new",
tol=1e-2,
)
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys())
@torch.inference_mode()
def test_binary2tensor_no_warning(embed_dtype: str):
"""Test that binary2tensor does not emit UserWarning about non-writable buffers.
This addresses issue #26781 where torch.frombuffer on non-writable bytes
would emit: "UserWarning: The given buffer is not writable..."
"""
import warnings
tensor = torch.rand(10, 20, device="cpu", dtype=torch.float32)
binary = tensor2binary(tensor, embed_dtype, "native")
# Capture warnings during binary2tensor call
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = binary2tensor(binary, tensor.shape, embed_dtype, "native")
# Filter for the specific UserWarning about non-writable buffers
buffer_warnings = [
warning
for warning in w
if issubclass(warning.category, UserWarning)
and "not writable" in str(warning.message)
]
assert len(buffer_warnings) == 0, (
f"Expected no warnings about non-writable buffers, got: {buffer_warnings}"
)
# Verify the result is correct
result_float = result.to(torch.float32)
if embed_dtype in ["float32", "float16"]:
torch.testing.assert_close(tensor, result_float, atol=0.001, rtol=0.001)
elif embed_dtype == "bfloat16":
torch.testing.assert_close(tensor, result_float, atol=0.01, rtol=0.01)
else: # fp8
torch.testing.assert_close(tensor, result_float, atol=0.1, rtol=0.1)

View File

@ -107,7 +107,10 @@ def binary2tensor(
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]
np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape)
# Use bytearray to create a mutable copy of the binary data.
# This ensures np.frombuffer returns a writable array, avoiding
# UserWarning from torch.from_numpy on read-only arrays.
np_array = np.frombuffer(bytearray(binary), dtype=np_dtype).reshape(shape)
if endianness != "native" and endianness != sys_byteorder:
np_array = np_array.byteswap()