fix
This commit is contained in:
parent
a51ca0e907
commit
8fad965c58
@ -161,7 +161,35 @@ GGML_QUANT_SIZES = {
|
||||
"Q4_0": (32, 2 + 16),
|
||||
"Q8_0": (32, 2 + 32),
|
||||
}
|
||||
def dequantize_blocks_Q4_0(data, dtype=torch.float16):
|
||||
block_size, type_size = GGML_QUANT_SIZES["Q4_0"]
|
||||
|
||||
data = data.to(torch.uint8)
|
||||
shape = data.shape
|
||||
|
||||
rows = data.reshape(
|
||||
(-1, data.shape[-1])
|
||||
).view(torch.uint8)
|
||||
|
||||
n_blocks = rows.numel() // type_size
|
||||
blocks = data.reshape((n_blocks, type_size))
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qs = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
||||
|
||||
out = (d * qs)
|
||||
|
||||
out = out.reshape(quant_shape_from_byte_shape(
|
||||
shape,
|
||||
qtype="Q4_0",
|
||||
)).to(dtype)
|
||||
return out
|
||||
def dequantize_blocks_Q8_0(data, dtype=torch.float16):
|
||||
block_size, type_size = GGML_QUANT_SIZES["Q8_0"]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user