From daf80325426f3d49260cd1986815691f0d5016c3 Mon Sep 17 00:00:00 2001 From: yurekami Date: Thu, 25 Dec 2025 02:44:56 +0900 Subject: [PATCH 1/2] [Code Quality] Add missing return type annotations to model_executor utils MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds missing return type annotations to functions in the model_executor module to improve code quality and IDE support. ## Changes ### model_executor/utils.py - `set_weight_attrs()` -> `None` - `replace_parameter()` -> `None` ### model_executor/model_loader/weight_utils.py - `enable_hf_transfer()` -> `None` - `get_lock()` -> `filelock.FileLock` - `_shared_pointers()` -> `list[list[str]]` - `enable_tqdm()` -> `bool` 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: yurekami --- vllm/model_executor/model_loader/weight_utils.py | 14 ++++++++------ vllm/model_executor/utils.py | 6 ++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0c5961561a7d9..a93ff74a257e3 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -66,7 +66,7 @@ logger = init_logger(__name__) temp_dir = tempfile.gettempdir() -def enable_hf_transfer(): +def enable_hf_transfer() -> None: """automatically activates hf_transfer""" if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: try: @@ -87,7 +87,9 @@ class DisabledTqdm(tqdm): super().__init__(*args, **kwargs) -def get_lock(model_name_or_path: str | Path, cache_dir: str | None = None): +def get_lock( + model_name_or_path: str | Path, cache_dir: str | None = None +) -> filelock.FileLock: lock_dir = cache_dir or temp_dir model_name_or_path = str(model_name_or_path) os.makedirs(os.path.dirname(lock_dir), exist_ok=True) @@ -178,11 +180,11 @@ def maybe_download_from_modelscope( return None -def _shared_pointers(tensors): - ptrs = defaultdict(list) +def _shared_pointers(tensors: dict[str, torch.Tensor]) -> list[list[str]]: + ptrs: dict[int, list[str]] = defaultdict(list) for k, v in tensors.items(): ptrs[v.data_ptr()].append(k) - failing = [] + failing: list[list[str]] = [] for _, names in ptrs.items(): if len(names) > 1: failing.append(names) @@ -602,7 +604,7 @@ def filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[s _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 -def enable_tqdm(use_tqdm_on_load: bool): +def enable_tqdm(use_tqdm_on_load: bool) -> bool: return use_tqdm_on_load and ( not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index b89371d987541..6344b86fcb5b6 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -19,7 +19,7 @@ def set_random_seed(seed: int | None) -> None: def set_weight_attrs( weight: torch.Tensor, weight_attrs: dict[str, Any] | None, -): +) -> None: """Set attributes on a weight tensor. This method is used to set attributes on a weight tensor. This method @@ -50,7 +50,9 @@ def set_weight_attrs( setattr(weight, key, value) -def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor): +def replace_parameter( + layer: torch.nn.Module, param_name: str, new_data: torch.Tensor +) -> None: """ Replace a parameter of a layer while maintaining the ability to reload the weight. Called within implementations of the `process_weights_after_loading` method. From ef89079712a5dcd53f1d50a89529aea1ac1d9a13 Mon Sep 17 00:00:00 2001 From: yurekami Date: Thu, 25 Dec 2025 02:46:18 +0900 Subject: [PATCH 2/2] fix: suppress UserWarning for non-writable buffer in binary2tensor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use bytearray to create a mutable copy of the binary data before passing to np.frombuffer. This ensures the numpy array is writable, avoiding UserWarning from torch.from_numpy on read-only arrays. The warning occurred because base64.b64decode returns immutable bytes, and np.frombuffer on immutable bytes returns a read-only array. When converted to a torch tensor via torch.from_numpy, PyTorch would emit: "UserWarning: The given buffer is not writable..." This fix maintains the efficient numpy-based conversion while ensuring compatibility with all embed_dtype formats (float32, float16, bfloat16, fp8_e4m3, fp8_e5m2) that numpy doesn't natively support. Fixes #26781 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: yurekami --- tests/utils_/test_serial_utils.py | 39 +++++++++++++++++++++++++++++++ vllm/utils/serial_utils.py | 5 +++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/utils_/test_serial_utils.py b/tests/utils_/test_serial_utils.py index 51b2e4de02693..ff48cc7ed2ded 100644 --- a/tests/utils_/test_serial_utils.py +++ b/tests/utils_/test_serial_utils.py @@ -38,3 +38,42 @@ def test_encode_and_decode(embed_dtype: str, endianness: str): name_1="new", tol=1e-2, ) + + +@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys()) +@torch.inference_mode() +def test_binary2tensor_no_warning(embed_dtype: str): + """Test that binary2tensor does not emit UserWarning about non-writable buffers. + + This addresses issue #26781 where torch.frombuffer on non-writable bytes + would emit: "UserWarning: The given buffer is not writable..." + """ + import warnings + + tensor = torch.rand(10, 20, device="cpu", dtype=torch.float32) + binary = tensor2binary(tensor, embed_dtype, "native") + + # Capture warnings during binary2tensor call + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = binary2tensor(binary, tensor.shape, embed_dtype, "native") + + # Filter for the specific UserWarning about non-writable buffers + buffer_warnings = [ + warning + for warning in w + if issubclass(warning.category, UserWarning) + and "not writable" in str(warning.message) + ] + assert len(buffer_warnings) == 0, ( + f"Expected no warnings about non-writable buffers, got: {buffer_warnings}" + ) + + # Verify the result is correct + result_float = result.to(torch.float32) + if embed_dtype in ["float32", "float16"]: + torch.testing.assert_close(tensor, result_float, atol=0.001, rtol=0.001) + elif embed_dtype == "bfloat16": + torch.testing.assert_close(tensor, result_float, atol=0.01, rtol=0.01) + else: # fp8 + torch.testing.assert_close(tensor, result_float, atol=0.1, rtol=0.1) diff --git a/vllm/utils/serial_utils.py b/vllm/utils/serial_utils.py index 07db5eaf74c8d..91f33034223a6 100644 --- a/vllm/utils/serial_utils.py +++ b/vllm/utils/serial_utils.py @@ -107,7 +107,10 @@ def binary2tensor( torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype] np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype] - np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape) + # Use bytearray to create a mutable copy of the binary data. + # This ensures np.frombuffer returns a writable array, avoiding + # UserWarning from torch.from_numpy on read-only arrays. + np_array = np.frombuffer(bytearray(binary), dtype=np_dtype).reshape(shape) if endianness != "native" and endianness != sys_byteorder: np_array = np_array.byteswap()