[ Misc ] Rs/compressed tensors cleanup (#5432)

Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
This commit is contained in:
Robert Shaw 2024-06-14 13:01:46 -04:00 committed by GitHub
parent d74674bbd9
commit 15985680e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 39 deletions

View File

@ -26,7 +26,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return []
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16]
return [torch.float16, torch.bfloat16]
# Need to figure it out
def get_min_capability(self) -> int:

View File

@ -64,10 +64,9 @@ class CompressedTensorsW4A16(CompressedTensorsScheme):
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": pack_factor
"pack_factor": pack_factor,
"weight_loader": weight_loader
})
set_weight_attrs(weight, {"weight_loader": weight_loader})
layer.register_parameter("weight_packed", weight)
weight_scale = Parameter(
@ -79,11 +78,12 @@ class CompressedTensorsW4A16(CompressedTensorsScheme):
requires_grad=False,
)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(weight_scale, {
"input_dim": weight_scale_dim,
"output_dim": 0
})
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"input_dim": weight_scale_dim,
"output_dim": 0
})
layer.register_parameter("weight_scale", weight_scale)
# A 2D array defining the original shape of the weights
@ -92,7 +92,10 @@ class CompressedTensorsW4A16(CompressedTensorsScheme):
requires_grad=False)
layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
set_weight_attrs(weight_shape, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition

View File

@ -48,9 +48,6 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)
weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32),
requires_grad=False)
@ -61,20 +58,21 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, {"weight_loader": weight_loader})
set_weight_attrs(weight, {"logical_widths": output_partition_sizes})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(
weight_scale, {
"shard_splitter": self.scales_shard_splitter,
weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
"logical_widths": output_partition_sizes
})
layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes
})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight

View File

@ -39,22 +39,16 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
# TODO: remove zero_point parameters once the configs given remove them
is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
requires_grad=False)
input_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)
weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32),
requires_grad=False)
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
@ -72,11 +66,6 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("input_zero_point", input_zero_point)
set_weight_attrs(input_zero_point, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(
weight_scale, {
@ -85,11 +74,6 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
"logical_widths": output_partition_sizes,
"ignore_warning": True,
})
layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {
"weight_loader": weight_loader,
"ignore_warning": True
})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight