mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 07:15:34 +08:00
[Core] Support weight_loader_v2 for UnquantizedLinearMethod (#23036)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
cb825af948
commit
4ed6b67da3
@ -8,6 +8,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from packaging import version
|
||||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
@ -300,13 +301,13 @@ def _support_torch_compile(
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"enable_cpp_symbolic_shape_guards config not available")
|
"enable_cpp_symbolic_shape_guards config not available")
|
||||||
|
|
||||||
with patch.object(InliningInstructionTranslator, 'inline_call',
|
with patch.object(
|
||||||
|
InliningInstructionTranslator, "inline_call",
|
||||||
patched_inline_call), torch._dynamo.config.patch(
|
patched_inline_call), torch._dynamo.config.patch(
|
||||||
**dynamo_config_patches
|
**dynamo_config_patches
|
||||||
), maybe_use_cudagraph_partition_wrapper(
|
), maybe_use_cudagraph_partition_wrapper(
|
||||||
self.vllm_config):
|
self.vllm_config), _torch27_patch_tensor_subclasses():
|
||||||
output = self.compiled_callable(*args, **kwargs)
|
output = self.compiled_callable(*args, **kwargs)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# usually, capturing the model once is enough, and then we can
|
# usually, capturing the model once is enough, and then we can
|
||||||
@ -367,3 +368,33 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
|||||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
and compilation_config.use_inductor_graph_partition):
|
and compilation_config.use_inductor_graph_partition):
|
||||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _torch27_patch_tensor_subclasses():
|
||||||
|
"""
|
||||||
|
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
|
||||||
|
using torch 2.7.0. This enables using weight_loader_v2 and the use of
|
||||||
|
`BasevLLMParameters` without having to replace them with regular tensors
|
||||||
|
before `torch.compile`-time.
|
||||||
|
"""
|
||||||
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
|
RowvLLMParameter,
|
||||||
|
_ColumnvLLMParameter)
|
||||||
|
|
||||||
|
def return_false(*args, **kwargs):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if version.parse("2.7") <= version.parse(
|
||||||
|
torch.__version__) < version.parse("2.8"):
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
|
||||||
|
BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
|
||||||
|
RowvLLMParameter
|
||||||
|
]),
|
||||||
|
patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
|
||||||
|
return_false)):
|
||||||
|
yield
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
|||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
BlockQuantScaleParameter,
|
BlockQuantScaleParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
PackedColumnParameter,
|
PackedColumnParameter,
|
||||||
PackedvLLMParameter,
|
PackedvLLMParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
@ -34,6 +35,7 @@ from vllm.utils import GiB_bytes
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||||
|
"UnquantizedLinearMethod",
|
||||||
"CompressedTensorsLinearMethod",
|
"CompressedTensorsLinearMethod",
|
||||||
"CompressedTensorsLinearTransformMethod",
|
"CompressedTensorsLinearTransformMethod",
|
||||||
"BitBLASLinearMethod",
|
"BitBLASLinearMethod",
|
||||||
@ -196,10 +198,14 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
# The amount of memory allocated for the weights is
|
# The amount of memory allocated for the weights is
|
||||||
# sum(output_partition_sizes) * input_size_per_partition.
|
# sum(output_partition_sizes) * input_size_per_partition.
|
||||||
try:
|
try:
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
weight_loader = extra_weight_attrs.pop("weight_loader")
|
||||||
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
|
sum(output_partition_sizes),
|
||||||
input_size_per_partition,
|
input_size_per_partition,
|
||||||
dtype=params_dtype),
|
dtype=params_dtype),
|
||||||
requires_grad=False)
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
except torch.cuda.OutOfMemoryError as e:
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
logger.error("Failed to create unquantized linear weights: %s", e)
|
logger.error("Failed to create unquantized linear weights: %s", e)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -212,7 +218,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
"Failed to create unquantized linear weights. "
|
"Failed to create unquantized linear weights. "
|
||||||
"This may be caused by insufficient memory to allocate "
|
"This may be caused by insufficient memory to allocate "
|
||||||
"the weight.") from e
|
"the weight.") from e
|
||||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
|
||||||
layer.register_parameter("weight", weight)
|
layer.register_parameter("weight", weight)
|
||||||
set_weight_attrs(weight, extra_weight_attrs)
|
set_weight_attrs(weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
|||||||
@ -61,9 +61,24 @@ class BasevLLMParameter(Parameter):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def weight_loader(self):
|
def weight_loader(self) -> Callable:
|
||||||
|
# NOTE(@ksayers) some models such as mamba_mixer2 override the
|
||||||
|
# weight loader to support custom loading. In the future, model-specific
|
||||||
|
# weight loading should be implemented via Model.load_weights. In the
|
||||||
|
# meantime, support deleting and overriding `weight_loader`` attribute
|
||||||
|
if self._weight_loader is None:
|
||||||
|
raise AttributeError(f"{self.__class__.__name__} weight_loader "
|
||||||
|
"attribute has been deleted")
|
||||||
return self._weight_loader
|
return self._weight_loader
|
||||||
|
|
||||||
|
@weight_loader.setter
|
||||||
|
def weight_loader(self, value: Callable):
|
||||||
|
self._weight_loader = value
|
||||||
|
|
||||||
|
@weight_loader.deleter
|
||||||
|
def weight_loader(self):
|
||||||
|
self._weight_loader = None # type: ignore[assignment]
|
||||||
|
|
||||||
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
|
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
|
||||||
cond1 = self.data.ndim == 1 and self.data.numel() == 1
|
cond1 = self.data.ndim == 1 and self.data.numel() == 1
|
||||||
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
|
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
|
||||||
@ -97,6 +112,12 @@ class BasevLLMParameter(Parameter):
|
|||||||
assert shard_id in qkv_idxs
|
assert shard_id in qkv_idxs
|
||||||
return qkv_idxs[shard_id]
|
return qkv_idxs[shard_id]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
return super().__torch_function__(func, types, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
class _ColumnvLLMParameter(BasevLLMParameter):
|
class _ColumnvLLMParameter(BasevLLMParameter):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user