mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
[Misc] Update compressed tensors lifecycle to remove prefix from create_weights (#7825)
This commit is contained in:
parent
760e9f71a8
commit
015e6cc252
@ -208,8 +208,7 @@ class ReplicatedLinear(LinearBase):
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
prefix=prefix)
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
@ -307,8 +306,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
|
||||
prefix=prefix)
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
@ -976,8 +974,7 @@ class RowParallelLinear(LinearBase):
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
|
||||
prefix=prefix)
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
@ -3,15 +3,15 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsScheme, CompressedTensorsUnquantized,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16)
|
||||
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||
QuantizationType, find_matched_target, is_activation_quantization_format,
|
||||
@ -52,15 +52,20 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def get_name(self) -> str:
|
||||
return "compressed_tensors"
|
||||
|
||||
# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
|
||||
# rather than though create_weights to match other methods
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
# TODO (@robertgshaw2): support module names
|
||||
if should_ignore_layer(prefix, ignore=self.ignore):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
@ -281,15 +286,11 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
to select the CompressedTensorsScheme used for infernece.
|
||||
"""
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
# TODO (@robertgshaw2): support module names
|
||||
if should_ignore_layer(layer_name, ignore=self.ignore):
|
||||
return CompressedTensorsUnquantized()
|
||||
|
||||
# Find the "target" in the compressed-tensors config
|
||||
# that our layer conforms to.
|
||||
# TODO (@robertgshaw): add compressed-tensors as dep
|
||||
# so we do not have to re-write these functions
|
||||
# need to make accelerate optional in ct to do this
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
@ -327,10 +328,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
details
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer_name = extra_weight_attrs.get("prefix")
|
||||
|
||||
scheme = self.quantization_config.get_scheme(layer, layer_name)
|
||||
scheme.create_weights(
|
||||
layer.scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
@ -339,8 +337,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.scheme = scheme
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_unquantized import CompressedTensorsUnquantized
|
||||
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24)
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
@ -10,7 +9,6 @@ from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsUnquantized",
|
||||
"CompressedTensorsWNA16",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW4A16Sparse24",
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
|
||||
__all__ = ["CompressedTensorsUnquantized"]
|
||||
|
||||
|
||||
class CompressedTensorsUnquantized(CompressedTensorsScheme):
|
||||
"""
|
||||
Implements the scheme for all layers which are ignored
|
||||
in the CompressedTensors config. The input and loaded weight are used
|
||||
in a linear transformation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# volta and up
|
||||
return 70
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
return F.linear(x, layer.weight, bias)
|
||||
Loading…
x
Reference in New Issue
Block a user