mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 15:18:05 +08:00
fix: suppress UserWarning for non-writable buffer in binary2tensor
Use bytearray to create a mutable copy of the binary data before passing to np.frombuffer. This ensures the numpy array is writable, avoiding UserWarning from torch.from_numpy on read-only arrays. The warning occurred because base64.b64decode returns immutable bytes, and np.frombuffer on immutable bytes returns a read-only array. When converted to a torch tensor via torch.from_numpy, PyTorch would emit: "UserWarning: The given buffer is not writable..." This fix maintains the efficient numpy-based conversion while ensuring compatibility with all embed_dtype formats (float32, float16, bfloat16, fp8_e4m3, fp8_e5m2) that numpy doesn't natively support. Fixes #26781 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: yurekami <yurekami@users.noreply.github.com>
This commit is contained in:
parent
daf8032542
commit
ef89079712
@ -38,3 +38,42 @@ def test_encode_and_decode(embed_dtype: str, endianness: str):
|
||||
name_1="new",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys())
|
||||
@torch.inference_mode()
|
||||
def test_binary2tensor_no_warning(embed_dtype: str):
|
||||
"""Test that binary2tensor does not emit UserWarning about non-writable buffers.
|
||||
|
||||
This addresses issue #26781 where torch.frombuffer on non-writable bytes
|
||||
would emit: "UserWarning: The given buffer is not writable..."
|
||||
"""
|
||||
import warnings
|
||||
|
||||
tensor = torch.rand(10, 20, device="cpu", dtype=torch.float32)
|
||||
binary = tensor2binary(tensor, embed_dtype, "native")
|
||||
|
||||
# Capture warnings during binary2tensor call
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
result = binary2tensor(binary, tensor.shape, embed_dtype, "native")
|
||||
|
||||
# Filter for the specific UserWarning about non-writable buffers
|
||||
buffer_warnings = [
|
||||
warning
|
||||
for warning in w
|
||||
if issubclass(warning.category, UserWarning)
|
||||
and "not writable" in str(warning.message)
|
||||
]
|
||||
assert len(buffer_warnings) == 0, (
|
||||
f"Expected no warnings about non-writable buffers, got: {buffer_warnings}"
|
||||
)
|
||||
|
||||
# Verify the result is correct
|
||||
result_float = result.to(torch.float32)
|
||||
if embed_dtype in ["float32", "float16"]:
|
||||
torch.testing.assert_close(tensor, result_float, atol=0.001, rtol=0.001)
|
||||
elif embed_dtype == "bfloat16":
|
||||
torch.testing.assert_close(tensor, result_float, atol=0.01, rtol=0.01)
|
||||
else: # fp8
|
||||
torch.testing.assert_close(tensor, result_float, atol=0.1, rtol=0.1)
|
||||
|
||||
@ -107,7 +107,10 @@ def binary2tensor(
|
||||
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
|
||||
np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]
|
||||
|
||||
np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape)
|
||||
# Use bytearray to create a mutable copy of the binary data.
|
||||
# This ensures np.frombuffer returns a writable array, avoiding
|
||||
# UserWarning from torch.from_numpy on read-only arrays.
|
||||
np_array = np.frombuffer(bytearray(binary), dtype=np_dtype).reshape(shape)
|
||||
|
||||
if endianness != "native" and endianness != sys_byteorder:
|
||||
np_array = np_array.byteswap()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user