From 62963d129e84d0a0904ee62dbab067a29216e7bf Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Wed, 3 Jul 2024 18:50:08 -0400 Subject: [PATCH] [ Misc ] Clean Up `CompressedTensorsW8A8` (#6113) --- tests/quantization/test_compressed_tensors.py | 9 ++-- .../compressed_tensors/compressed_tensors.py | 11 ++--- .../compressed_tensors/schemes/__init__.py | 5 +- .../schemes/compressed_tensors_w8a8.py | 34 +++++++++++++- .../compressed_tensors_w8a8_dynamictoken.py | 33 ------------- .../compressed_tensors_w8a8_statictensor.py | 47 ------------------- 6 files changed, 44 insertions(+), 95 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py delete mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index d5472f97a1c50..4cdda97dc728d 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -9,8 +9,7 @@ import torch from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor, - CompressedTensorsWNA16) + CompressedTensorsW8A8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationType) @@ -38,9 +37,10 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): CompressedTensorsLinearMethod) assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8) assert qkv_proj.scheme.strategy == strategy + assert qkv_proj.scheme.is_static_input_scheme expected_type = (torch.int8 if quant_type == QuantizationType.INT else torch.float8_e4m3fn) @@ -79,7 +79,8 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args): qkv_proj = layer.self_attn.qkv_proj assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8) + assert not qkv_proj.scheme.is_static_input_scheme assert qkv_proj.scheme.strategy == strategy assert qkv_proj.weight.dtype is torch.int8 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e88bbc361a5e0..8ca486d95941d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -9,8 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor, - CompressedTensorsWNA16) + CompressedTensorsW8A8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match) @@ -150,12 +149,12 @@ class CompressedTensorsConfig(QuantizationConfig): if self.quant_format == CompressionFormat.int_quantized.value: if self._is_static_tensor_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8StaticTensor( - strategy=weight_quant.strategy) + return CompressedTensorsW8A8(strategy=weight_quant.strategy, + is_static_input_scheme=True) if self._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8DynamicToken( - strategy=weight_quant.strategy) + return CompressedTensorsW8A8(strategy=weight_quant.strategy, + is_static_input_scheme=False) raise NotImplementedError( "No compressed-tensors compatible scheme was found.") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index f6d20ce2c6f77..720b8c263298c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,9 +3,6 @@ from .compressed_tensors_unquantized import ( # noqa: F401 CompressedTensorsUnquantized) from .compressed_tensors_w4a16_24 import ( # noqa: F401 W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) -from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501 - CompressedTensorsW8A8DynamicToken) -from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 - CompressedTensorsW8A8StaticTensor) +from .compressed_tensors_w8a8 import CompressedTensorsW8A8 # noqa: F401 from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS # noqa: F401 from .compressed_tensors_wNa16 import CompressedTensorsWNA16 # noqa: F401 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py index 49779057659f0..dffe2a284458f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py @@ -3,6 +3,7 @@ from typing import Callable, List, Tuple, Union import torch from torch.nn import Parameter +from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( @@ -12,8 +13,9 @@ from vllm.model_executor.utils import set_weight_attrs class CompressedTensorsW8A8(CompressedTensorsScheme): - def __init__(self, strategy: str): + def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme # Cutlass kernels support only per-tensor and per-channel cases. # So if we have a fused module (QKV, MLP) with per tensor scales (thus N @@ -36,6 +38,10 @@ class CompressedTensorsW8A8(CompressedTensorsScheme): layer.weight_scale = Parameter(weight_scale_channel, requires_grad=False) + # transpose weights for cutlass. + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size_per_partition: int, @@ -75,3 +81,29 @@ class CompressedTensorsW8A8(CompressedTensorsScheme): "output_dim": 0, "weight_loader": weight_loader, }) + + # INPUT SCALE + # Static quantization: load from disk. + if self.is_static_input_scheme: + input_scale = Parameter(torch.empty(1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("input_scale", input_scale) + set_weight_attrs(input_scale, { + "weight_loader": weight_loader, + "ignore_warning": True, + }) + # Dynamic quantization: set to None. + else: + layer.input_scale = None + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + # ops.scaled_int8_quant supports both dynamic and static quant. + # * dynamic, layer.input_scale is None and x_scale computed from x. + # * static, layer.input_scale is scalar and x_scale is input_scale. + x_q, x_scale = ops.scaled_int8_quant(x, layer.input_scale) + + return ops.cutlass_scaled_mm(x_q, + layer.weight, + scale_a=x_scale, + scale_b=layer.weight_scale, + out_dtype=x.dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py deleted file mode 100644 index 5fc05b8e682d6..0000000000000 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Callable, List - -import torch - -from vllm import _custom_ops as custom_ops -from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501 - CompressedTensorsW8A8) - -__all__ = ["CompressedTensorsW8A8DynamicToken"] - - -class CompressedTensorsW8A8DynamicToken(CompressedTensorsW8A8): - - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - - super().create_weights( - layer=layer, - output_partition_sizes=output_partition_sizes, - input_size_per_partition=input_size_per_partition, - params_dtype=params_dtype, - weight_loader=weight_loader) - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): - weight = layer.weight - weight_scale = layer.weight_scale - - x_q, input_scales = custom_ops.scaled_int8_quant(x) - return custom_ops.cutlass_scaled_mm(x_q, weight.t(), input_scales, - weight_scale, x.dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py deleted file mode 100644 index 79f5358a365ed..0000000000000 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Callable, List - -import torch -from torch.nn import Parameter - -from vllm import _custom_ops as custom_ops -from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501 - CompressedTensorsW8A8) -from vllm.model_executor.utils import set_weight_attrs - -__all__ = ["CompressedTensorsW8A8StaticTensor"] - - -class CompressedTensorsW8A8StaticTensor(CompressedTensorsW8A8): - - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - - super().create_weights( - layer=layer, - output_partition_sizes=output_partition_sizes, - input_size_per_partition=input_size_per_partition, - params_dtype=params_dtype, - weight_loader=weight_loader) - - input_scale = Parameter(torch.empty(1, dtype=torch.float32), - requires_grad=False) - - layer.register_parameter("input_scale", input_scale) - set_weight_attrs(input_scale, { - "weight_loader": weight_loader, - "ignore_warning": True, - }) - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): - weight = layer.weight - weight_scale = layer.weight_scale - act_scale = layer.input_scale - - # Input quantize - x_q, _ = custom_ops.scaled_int8_quant(x, act_scale) - - return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale, - weight_scale, x.dtype)