[Bugfix] gptq_marlin: Ensure g_idx_sort_indices is not a Parameter (#5108)

This commit is contained in:
Alexander Matveev 2024-05-29 20:30:18 -04:00 committed by GitHub
parent 4fbcb0f27e
commit 5bf185a1c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -298,14 +298,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
},
)
g_idx_sort_indices = Parameter(
torch.empty(
g_idx.shape,
dtype=torch.int32,
),
requires_grad=False,
g_idx_sort_indices = torch.empty(
g_idx.shape,
dtype=torch.int32,
)
set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
# Scales
scales = Parameter(
@ -356,9 +352,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.g_idx_sort_indices = g_idx_sort_indices
layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition