[Core] Support weight_loader_v2 for UnquantizedLinearMethod (#23036)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Kyle Sayers 2025-09-24 01:30:26 +01:00 committed by yewentao256
parent cb825af948
commit 4ed6b67da3
3 changed files with 70 additions and 12 deletions

View File

@ -8,6 +8,7 @@ from unittest.mock import patch
import torch import torch
import torch.nn as nn import torch.nn as nn
from packaging import version
from torch._dynamo.symbolic_convert import InliningInstructionTranslator from torch._dynamo.symbolic_convert import InliningInstructionTranslator
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
@ -300,13 +301,13 @@ def _support_torch_compile(
logger.debug( logger.debug(
"enable_cpp_symbolic_shape_guards config not available") "enable_cpp_symbolic_shape_guards config not available")
with patch.object(InliningInstructionTranslator, 'inline_call', with patch.object(
InliningInstructionTranslator, "inline_call",
patched_inline_call), torch._dynamo.config.patch( patched_inline_call), torch._dynamo.config.patch(
**dynamo_config_patches **dynamo_config_patches
), maybe_use_cudagraph_partition_wrapper( ), maybe_use_cudagraph_partition_wrapper(
self.vllm_config): self.vllm_config), _torch27_patch_tensor_subclasses():
output = self.compiled_callable(*args, **kwargs) output = self.compiled_callable(*args, **kwargs)
return output return output
# usually, capturing the model once is enough, and then we can # usually, capturing the model once is enough, and then we can
@ -367,3 +368,33 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and compilation_config.use_inductor_graph_partition): and compilation_config.use_inductor_graph_partition):
torch._inductor.utils.set_customized_partition_wrappers(None) torch._inductor.utils.set_customized_partition_wrappers(None)
@contextlib.contextmanager
def _torch27_patch_tensor_subclasses():
"""
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
using torch 2.7.0. This enables using weight_loader_v2 and the use of
`BasevLLMParameters` without having to replace them with regular tensors
before `torch.compile`-time.
"""
from vllm.model_executor.parameter import (BasevLLMParameter,
ModelWeightParameter,
RowvLLMParameter,
_ColumnvLLMParameter)
def return_false(*args, **kwargs):
return False
if version.parse("2.7") <= version.parse(
torch.__version__) < version.parse("2.8"):
yield
return
with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
RowvLLMParameter
]),
patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
return_false)):
yield

View File

@ -22,6 +22,7 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
# yapf: disable # yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
BlockQuantScaleParameter, BlockQuantScaleParameter,
ModelWeightParameter,
PackedColumnParameter, PackedColumnParameter,
PackedvLLMParameter, PackedvLLMParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
@ -34,6 +35,7 @@ from vllm.utils import GiB_bytes
logger = init_logger(__name__) logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
"UnquantizedLinearMethod",
"CompressedTensorsLinearMethod", "CompressedTensorsLinearMethod",
"CompressedTensorsLinearTransformMethod", "CompressedTensorsLinearTransformMethod",
"BitBLASLinearMethod", "BitBLASLinearMethod",
@ -196,10 +198,14 @@ class UnquantizedLinearMethod(LinearMethodBase):
# The amount of memory allocated for the weights is # The amount of memory allocated for the weights is
# sum(output_partition_sizes) * input_size_per_partition. # sum(output_partition_sizes) * input_size_per_partition.
try: try:
weight = Parameter(torch.empty(sum(output_partition_sizes), weight_loader = extra_weight_attrs.pop("weight_loader")
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition, input_size_per_partition,
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) input_dim=1,
output_dim=0,
weight_loader=weight_loader)
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
logger.error("Failed to create unquantized linear weights: %s", e) logger.error("Failed to create unquantized linear weights: %s", e)
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -212,7 +218,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
"Failed to create unquantized linear weights. " "Failed to create unquantized linear weights. "
"This may be caused by insufficient memory to allocate " "This may be caused by insufficient memory to allocate "
"the weight.") from e "the weight.") from e
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs) set_weight_attrs(weight, extra_weight_attrs)

View File

@ -61,9 +61,24 @@ class BasevLLMParameter(Parameter):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
@property @property
def weight_loader(self): def weight_loader(self) -> Callable:
# NOTE(@ksayers) some models such as mamba_mixer2 override the
# weight loader to support custom loading. In the future, model-specific
# weight loading should be implemented via Model.load_weights. In the
# meantime, support deleting and overriding `weight_loader`` attribute
if self._weight_loader is None:
raise AttributeError(f"{self.__class__.__name__} weight_loader "
"attribute has been deleted")
return self._weight_loader return self._weight_loader
@weight_loader.setter
def weight_loader(self, value: Callable):
self._weight_loader = value
@weight_loader.deleter
def weight_loader(self):
self._weight_loader = None # type: ignore[assignment]
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor): def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
cond1 = self.data.ndim == 1 and self.data.numel() == 1 cond1 = self.data.ndim == 1 and self.data.numel() == 1
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1 cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
@ -97,6 +112,12 @@ class BasevLLMParameter(Parameter):
assert shard_id in qkv_idxs assert shard_id in qkv_idxs
return qkv_idxs[shard_id] return qkv_idxs[shard_id]
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
class _ColumnvLLMParameter(BasevLLMParameter): class _ColumnvLLMParameter(BasevLLMParameter):
""" """