This commit is contained in:
kijai 2024-10-25 01:36:34 +03:00
parent a51ca0e907
commit 8fad965c58

View File

@ -161,7 +161,35 @@ GGML_QUANT_SIZES = {
"Q4_0": (32, 2 + 16),
"Q8_0": (32, 2 + 32),
}
def dequantize_blocks_Q4_0(data, dtype=torch.float16):
block_size, type_size = GGML_QUANT_SIZES["Q4_0"]
data = data.to(torch.uint8)
shape = data.shape
rows = data.reshape(
(-1, data.shape[-1])
).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = data.reshape((n_blocks, type_size))
n_blocks = blocks.shape[0]
d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
out = (d * qs)
out = out.reshape(quant_shape_from_byte_shape(
shape,
qtype="Q4_0",
)).to(dtype)
return out
def dequantize_blocks_Q8_0(data, dtype=torch.float16):
block_size, type_size = GGML_QUANT_SIZES["Q8_0"]