mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2026-04-30 21:25:46 +08:00
GGUF Q4 works
This commit is contained in:
parent
fb246f95ef
commit
ea5ee0b017
@ -206,7 +206,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
|
if fp8_transformer == "enabled" or fp8_transformer == "fastmode":
|
||||||
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"}
|
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"}
|
||||||
if "1.5" in model:
|
if "1.5" in model:
|
||||||
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding", "norm","ofs_embedding", "norm_final", "norm_out", "proj_out"}
|
params_to_keep.update({"norm1.linear.weight", "norm_k", "norm_q","ofs_embedding", "norm_final", "norm_out", "proj_out"})
|
||||||
for name, param in transformer.named_parameters():
|
for name, param in transformer.named_parameters():
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
param.data = param.data.to(torch.float8_e4m3fn)
|
param.data = param.data.to(torch.float8_e4m3fn)
|
||||||
@ -214,7 +214,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
if fp8_transformer == "fastmode":
|
if fp8_transformer == "fastmode":
|
||||||
from .fp8_optimization import convert_fp8_linear
|
from .fp8_optimization import convert_fp8_linear
|
||||||
if "1.5" in model:
|
if "1.5" in model:
|
||||||
params_to_keep.update({"ff"})
|
params_to_keep.update({"ff"}) #otherwise NaNs
|
||||||
convert_fp8_linear(transformer, dtype, params_to_keep=params_to_keep)
|
convert_fp8_linear(transformer, dtype, params_to_keep=params_to_keep)
|
||||||
|
|
||||||
with open(scheduler_path) as f:
|
with open(scheduler_path) as f:
|
||||||
@ -422,11 +422,11 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
|
params_to_keep = {"patch_embed", "pos_embedding", "time_embedding"}
|
||||||
cast_dtype = torch.float16
|
cast_dtype = torch.float16
|
||||||
elif "1_5" in model:
|
elif "1_5" in model:
|
||||||
params_to_keep = {"patch_embed", "time_embedding", "ofs_embedding", "norm_final", "norm_out", "proj_out", "norm"}
|
params_to_keep = {"norm1.linear.weight", "patch_embed", "time_embedding", "ofs_embedding", "norm_final", "norm_out", "proj_out"}
|
||||||
cast_dtype = torch.bfloat16
|
cast_dtype = torch.bfloat16
|
||||||
for name, param in transformer.named_parameters():
|
for name, param in transformer.named_parameters():
|
||||||
if not any(keyword in name for keyword in params_to_keep):
|
if not any(keyword in name for keyword in params_to_keep):
|
||||||
param.data = param.data.to(torch.bfloat16)
|
param.data = param.data.to(torch.float8_e4m3fn)
|
||||||
else:
|
else:
|
||||||
param.data = param.data.to(cast_dtype)
|
param.data = param.data.to(cast_dtype)
|
||||||
#for name, param in transformer.named_parameters():
|
#for name, param in transformer.named_parameters():
|
||||||
@ -438,8 +438,11 @@ class DownloadAndLoadCogVideoGGUFModel:
|
|||||||
transformer.attention_mode = attention_mode
|
transformer.attention_mode = attention_mode
|
||||||
|
|
||||||
if fp8_fastmode:
|
if fp8_fastmode:
|
||||||
|
params_to_keep = {"patch_embed", "lora", "pos_embedding", "time_embedding"}
|
||||||
|
if "1.5" in model:
|
||||||
|
params_to_keep.update({"ff","norm1.linear.weight", "norm_k", "norm_q","ofs_embedding", "norm_final", "norm_out", "proj_out"})
|
||||||
from .fp8_optimization import convert_fp8_linear
|
from .fp8_optimization import convert_fp8_linear
|
||||||
convert_fp8_linear(transformer, vae_dtype)
|
convert_fp8_linear(transformer, vae_dtype, params_to_keep=params_to_keep)
|
||||||
|
|
||||||
if compile == "torch":
|
if compile == "torch":
|
||||||
# compilation
|
# compilation
|
||||||
|
|||||||
@ -19,17 +19,21 @@ class quantize_lazy_load():
|
|||||||
|
|
||||||
|
|
||||||
def quantize_load_state_dict(model, state_dict, device="cpu"):
|
def quantize_load_state_dict(model, state_dict, device="cpu"):
|
||||||
Q4_0_qkey = []
|
quant_keys = []
|
||||||
for key in state_dict.keys():
|
for key in state_dict.keys():
|
||||||
if key.endswith(".Q4_0_qweight"):
|
if key.endswith(".Q4_0_qweight"):
|
||||||
Q4_0_qkey.append(key.replace(".Q4_0_qweight", ""))
|
quant_keys.append(key.replace(".Q4_0_qweight", ""))
|
||||||
|
qtype = "Q4_0"
|
||||||
|
elif key.endswith(".Q8_0_qweight"):
|
||||||
|
quant_keys.append(key.replace(".Q8_0_qweight", ""))
|
||||||
|
qtype = "Q8_0"
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if name in Q4_0_qkey:
|
if name in quant_keys:
|
||||||
q_linear = WQLinear_GGUF.from_linear(
|
q_linear = WQLinear_GGUF.from_linear(
|
||||||
linear=module,
|
linear=module,
|
||||||
device=device,
|
device=device,
|
||||||
qtype="Q4_0",
|
qtype=qtype,
|
||||||
)
|
)
|
||||||
set_op_by_name(model, name, q_linear)
|
set_op_by_name(model, name, q_linear)
|
||||||
|
|
||||||
@ -117,14 +121,14 @@ class WQLinear_GGUF(nn.Module):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
|
|
||||||
if self.qtype == "Q4_0":
|
if self.qtype == "Q4_0":
|
||||||
x = F.linear(x, dequantize_blocks_Q4_0(
|
dequant = dequantize_blocks_Q4_0(self.Q4_0_qweight, x.dtype)
|
||||||
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
elif self.qtype == "Q8_0":
|
||||||
|
dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown qtype: {self.qtype}")
|
raise ValueError(f"Unknown qtype: {self.qtype}")
|
||||||
|
|
||||||
return x
|
return F.linear(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None)
|
||||||
|
|
||||||
|
|
||||||
def split_block_dims(blocks, *args):
|
def split_block_dims(blocks, *args):
|
||||||
@ -153,6 +157,7 @@ def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]:
|
|||||||
|
|
||||||
GGML_QUANT_SIZES = {
|
GGML_QUANT_SIZES = {
|
||||||
"Q4_0": (32, 2 + 16),
|
"Q4_0": (32, 2 + 16),
|
||||||
|
"Q8_0": (32, 2 + 32),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -186,3 +191,31 @@ def dequantize_blocks_Q4_0(data, dtype=torch.float16):
|
|||||||
)).to(dtype)
|
)).to(dtype)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def dequantize_blocks_Q8_0(data, dtype=torch.float16):
|
||||||
|
block_size, type_size = GGML_QUANT_SIZES["Q8_0"]
|
||||||
|
|
||||||
|
data = data.to(torch.uint8)
|
||||||
|
shape = data.shape
|
||||||
|
|
||||||
|
rows = data.reshape(
|
||||||
|
(-1, data.shape[-1])
|
||||||
|
).view(torch.uint8)
|
||||||
|
|
||||||
|
n_blocks = rows.numel() // type_size
|
||||||
|
blocks = data.reshape((n_blocks, type_size))
|
||||||
|
|
||||||
|
n_blocks = blocks.shape[0]
|
||||||
|
|
||||||
|
d, qs = split_block_dims(blocks, 2)
|
||||||
|
d = d.view(torch.float16).to(torch.float32)
|
||||||
|
|
||||||
|
qs = qs.view(torch.int8).to(torch.float32)
|
||||||
|
|
||||||
|
out = (d * qs)
|
||||||
|
|
||||||
|
out = out.reshape(quant_shape_from_byte_shape(
|
||||||
|
shape,
|
||||||
|
qtype="Q8_0",
|
||||||
|
)).to(dtype)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user