From 4ed6b67da3cac69c4b4b8cc56a7e186032a7c7d5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 24 Sep 2025 01:30:26 +0100 Subject: [PATCH] [Core] Support weight_loader_v2 for `UnquantizedLinearMethod` (#23036) Signed-off-by: Kyle Sayers Signed-off-by: yewentao256 --- vllm/compilation/decorators.py | 43 ++++++++++++++++++++++++---- vllm/model_executor/layers/linear.py | 16 +++++++---- vllm/model_executor/parameter.py | 23 ++++++++++++++- 3 files changed, 70 insertions(+), 12 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index b7a6e23c1aa79..6e9a36a2b0b99 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -8,6 +8,7 @@ from unittest.mock import patch import torch import torch.nn as nn +from packaging import version from torch._dynamo.symbolic_convert import InliningInstructionTranslator from vllm.compilation.counter import compilation_counter @@ -300,13 +301,13 @@ def _support_torch_compile( logger.debug( "enable_cpp_symbolic_shape_guards config not available") - with patch.object(InliningInstructionTranslator, 'inline_call', - patched_inline_call), torch._dynamo.config.patch( - **dynamo_config_patches - ), maybe_use_cudagraph_partition_wrapper( - self.vllm_config): + with patch.object( + InliningInstructionTranslator, "inline_call", + patched_inline_call), torch._dynamo.config.patch( + **dynamo_config_patches + ), maybe_use_cudagraph_partition_wrapper( + self.vllm_config), _torch27_patch_tensor_subclasses(): output = self.compiled_callable(*args, **kwargs) - return output # usually, capturing the model once is enough, and then we can @@ -367,3 +368,33 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE and compilation_config.use_inductor_graph_partition): torch._inductor.utils.set_customized_partition_wrappers(None) + + +@contextlib.contextmanager +def _torch27_patch_tensor_subclasses(): + """ + Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when + using torch 2.7.0. This enables using weight_loader_v2 and the use of + `BasevLLMParameters` without having to replace them with regular tensors + before `torch.compile`-time. + """ + from vllm.model_executor.parameter import (BasevLLMParameter, + ModelWeightParameter, + RowvLLMParameter, + _ColumnvLLMParameter) + + def return_false(*args, **kwargs): + return False + + if version.parse("2.7") <= version.parse( + torch.__version__) < version.parse("2.8"): + yield + return + + with (torch._dynamo.config.patch("traceable_tensor_subclasses", [ + BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter, + RowvLLMParameter + ]), + patch("torch._dynamo.variables.torch.can_dispatch_torch_function", + return_false)): + yield diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5bf96398bc710..df5bced6b2288 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm # yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, BlockQuantScaleParameter, + ModelWeightParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, @@ -34,6 +35,7 @@ from vllm.utils import GiB_bytes logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ + "UnquantizedLinearMethod", "CompressedTensorsLinearMethod", "CompressedTensorsLinearTransformMethod", "BitBLASLinearMethod", @@ -196,10 +198,14 @@ class UnquantizedLinearMethod(LinearMethodBase): # The amount of memory allocated for the weights is # sum(output_partition_sizes) * input_size_per_partition. try: - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + weight_loader = extra_weight_attrs.pop("weight_loader") + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) except torch.cuda.OutOfMemoryError as e: logger.error("Failed to create unquantized linear weights: %s", e) if torch.cuda.is_available(): @@ -212,7 +218,7 @@ class UnquantizedLinearMethod(LinearMethodBase): "Failed to create unquantized linear weights. " "This may be caused by insufficient memory to allocate " "the weight.") from e - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 03e5e5809b678..66add98dab443 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -61,9 +61,24 @@ class BasevLLMParameter(Parameter): self.tp_size = get_tensor_model_parallel_world_size() @property - def weight_loader(self): + def weight_loader(self) -> Callable: + # NOTE(@ksayers) some models such as mamba_mixer2 override the + # weight loader to support custom loading. In the future, model-specific + # weight loading should be implemented via Model.load_weights. In the + # meantime, support deleting and overriding `weight_loader`` attribute + if self._weight_loader is None: + raise AttributeError(f"{self.__class__.__name__} weight_loader " + "attribute has been deleted") return self._weight_loader + @weight_loader.setter + def weight_loader(self, value: Callable): + self._weight_loader = value + + @weight_loader.deleter + def weight_loader(self): + self._weight_loader = None # type: ignore[assignment] + def _is_1d_and_scalar(self, loaded_weight: torch.Tensor): cond1 = self.data.ndim == 1 and self.data.numel() == 1 cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1 @@ -97,6 +112,12 @@ class BasevLLMParameter(Parameter): assert shard_id in qkv_idxs return qkv_idxs[shard_id] + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + class _ColumnvLLMParameter(BasevLLMParameter): """