mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 03:28:42 +08:00
[ Misc ] Rs/compressed tensors cleanup (#5432)
Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
This commit is contained in:
parent
d74674bbd9
commit
15985680e2
@ -26,7 +26,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return []
|
||||
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16]
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
# Need to figure it out
|
||||
def get_min_capability(self) -> int:
|
||||
|
||||
@ -64,10 +64,9 @@ class CompressedTensorsW4A16(CompressedTensorsScheme):
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": pack_factor
|
||||
"pack_factor": pack_factor,
|
||||
"weight_loader": weight_loader
|
||||
})
|
||||
set_weight_attrs(weight, {"weight_loader": weight_loader})
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
weight_scale = Parameter(
|
||||
@ -79,11 +78,12 @@ class CompressedTensorsW4A16(CompressedTensorsScheme):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(weight_scale, {
|
||||
"input_dim": weight_scale_dim,
|
||||
"output_dim": 0
|
||||
})
|
||||
set_weight_attrs(
|
||||
weight_scale, {
|
||||
"weight_loader": weight_loader,
|
||||
"input_dim": weight_scale_dim,
|
||||
"output_dim": 0
|
||||
})
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
@ -92,7 +92,10 @@ class CompressedTensorsW4A16(CompressedTensorsScheme):
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(weight_shape, {
|
||||
"weight_loader": weight_loader,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
@ -48,9 +48,6 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
||||
weight_scale_dim = sum(
|
||||
output_partition_sizes) if is_tensor_partitioned else 1
|
||||
|
||||
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
|
||||
weight_scale = Parameter(torch.empty(weight_scale_dim,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
@ -61,20 +58,21 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
set_weight_attrs(weight, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(weight, {"logical_widths": output_partition_sizes})
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(
|
||||
weight_scale, {
|
||||
"shard_splitter": self.scales_shard_splitter,
|
||||
weight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"weight_loader": weight_loader,
|
||||
"logical_widths": output_partition_sizes
|
||||
})
|
||||
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
set_weight_attrs(
|
||||
weight_scale, {
|
||||
"weight_loader": weight_loader,
|
||||
"shard_splitter": self.scales_shard_splitter,
|
||||
"logical_widths": output_partition_sizes
|
||||
})
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||
weight = layer.weight
|
||||
|
||||
@ -39,22 +39,16 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
# TODO: remove zero_point parameters once the configs given remove them
|
||||
|
||||
is_tensor_partitioned = len(output_partition_sizes) != 1
|
||||
weight_scale_dim = sum(
|
||||
output_partition_sizes) if is_tensor_partitioned else 1
|
||||
|
||||
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
input_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
|
||||
weight_scale = Parameter(torch.empty(weight_scale_dim,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
@ -72,11 +66,6 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
||||
"weight_loader": weight_loader,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
set_weight_attrs(input_zero_point, {
|
||||
"weight_loader": weight_loader,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
set_weight_attrs(
|
||||
weight_scale, {
|
||||
@ -85,11 +74,6 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
||||
"logical_widths": output_partition_sizes,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
set_weight_attrs(weight_zero_point, {
|
||||
"weight_loader": weight_loader,
|
||||
"ignore_warning": True
|
||||
})
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||
weight = layer.weight
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user