Add GGUF_Q8_0
This commit is contained in:
parent
1f25400bc2
commit
a51ca0e907
@ -19,18 +19,20 @@ class quantize_lazy_load():
|
|||||||
|
|
||||||
|
|
||||||
def quantize_load_state_dict(model, state_dict, device="cpu"):
|
def quantize_load_state_dict(model, state_dict, device="cpu"):
|
||||||
Q4_0_qkey = []
|
quant_keys = []
|
||||||
for key in state_dict.keys():
|
for key in state_dict.keys():
|
||||||
if key.endswith(".Q4_0_qweight"):
|
if key.endswith(".Q4_0_qweight"):
|
||||||
Q4_0_qkey.append(key.replace(".Q4_0_qweight", ""))
|
quant_keys.append(key.replace(".Q4_0_qweight", ""))
|
||||||
|
elif key.endswith(".Q8_0_qweight"):
|
||||||
|
quant_keys.append(key.replace(".Q8_0_qweight", ""))
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if name in Q4_0_qkey:
|
if name in quant_keys:
|
||||||
#print(name)
|
#print(name)
|
||||||
q_linear = WQLinear_GGUF.from_linear(
|
q_linear = WQLinear_GGUF.from_linear(
|
||||||
linear=module,
|
linear=module,
|
||||||
device=device,
|
device=device,
|
||||||
qtype="Q4_0",
|
qtype="Q8_0",
|
||||||
)
|
)
|
||||||
set_op_by_name(model, name, q_linear)
|
set_op_by_name(model, name, q_linear)
|
||||||
|
|
||||||
@ -122,6 +124,9 @@ class WQLinear_GGUF(nn.Module):
|
|||||||
if self.qtype == "Q4_0":
|
if self.qtype == "Q4_0":
|
||||||
x = F.linear(x, dequantize_blocks_Q4_0(
|
x = F.linear(x, dequantize_blocks_Q4_0(
|
||||||
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||||
|
elif self.qtype == "Q8_0":
|
||||||
|
x = F.linear(x, dequantize_blocks_Q8_0(
|
||||||
|
self.Q8_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown qtype: {self.qtype}")
|
raise ValueError(f"Unknown qtype: {self.qtype}")
|
||||||
|
|
||||||
@ -139,7 +144,7 @@ def quant_shape_to_byte_shape(shape, qtype) -> tuple[int, ...]:
|
|||||||
block_size, type_size = GGML_QUANT_SIZES[qtype]
|
block_size, type_size = GGML_QUANT_SIZES[qtype]
|
||||||
if shape[-1] % block_size != 0:
|
if shape[-1] % block_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Quantized tensor row size ({shape[-1]}) is not a multiple of Q4_0 block size ({block_size})")
|
f"Quantized tensor row size ({shape[-1]}) is not a multiple of {qtype} block size ({block_size})")
|
||||||
return (*shape[:-1], shape[-1] // block_size * type_size)
|
return (*shape[:-1], shape[-1] // block_size * type_size)
|
||||||
|
|
||||||
|
|
||||||
@ -148,17 +153,17 @@ def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]:
|
|||||||
block_size, type_size = GGML_QUANT_SIZES[qtype]
|
block_size, type_size = GGML_QUANT_SIZES[qtype]
|
||||||
if shape[-1] % type_size != 0:
|
if shape[-1] % type_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of Q4_0 type size ({type_size})")
|
f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {qtype} type size ({type_size})")
|
||||||
return (*shape[:-1], shape[-1] // type_size * block_size)
|
return (*shape[:-1], shape[-1] // type_size * block_size)
|
||||||
|
|
||||||
|
|
||||||
GGML_QUANT_SIZES = {
|
GGML_QUANT_SIZES = {
|
||||||
"Q4_0": (32, 2 + 16),
|
"Q4_0": (32, 2 + 16),
|
||||||
|
"Q8_0": (32, 2 + 32),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def dequantize_blocks_Q8_0(data, dtype=torch.float16):
|
||||||
def dequantize_blocks_Q4_0(data, dtype=torch.float16):
|
block_size, type_size = GGML_QUANT_SIZES["Q8_0"]
|
||||||
block_size, type_size = GGML_QUANT_SIZES["Q4_0"]
|
|
||||||
|
|
||||||
data = data.to(torch.uint8)
|
data = data.to(torch.uint8)
|
||||||
shape = data.shape
|
shape = data.shape
|
||||||
@ -173,17 +178,14 @@ def dequantize_blocks_Q4_0(data, dtype=torch.float16):
|
|||||||
n_blocks = blocks.shape[0]
|
n_blocks = blocks.shape[0]
|
||||||
|
|
||||||
d, qs = split_block_dims(blocks, 2)
|
d, qs = split_block_dims(blocks, 2)
|
||||||
d = d.view(torch.float16)
|
d = d.view(torch.float16).to(torch.float32)
|
||||||
|
|
||||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
qs = qs.view(torch.int8).to(torch.float32)
|
||||||
[0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
|
||||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
|
||||||
|
|
||||||
out = (d * qs)
|
out = (d * qs)
|
||||||
|
|
||||||
out = out.reshape(quant_shape_from_byte_shape(
|
out = out.reshape(quant_shape_from_byte_shape(
|
||||||
shape,
|
shape,
|
||||||
qtype="Q4_0",
|
qtype="Q8_0",
|
||||||
)).to(dtype)
|
)).to(dtype)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
3
nodes.py
3
nodes.py
@ -49,7 +49,8 @@ class DownloadAndLoadMochiModel:
|
|||||||
[
|
[
|
||||||
"mochi_preview_dit_fp8_e4m3fn.safetensors",
|
"mochi_preview_dit_fp8_e4m3fn.safetensors",
|
||||||
"mochi_preview_dit_bf16.safetensors",
|
"mochi_preview_dit_bf16.safetensors",
|
||||||
"mochi_preview_dit_GGUF_Q4_0_v2.safetensors"
|
"mochi_preview_dit_GGUF_Q4_0_v2.safetensors",
|
||||||
|
"mochi_preview_dit_GGUF_Q8_0.safetensors",
|
||||||
|
|
||||||
],
|
],
|
||||||
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", },
|
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", },
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user