Add first GGUF test version
This commit is contained in:
parent
d699fae213
commit
f4c13b1ef4
@ -6,6 +6,10 @@ import torch.nn as nn
|
||||
def fp8_linear_forward(cls, original_dtype, input):
|
||||
weight_dtype = cls.weight.dtype
|
||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
tensor_2d = False
|
||||
if len(input.shape) == 2:
|
||||
tensor_2d = True
|
||||
input = input.unsqueeze(1)
|
||||
if len(input.shape) == 3:
|
||||
if weight_dtype == torch.float8_e4m3fn:
|
||||
inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2)
|
||||
@ -26,6 +30,9 @@ def fp8_linear_forward(cls, original_dtype, input):
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
if tensor_2d:
|
||||
return o.reshape(input.shape[0], -1)
|
||||
|
||||
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
|
||||
else:
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Dict, List
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||
from tqdm import tqdm
|
||||
@ -125,7 +126,13 @@ class T2VSynthMochiModel:
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
print(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||
if is_accelerate_available:
|
||||
if "gguf" in dit_checkpoint_path.lower():
|
||||
from .. import mz_gguf_loader
|
||||
import importlib
|
||||
importlib.reload(mz_gguf_loader)
|
||||
with mz_gguf_loader.quantize_lazy_load():
|
||||
model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu")
|
||||
elif is_accelerate_available:
|
||||
print("Using accelerate to load and assign model weights to device...")
|
||||
for name, param in model.named_parameters():
|
||||
if not any(keyword in name for keyword in params_to_keep):
|
||||
@ -261,43 +268,59 @@ class T2VSynthMochiModel:
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# if batch_cfg:
|
||||
# sample_batched["packed_indices"] = self.get_packed_indices(
|
||||
# sample_batched["y_mask"], **latent_dims
|
||||
# )
|
||||
# z = repeat(z, "b ... -> (repeat b) ...", repeat=2)
|
||||
# else:
|
||||
|
||||
sample = {
|
||||
if batch_cfg: #WIP
|
||||
pos_embeds = args["positive_embeds"]["embeds"].to(self.device)
|
||||
neg_embeds = args["negative_embeds"]["embeds"].to(self.device)
|
||||
pos_attention_mask = args["positive_embeds"]["attention_mask"].to(self.device)
|
||||
neg_attention_mask = args["negative_embeds"]["attention_mask"].to(self.device)
|
||||
print(neg_embeds.shape)
|
||||
y_feat = torch.cat((pos_embeds, neg_embeds))
|
||||
y_mask = torch.cat((pos_attention_mask, neg_attention_mask))
|
||||
zero_last_n_prompts = B# if neg_prompt == "" else 0
|
||||
y_feat[-zero_last_n_prompts:] = 0
|
||||
y_mask[-zero_last_n_prompts:] = False
|
||||
|
||||
sample_batched = {
|
||||
"y_mask": [y_mask],
|
||||
"y_feat": [y_feat]
|
||||
}
|
||||
sample_batched["packed_indices"] = self.get_packed_indices(
|
||||
sample_batched["y_mask"], **latent_dims
|
||||
)
|
||||
z = repeat(z, "b ... -> (repeat b) ...", repeat=2)
|
||||
print("sample_batched y_mask",sample_batched["y_mask"])
|
||||
print("y_mask type",type(sample_batched["y_mask"])) #<class 'list'>"
|
||||
print("ymask 0 shape",sample_batched["y_mask"][0].shape)#torch.Size([2, 256])
|
||||
else:
|
||||
sample = {
|
||||
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
sample_null = {
|
||||
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
}
|
||||
sample_null = {
|
||||
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
|
||||
sample["packed_indices"] = self.get_packed_indices(
|
||||
sample["y_mask"], **latent_dims
|
||||
)
|
||||
sample_null["packed_indices"] = self.get_packed_indices(
|
||||
sample_null["y_mask"], **latent_dims
|
||||
)
|
||||
sample["packed_indices"] = self.get_packed_indices(
|
||||
sample["y_mask"], **latent_dims
|
||||
)
|
||||
sample_null["packed_indices"] = self.get_packed_indices(
|
||||
sample_null["y_mask"], **latent_dims
|
||||
)
|
||||
|
||||
def model_fn(*, z, sigma, cfg_scale):
|
||||
self.dit.to(self.device)
|
||||
# if batch_cfg:
|
||||
# with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
# out = self.dit(z, sigma, **sample_batched)
|
||||
# out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
|
||||
#else:
|
||||
if batch_cfg:
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16):
|
||||
out = self.dit(z, sigma, **sample_batched)
|
||||
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
|
||||
else:
|
||||
nonlocal sample, sample_null
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16):
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
out_uncond = self.dit(z, sigma, **sample_null)
|
||||
|
||||
nonlocal sample, sample_null
|
||||
with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16):
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
out_uncond = self.dit(z, sigma, **sample_null)
|
||||
assert out_cond.shape == out_uncond.shape
|
||||
|
||||
return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond
|
||||
|
||||
comfy_pbar = ProgressBar(sample_steps)
|
||||
|
||||
189
mz_gguf_loader.py
Normal file
189
mz_gguf_loader.py
Normal file
@ -0,0 +1,189 @@
|
||||
# https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/9616415220fd09388622f40f6609e4ed81f048a5/mz_gguf_loader.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import gc
|
||||
|
||||
|
||||
class quantize_lazy_load():
|
||||
def __init__(self):
|
||||
self.device = None
|
||||
|
||||
def __enter__(self):
|
||||
self.device = torch.device("meta")
|
||||
self.device.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.device.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
|
||||
def quantize_load_state_dict(model, state_dict, device="cpu"):
|
||||
Q4_0_qkey = []
|
||||
for key in state_dict.keys():
|
||||
if key.endswith(".Q4_0_qweight"):
|
||||
Q4_0_qkey.append(key.replace(".Q4_0_qweight", ""))
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name in Q4_0_qkey:
|
||||
print(name)
|
||||
q_linear = WQLinear_GGUF.from_linear(
|
||||
linear=module,
|
||||
device=device,
|
||||
qtype="Q4_0",
|
||||
)
|
||||
set_op_by_name(model, name, q_linear)
|
||||
|
||||
model.to_empty(device=device)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
def set_op_by_name(layer, name, new_module):
|
||||
levels = name.split(".")
|
||||
if len(levels) > 1:
|
||||
mod_ = layer
|
||||
for l_idx in range(len(levels) - 1):
|
||||
if levels[l_idx].isdigit():
|
||||
mod_ = mod_[int(levels[l_idx])]
|
||||
else:
|
||||
mod_ = getattr(mod_, levels[l_idx])
|
||||
setattr(mod_, levels[-1], new_module)
|
||||
else:
|
||||
setattr(layer, name, new_module)
|
||||
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class WQLinear_GGUF(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, out_features, bias, dev, qtype="Q4_0"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.qtype = qtype
|
||||
|
||||
qweight_shape = quant_shape_to_byte_shape(
|
||||
(out_features, in_features), qtype
|
||||
)
|
||||
self.register_buffer(
|
||||
f"{qtype}_qweight",
|
||||
torch.zeros(
|
||||
qweight_shape,
|
||||
dtype=torch.uint8,
|
||||
device=dev,
|
||||
),
|
||||
)
|
||||
if bias:
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.zeros(
|
||||
(out_features),
|
||||
dtype=torch.float16,
|
||||
device=dev,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@classmethod
|
||||
def from_linear(
|
||||
cls, linear,
|
||||
device="cpu",
|
||||
qtype="Q4_0",
|
||||
):
|
||||
q_linear = cls(
|
||||
linear.in_features,
|
||||
linear.out_features,
|
||||
linear.bias is not None,
|
||||
device,
|
||||
qtype=qtype,
|
||||
)
|
||||
return q_linear
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
|
||||
self.in_features,
|
||||
self.out_features,
|
||||
self.bias is not None,
|
||||
self.w_bit,
|
||||
self.group_size,
|
||||
)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x):
|
||||
# x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight))
|
||||
if self.qtype == "Q4_0":
|
||||
x = F.linear(x, dequantize_blocks_Q4_0(
|
||||
self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
else:
|
||||
raise ValueError(f"Unknown qtype: {self.qtype}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def split_block_dims(blocks, *args):
|
||||
n_max = blocks.shape[1]
|
||||
dims = list(args) + [n_max - sum(args)]
|
||||
return torch.split(blocks, dims, dim=1)
|
||||
|
||||
|
||||
def quant_shape_to_byte_shape(shape, qtype) -> tuple[int, ...]:
|
||||
# shape = shape[::-1]
|
||||
block_size, type_size = GGML_QUANT_SIZES[qtype]
|
||||
if shape[-1] % block_size != 0:
|
||||
raise ValueError(
|
||||
f"Quantized tensor row size ({shape[-1]}) is not a multiple of Q4_0 block size ({block_size})")
|
||||
return (*shape[:-1], shape[-1] // block_size * type_size)
|
||||
|
||||
|
||||
def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]:
|
||||
# shape = shape[::-1]
|
||||
block_size, type_size = GGML_QUANT_SIZES[qtype]
|
||||
if shape[-1] % type_size != 0:
|
||||
raise ValueError(
|
||||
f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of Q4_0 type size ({type_size})")
|
||||
return (*shape[:-1], shape[-1] // type_size * block_size)
|
||||
|
||||
|
||||
GGML_QUANT_SIZES = {
|
||||
"Q4_0": (32, 2 + 16),
|
||||
}
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_0(data, dtype=torch.float16):
|
||||
block_size, type_size = GGML_QUANT_SIZES["Q4_0"]
|
||||
|
||||
data = data.to(torch.uint8)
|
||||
shape = data.shape
|
||||
|
||||
rows = data.reshape(
|
||||
(-1, data.shape[-1])
|
||||
).view(torch.uint8)
|
||||
|
||||
n_blocks = rows.numel() // type_size
|
||||
blocks = data.reshape((n_blocks, type_size))
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qs = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
||||
|
||||
out = (d * qs)
|
||||
|
||||
out = out.reshape(quant_shape_from_byte_shape(
|
||||
shape,
|
||||
qtype="Q4_0",
|
||||
)).to(dtype)
|
||||
return out
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -49,6 +49,7 @@ class DownloadAndLoadMochiModel:
|
||||
[
|
||||
"mochi_preview_dit_fp8_e4m3fn.safetensors",
|
||||
"mochi_preview_dit_bf16.safetensors",
|
||||
"mochi_preview_dit_GGUF_Q4_0_v1.safetensors"
|
||||
|
||||
],
|
||||
{"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", },
|
||||
@ -208,6 +209,7 @@ class MochiSampler:
|
||||
"steps": ("INT", {"default": 50, "min": 2}),
|
||||
"cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
#"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user