[Misc] Update compressed tensors lifecycle to remove prefix from create_weights (#7825)

This commit is contained in:
Dipika Sikka 2024-08-26 20:09:34 -04:00 committed by GitHub
parent 760e9f71a8
commit 015e6cc252
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 17 additions and 75 deletions

View File

@ -208,8 +208,7 @@ class ReplicatedLinear(LinearBase):
self.input_size, self.input_size,
self.output_size, self.output_size,
self.params_dtype, self.params_dtype,
weight_loader=self.weight_loader, weight_loader=self.weight_loader)
prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
@ -307,8 +306,7 @@ class ColumnParallelLinear(LinearBase):
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=( weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__ self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
@ -976,8 +974,7 @@ class RowParallelLinear(LinearBase):
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=( weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__ self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results") "results can lead to incorrect results")

View File

@ -3,15 +3,15 @@ from typing import Any, Dict, List, Optional
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsUnquantized, CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy, CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_matched_target, is_activation_quantization_format, QuantizationType, find_matched_target, is_activation_quantization_format,
@ -52,15 +52,20 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_name(self) -> str: def get_name(self) -> str:
return "compressed_tensors" return "compressed_tensors"
# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
# rather than though create_weights to match other methods
def get_quant_method( def get_quant_method(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention): if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self) return CompressedTensorsKVCacheMethod(self)
@ -281,15 +286,11 @@ class CompressedTensorsConfig(QuantizationConfig):
to select the CompressedTensorsScheme used for infernece. to select the CompressedTensorsScheme used for infernece.
""" """
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(layer_name, ignore=self.ignore):
return CompressedTensorsUnquantized()
# Find the "target" in the compressed-tensors config # Find the "target" in the compressed-tensors config
# that our layer conforms to. # that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep # TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions # so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
matched_target = find_matched_target( matched_target = find_matched_target(
layer_name=layer_name, layer_name=layer_name,
module=layer, module=layer,
@ -327,10 +328,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
details details
""" """
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
layer_name = extra_weight_attrs.get("prefix") layer.scheme.create_weights(
scheme = self.quantization_config.get_scheme(layer, layer_name)
scheme.create_weights(
layer=layer, layer=layer,
input_size=input_size, input_size=input_size,
input_size_per_partition=input_size_per_partition, input_size_per_partition=input_size_per_partition,
@ -339,8 +337,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
params_dtype=params_dtype, params_dtype=params_dtype,
weight_loader=weight_loader) weight_loader=weight_loader)
layer.scheme = scheme
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,

View File

@ -1,5 +1,4 @@
from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_unquantized import CompressedTensorsUnquantized
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24) CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
@ -10,7 +9,6 @@ from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
__all__ = [ __all__ = [
"CompressedTensorsScheme", "CompressedTensorsScheme",
"CompressedTensorsUnquantized",
"CompressedTensorsWNA16", "CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8", "CompressedTensorsW8A16Fp8",
"CompressedTensorsW4A16Sparse24", "CompressedTensorsW4A16Sparse24",

View File

@ -1,49 +0,0 @@
from typing import Callable, List, Optional
import torch
import torch.nn.functional as F
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.parameter import ModelWeightParameter
__all__ = ["CompressedTensorsUnquantized"]
class CompressedTensorsUnquantized(CompressedTensorsScheme):
"""
Implements the scheme for all layers which are ignored
in the CompressedTensors config. The input and loaded weight are used
in a linear transformation.
"""
@classmethod
def get_min_capability(cls) -> int:
# volta and up
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile to be torch.nn.Parameter
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return F.linear(x, layer.weight, bias)