237 lines
6.9 KiB
Python
237 lines
6.9 KiB
Python
# https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/9616415220fd09388622f40f6609e4ed81f048a5/mz_gguf_loader.py
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
class quantize_lazy_load():
|
|
def __init__(self):
|
|
self.device = None
|
|
|
|
def __enter__(self):
|
|
self.device = torch.device("meta")
|
|
self.device.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.device.__exit__(exc_type, exc_value, traceback)
|
|
|
|
|
|
def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False):
|
|
if cublas_ops:
|
|
try:
|
|
from cublas_ops import cublas_half_matmul
|
|
linear_ops = cublas_half_matmul
|
|
setattr(model, "cublas_half_matmul", True)
|
|
print("Using cublas_ops")
|
|
except:
|
|
raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops")
|
|
else:
|
|
linear_ops = F.linear
|
|
setattr(model, "cublas_half_matmul", False)
|
|
|
|
quant_keys = []
|
|
for key in state_dict.keys():
|
|
if key.endswith(".Q4_0_qweight"):
|
|
quant_keys.append(key.replace(".Q4_0_qweight", ""))
|
|
qtype = "Q4_0"
|
|
elif key.endswith(".Q8_0_qweight"):
|
|
quant_keys.append(key.replace(".Q8_0_qweight", ""))
|
|
qtype = "Q8_0"
|
|
|
|
for name, module in model.named_modules():
|
|
if name in quant_keys:
|
|
#print(name)
|
|
q_linear = WQLinear_GGUF.from_linear(
|
|
linear=module,
|
|
device=device,
|
|
qtype=qtype,
|
|
linear_ops=linear_ops
|
|
)
|
|
set_op_by_name(model, name, q_linear)
|
|
|
|
model.to_empty(device=device)
|
|
model.load_state_dict(state_dict, strict=False)
|
|
return model
|
|
|
|
|
|
def set_op_by_name(layer, name, new_module):
|
|
levels = name.split(".")
|
|
if len(levels) > 1:
|
|
mod_ = layer
|
|
for l_idx in range(len(levels) - 1):
|
|
if levels[l_idx].isdigit():
|
|
mod_ = mod_[int(levels[l_idx])]
|
|
else:
|
|
mod_ = getattr(mod_, levels[l_idx])
|
|
setattr(mod_, levels[-1], new_module)
|
|
else:
|
|
setattr(layer, name, new_module)
|
|
|
|
class WQLinear_GGUF(nn.Module):
|
|
def __init__(
|
|
self, in_features, out_features, bias, dev, qtype, linear_ops
|
|
):
|
|
super().__init__()
|
|
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.qtype = qtype
|
|
self.linear_ops = linear_ops
|
|
|
|
qweight_shape = quant_shape_to_byte_shape(
|
|
(out_features, in_features), qtype
|
|
)
|
|
self.register_buffer(
|
|
f"{qtype}_qweight",
|
|
torch.zeros(
|
|
qweight_shape,
|
|
dtype=torch.uint8,
|
|
device=dev,
|
|
),
|
|
)
|
|
if bias:
|
|
self.register_buffer(
|
|
"bias",
|
|
torch.zeros(
|
|
(out_features),
|
|
dtype=torch.float16,
|
|
device=dev,
|
|
),
|
|
)
|
|
else:
|
|
self.bias = None
|
|
|
|
@classmethod
|
|
def from_linear(
|
|
cls, linear,
|
|
device="cpu",
|
|
qtype="Q4_0",
|
|
linear_ops=F.linear
|
|
):
|
|
q_linear = cls(
|
|
linear.in_features,
|
|
linear.out_features,
|
|
linear.bias is not None,
|
|
device,
|
|
qtype=qtype,
|
|
linear_ops=linear_ops
|
|
)
|
|
return q_linear
|
|
|
|
def extra_repr(self) -> str:
|
|
return (
|
|
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
|
|
self.in_features,
|
|
self.out_features,
|
|
self.bias is not None,
|
|
self.w_bit,
|
|
self.group_size,
|
|
)
|
|
)
|
|
|
|
|
|
@torch.no_grad()
|
|
def forward(self, x):
|
|
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
|
|
if self.qtype == "Q4_0":
|
|
x = F.linear(x, dequantize_blocks_Q4_0(
|
|
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
|
elif self.qtype == "Q8_0":
|
|
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
|
|
|
|
#x = F.linear(x, dequant, self.bias.to(x.dtype) if self.bias is not None else None)
|
|
x = self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown qtype: {self.qtype}")
|
|
|
|
return x
|
|
|
|
|
|
def split_block_dims(blocks, *args):
|
|
n_max = blocks.shape[1]
|
|
dims = list(args) + [n_max - sum(args)]
|
|
return torch.split(blocks, dims, dim=1)
|
|
|
|
|
|
def quant_shape_to_byte_shape(shape, qtype) -> tuple[int, ...]:
|
|
# shape = shape[::-1]
|
|
block_size, type_size = GGML_QUANT_SIZES[qtype]
|
|
if shape[-1] % block_size != 0:
|
|
raise ValueError(
|
|
f"Quantized tensor row size ({shape[-1]}) is not a multiple of {qtype} block size ({block_size})")
|
|
return (*shape[:-1], shape[-1] // block_size * type_size)
|
|
|
|
|
|
def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]:
|
|
# shape = shape[::-1]
|
|
block_size, type_size = GGML_QUANT_SIZES[qtype]
|
|
if shape[-1] % type_size != 0:
|
|
raise ValueError(
|
|
f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {qtype} type size ({type_size})")
|
|
return (*shape[:-1], shape[-1] // type_size * block_size)
|
|
|
|
|
|
GGML_QUANT_SIZES = {
|
|
"Q4_0": (32, 2 + 16),
|
|
"Q8_0": (32, 2 + 32),
|
|
}
|
|
def dequantize_blocks_Q4_0(data, dtype=torch.float16):
|
|
block_size, type_size = GGML_QUANT_SIZES["Q4_0"]
|
|
|
|
data = data.to(torch.uint8)
|
|
shape = data.shape
|
|
|
|
rows = data.reshape(
|
|
(-1, data.shape[-1])
|
|
).view(torch.uint8)
|
|
|
|
n_blocks = rows.numel() // type_size
|
|
blocks = data.reshape((n_blocks, type_size))
|
|
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, qs = split_block_dims(blocks, 2)
|
|
d = d.view(torch.float16)
|
|
|
|
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
|
[0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
|
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
|
|
|
out = (d * qs)
|
|
|
|
out = out.reshape(quant_shape_from_byte_shape(
|
|
shape,
|
|
qtype="Q4_0",
|
|
)).to(dtype)
|
|
return out
|
|
def dequantize_blocks_Q8_0(data, dtype=torch.float16):
|
|
block_size, type_size = GGML_QUANT_SIZES["Q8_0"]
|
|
|
|
data = data.to(torch.uint8)
|
|
shape = data.shape
|
|
|
|
rows = data.reshape(
|
|
(-1, data.shape[-1])
|
|
).view(torch.uint8)
|
|
|
|
n_blocks = rows.numel() // type_size
|
|
blocks = data.reshape((n_blocks, type_size))
|
|
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, qs = split_block_dims(blocks, 2)
|
|
d = d.view(torch.float16).to(torch.float32)
|
|
|
|
qs = qs.view(torch.int8).to(torch.float32)
|
|
|
|
out = (d * qs)
|
|
|
|
out = out.reshape(quant_shape_from_byte_shape(
|
|
shape,
|
|
qtype="Q8_0",
|
|
)).to(dtype)
|
|
return out |