Enable ZP Support for Machete (#20268)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere 2025-07-01 00:12:20 -07:00 committed by GitHub
parent 22e9d42040
commit 9909726d2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 5 deletions

View File

@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
fn = lambda: ops.gptq_marlin_gemm( fn = lambda: ops.gptq_marlin_gemm(
a=bt.a, a=bt.a,
c=None,
b_q_weight=w_q, b_q_weight=w_q,
b_scales=w_s, b_scales=w_s,
global_scale=None,
b_zeros=w_zp, b_zeros=w_zp,
g_idx=g_idx, g_idx=g_idx,
perm=sort_indices, perm=sort_indices,

View File

@ -139,7 +139,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
def group_size_valid(shape: tuple[int, int, int], def group_size_valid(shape: tuple[int, int, int],
group_size: Optional[int]) -> bool: group_size: Optional[int]) -> bool:
return group_size is None or group_size == -1 or group_size % shape[2] == 0 return group_size is None or group_size == -1 or shape[2] % group_size == 0
def machete_quantize_and_pack(atype: torch.dtype, def machete_quantize_and_pack(atype: torch.dtype,

View File

@ -33,8 +33,6 @@ class MacheteLinearKernel(MPLinearKernel):
return False, "Act reordering currently not supported by Machete, "\ return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\ "when the input features are partitioned across "\
"devices" "devices"
if c.zero_points:
return False, "Zero points currently not supported by Machete"
if c.weight_type not in query_machete_supported_quant_types( if c.weight_type not in query_machete_supported_quant_types(
c.zero_points): c.zero_points):
@ -53,6 +51,7 @@ class MacheteLinearKernel(MPLinearKernel):
# note assumes that # note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1} # `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module): def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config c = self.config
@ -90,16 +89,29 @@ class MacheteLinearKernel(MPLinearKernel):
x.data = x.data.contiguous() x.data = x.data.contiguous()
return x return x
def transform_w_zp(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=1)
w_s = getattr(layer, self.w_s_name).data
# pre-apply scales to zero-points
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
return x
# Repack weights and scales for Machete # Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s) self._transform_param(layer, self.w_s_name, transform_w_s)
if c.zero_points:
self._transform_param(layer, self.w_zp_name, transform_w_zp)
def apply_weights(self, def apply_weights(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer) w_q, w_s, w_zp, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1]) x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
@ -110,7 +122,7 @@ class MacheteLinearKernel(MPLinearKernel):
output = ops.machete_mm(a=x_2d, output = ops.machete_mm(a=x_2d,
b_q=w_q, b_q=w_q,
b_type=c.weight_type, b_type=c.weight_type,
b_group_zeros=None, b_group_zeros=w_zp,
b_group_scales=w_s, b_group_scales=w_s,
b_group_size=c.group_size) b_group_size=c.group_size)