diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 4374fd98012f6..ae440743fdf8e 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -298,14 +298,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase): }, ) - g_idx_sort_indices = Parameter( - torch.empty( - g_idx.shape, - dtype=torch.int32, - ), - requires_grad=False, + g_idx_sort_indices = torch.empty( + g_idx.shape, + dtype=torch.int32, ) - set_weight_attrs(g_idx_sort_indices, extra_weight_attrs) # Scales scales = Parameter( @@ -356,9 +352,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase): layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) - layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) + layer.g_idx_sort_indices = g_idx_sort_indices layer.workspace = workspace layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition