diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index b2044d3..9130c7e 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -161,7 +161,35 @@ GGML_QUANT_SIZES = { "Q4_0": (32, 2 + 16), "Q8_0": (32, 2 + 32), } +def dequantize_blocks_Q4_0(data, dtype=torch.float16): + block_size, type_size = GGML_QUANT_SIZES["Q4_0"] + data = data.to(torch.uint8) + shape = data.shape + + rows = data.reshape( + (-1, data.shape[-1]) + ).view(torch.uint8) + + n_blocks = rows.numel() // type_size + blocks = data.reshape((n_blocks, type_size)) + + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + + out = (d * qs) + + out = out.reshape(quant_shape_from_byte_shape( + shape, + qtype="Q4_0", + )).to(dtype) + return out def dequantize_blocks_Q8_0(data, dtype=torch.float16): block_size, type_size = GGML_QUANT_SIZES["Q8_0"]