diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 0f896f187ecb9..f73d0511e01fc 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: fn = lambda: ops.gptq_marlin_gemm( a=bt.a, + c=None, b_q_weight=w_q, b_scales=w_s, + global_scale=None, b_zeros=w_zp, g_idx=g_idx, perm=sort_indices, diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index 998171baaf2de..a4fb9874c4906 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool: - return group_size is None or group_size == -1 or group_size % shape[2] == 0 + return group_size is None or group_size == -1 or shape[2] % group_size == 0 def machete_quantize_and_pack(atype: torch.dtype, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index c7c45861875af..a75f3ac8d5033 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -33,8 +33,6 @@ class MacheteLinearKernel(MPLinearKernel): return False, "Act reordering currently not supported by Machete, "\ "when the input features are partitioned across "\ "devices" - if c.zero_points: - return False, "Zero points currently not supported by Machete" if c.weight_type not in query_machete_supported_quant_types( c.zero_points): @@ -53,6 +51,7 @@ class MacheteLinearKernel(MPLinearKernel): # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} + # `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): c = self.config @@ -90,16 +89,29 @@ class MacheteLinearKernel(MPLinearKernel): x.data = x.data.contiguous() return x + def transform_w_zp(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1) + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=1) + w_s = getattr(layer, self.w_s_name).data + # pre-apply scales to zero-points + x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous() + return x + # Repack weights and scales for Machete self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) + if c.zero_points: + self._transform_param(layer, self.w_zp_name, transform_w_zp) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: c = self.config - w_q, w_s, _, _ = self._get_weight_params(layer) + w_q, w_s, w_zp, _ = self._get_weight_params(layer) x_2d = x.reshape(-1, x.shape[-1]) out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) @@ -110,7 +122,7 @@ class MacheteLinearKernel(MPLinearKernel): output = ops.machete_mm(a=x_2d, b_q=w_q, b_type=c.weight_type, - b_group_zeros=None, + b_group_zeros=w_zp, b_group_scales=w_s, b_group_size=c.group_size)