mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 03:15:01 +08:00
[Core] Support weight_loader_v2 for UnquantizedLinearMethod (#23036)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
cb825af948
commit
4ed6b67da3
@ -8,6 +8,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
@ -300,13 +301,13 @@ def _support_torch_compile(
|
||||
logger.debug(
|
||||
"enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
with patch.object(InliningInstructionTranslator, 'inline_call',
|
||||
patched_inline_call), torch._dynamo.config.patch(
|
||||
**dynamo_config_patches
|
||||
), maybe_use_cudagraph_partition_wrapper(
|
||||
self.vllm_config):
|
||||
with patch.object(
|
||||
InliningInstructionTranslator, "inline_call",
|
||||
patched_inline_call), torch._dynamo.config.patch(
|
||||
**dynamo_config_patches
|
||||
), maybe_use_cudagraph_partition_wrapper(
|
||||
self.vllm_config), _torch27_patch_tensor_subclasses():
|
||||
output = self.compiled_callable(*args, **kwargs)
|
||||
|
||||
return output
|
||||
|
||||
# usually, capturing the model once is enough, and then we can
|
||||
@ -367,3 +368,33 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _torch27_patch_tensor_subclasses():
|
||||
"""
|
||||
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
|
||||
using torch 2.7.0. This enables using weight_loader_v2 and the use of
|
||||
`BasevLLMParameters` without having to replace them with regular tensors
|
||||
before `torch.compile`-time.
|
||||
"""
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter)
|
||||
|
||||
def return_false(*args, **kwargs):
|
||||
return False
|
||||
|
||||
if version.parse("2.7") <= version.parse(
|
||||
torch.__version__) < version.parse("2.8"):
|
||||
yield
|
||||
return
|
||||
|
||||
with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
|
||||
BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
|
||||
RowvLLMParameter
|
||||
]),
|
||||
patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
|
||||
return_false)):
|
||||
yield
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
# yapf: disable
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
@ -34,6 +35,7 @@ from vllm.utils import GiB_bytes
|
||||
logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"UnquantizedLinearMethod",
|
||||
"CompressedTensorsLinearMethod",
|
||||
"CompressedTensorsLinearTransformMethod",
|
||||
"BitBLASLinearMethod",
|
||||
@ -196,10 +198,14 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
# The amount of memory allocated for the weights is
|
||||
# sum(output_partition_sizes) * input_size_per_partition.
|
||||
try:
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
weight_loader = extra_weight_attrs.pop("weight_loader")
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
logger.error("Failed to create unquantized linear weights: %s", e)
|
||||
if torch.cuda.is_available():
|
||||
@ -212,7 +218,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
"Failed to create unquantized linear weights. "
|
||||
"This may be caused by insufficient memory to allocate "
|
||||
"the weight.") from e
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
|
||||
@ -61,9 +61,24 @@ class BasevLLMParameter(Parameter):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
@property
|
||||
def weight_loader(self):
|
||||
def weight_loader(self) -> Callable:
|
||||
# NOTE(@ksayers) some models such as mamba_mixer2 override the
|
||||
# weight loader to support custom loading. In the future, model-specific
|
||||
# weight loading should be implemented via Model.load_weights. In the
|
||||
# meantime, support deleting and overriding `weight_loader`` attribute
|
||||
if self._weight_loader is None:
|
||||
raise AttributeError(f"{self.__class__.__name__} weight_loader "
|
||||
"attribute has been deleted")
|
||||
return self._weight_loader
|
||||
|
||||
@weight_loader.setter
|
||||
def weight_loader(self, value: Callable):
|
||||
self._weight_loader = value
|
||||
|
||||
@weight_loader.deleter
|
||||
def weight_loader(self):
|
||||
self._weight_loader = None # type: ignore[assignment]
|
||||
|
||||
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
|
||||
cond1 = self.data.ndim == 1 and self.data.numel() == 1
|
||||
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
|
||||
@ -97,6 +112,12 @@ class BasevLLMParameter(Parameter):
|
||||
assert shard_id in qkv_idxs
|
||||
return qkv_idxs[shard_id]
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
||||
class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user