diff --git a/fp8_optimization.py b/fp8_optimization.py index 9d6d712..c06913b 100644 --- a/fp8_optimization.py +++ b/fp8_optimization.py @@ -6,6 +6,10 @@ import torch.nn as nn def fp8_linear_forward(cls, original_dtype, input): weight_dtype = cls.weight.dtype if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + tensor_2d = False + if len(input.shape) == 2: + tensor_2d = True + input = input.unsqueeze(1) if len(input.shape) == 3: if weight_dtype == torch.float8_e4m3fn: inn = input.reshape(-1, input.shape[2]).to(torch.float8_e5m2) @@ -26,6 +30,9 @@ def fp8_linear_forward(cls, original_dtype, input): if isinstance(o, tuple): o = o[0] + + if tensor_2d: + return o.reshape(input.shape[0], -1) return o.reshape((-1, input.shape[1], cls.weight.shape[0])) else: diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 116665f..e75e52a 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -4,6 +4,7 @@ from typing import Dict, List import torch import torch.nn.functional as F import torch.utils.data +from einops import rearrange, repeat from .dit.joint_model.context_parallel import get_cp_rank_size from tqdm import tqdm @@ -125,7 +126,13 @@ class T2VSynthMochiModel: params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} print(f"Loading model state_dict from {dit_checkpoint_path}...") dit_sd = load_torch_file(dit_checkpoint_path) - if is_accelerate_available: + if "gguf" in dit_checkpoint_path.lower(): + from .. import mz_gguf_loader + import importlib + importlib.reload(mz_gguf_loader) + with mz_gguf_loader.quantize_lazy_load(): + model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu") + elif is_accelerate_available: print("Using accelerate to load and assign model weights to device...") for name, param in model.named_parameters(): if not any(keyword in name for keyword in params_to_keep): @@ -261,43 +268,59 @@ class T2VSynthMochiModel: dtype=torch.float32, ) - # if batch_cfg: - # sample_batched["packed_indices"] = self.get_packed_indices( - # sample_batched["y_mask"], **latent_dims - # ) - # z = repeat(z, "b ... -> (repeat b) ...", repeat=2) - # else: - - sample = { + if batch_cfg: #WIP + pos_embeds = args["positive_embeds"]["embeds"].to(self.device) + neg_embeds = args["negative_embeds"]["embeds"].to(self.device) + pos_attention_mask = args["positive_embeds"]["attention_mask"].to(self.device) + neg_attention_mask = args["negative_embeds"]["attention_mask"].to(self.device) + print(neg_embeds.shape) + y_feat = torch.cat((pos_embeds, neg_embeds)) + y_mask = torch.cat((pos_attention_mask, neg_attention_mask)) + zero_last_n_prompts = B# if neg_prompt == "" else 0 + y_feat[-zero_last_n_prompts:] = 0 + y_mask[-zero_last_n_prompts:] = False + + sample_batched = { + "y_mask": [y_mask], + "y_feat": [y_feat] + } + sample_batched["packed_indices"] = self.get_packed_indices( + sample_batched["y_mask"], **latent_dims + ) + z = repeat(z, "b ... -> (repeat b) ...", repeat=2) + print("sample_batched y_mask",sample_batched["y_mask"]) + print("y_mask type",type(sample_batched["y_mask"])) #" + print("ymask 0 shape",sample_batched["y_mask"][0].shape)#torch.Size([2, 256]) + else: + sample = { "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] - } - sample_null = { - "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], - "y_feat": [args["negative_embeds"]["embeds"].to(self.device)] - } + } + sample_null = { + "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], + "y_feat": [args["negative_embeds"]["embeds"].to(self.device)] + } - sample["packed_indices"] = self.get_packed_indices( - sample["y_mask"], **latent_dims - ) - sample_null["packed_indices"] = self.get_packed_indices( - sample_null["y_mask"], **latent_dims - ) + sample["packed_indices"] = self.get_packed_indices( + sample["y_mask"], **latent_dims + ) + sample_null["packed_indices"] = self.get_packed_indices( + sample_null["y_mask"], **latent_dims + ) def model_fn(*, z, sigma, cfg_scale): self.dit.to(self.device) - # if batch_cfg: - # with torch.autocast("cuda", dtype=torch.bfloat16): - # out = self.dit(z, sigma, **sample_batched) - # out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) - #else: + if batch_cfg: + with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16): + out = self.dit(z, sigma, **sample_batched) + out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) + else: + nonlocal sample, sample_null + with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16): + out_cond = self.dit(z, sigma, **sample) + out_uncond = self.dit(z, sigma, **sample_null) - nonlocal sample, sample_null - with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16): - out_cond = self.dit(z, sigma, **sample) - out_uncond = self.dit(z, sigma, **sample_null) assert out_cond.shape == out_uncond.shape - return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond comfy_pbar = ProgressBar(sample_steps) diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py new file mode 100644 index 0000000..085e9c4 --- /dev/null +++ b/mz_gguf_loader.py @@ -0,0 +1,189 @@ +# https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/9616415220fd09388622f40f6609e4ed81f048a5/mz_gguf_loader.py + +import torch +import torch.nn as nn +import gc + + +class quantize_lazy_load(): + def __init__(self): + self.device = None + + def __enter__(self): + self.device = torch.device("meta") + self.device.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.device.__exit__(exc_type, exc_value, traceback) + + +def quantize_load_state_dict(model, state_dict, device="cpu"): + Q4_0_qkey = [] + for key in state_dict.keys(): + if key.endswith(".Q4_0_qweight"): + Q4_0_qkey.append(key.replace(".Q4_0_qweight", "")) + + for name, module in model.named_modules(): + if name in Q4_0_qkey: + print(name) + q_linear = WQLinear_GGUF.from_linear( + linear=module, + device=device, + qtype="Q4_0", + ) + set_op_by_name(model, name, q_linear) + + model.to_empty(device=device) + model.load_state_dict(state_dict, strict=False) + model.to(device) + return model + + +def set_op_by_name(layer, name, new_module): + levels = name.split(".") + if len(levels) > 1: + mod_ = layer + for l_idx in range(len(levels) - 1): + if levels[l_idx].isdigit(): + mod_ = mod_[int(levels[l_idx])] + else: + mod_ = getattr(mod_, levels[l_idx]) + setattr(mod_, levels[-1], new_module) + else: + setattr(layer, name, new_module) + + +import torch.nn.functional as F + + +class WQLinear_GGUF(nn.Module): + def __init__( + self, in_features, out_features, bias, dev, qtype="Q4_0" + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.qtype = qtype + + qweight_shape = quant_shape_to_byte_shape( + (out_features, in_features), qtype + ) + self.register_buffer( + f"{qtype}_qweight", + torch.zeros( + qweight_shape, + dtype=torch.uint8, + device=dev, + ), + ) + if bias: + self.register_buffer( + "bias", + torch.zeros( + (out_features), + dtype=torch.float16, + device=dev, + ), + ) + else: + self.bias = None + + @classmethod + def from_linear( + cls, linear, + device="cpu", + qtype="Q4_0", + ): + q_linear = cls( + linear.in_features, + linear.out_features, + linear.bias is not None, + device, + qtype=qtype, + ) + return q_linear + + def extra_repr(self) -> str: + return ( + "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( + self.in_features, + self.out_features, + self.bias is not None, + self.w_bit, + self.group_size, + ) + ) + + @torch.no_grad() + def forward(self, x): + # x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight)) + if self.qtype == "Q4_0": + x = F.linear(x, dequantize_blocks_Q4_0( + self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + else: + raise ValueError(f"Unknown qtype: {self.qtype}") + + return x + + +def split_block_dims(blocks, *args): + n_max = blocks.shape[1] + dims = list(args) + [n_max - sum(args)] + return torch.split(blocks, dims, dim=1) + + +def quant_shape_to_byte_shape(shape, qtype) -> tuple[int, ...]: + # shape = shape[::-1] + block_size, type_size = GGML_QUANT_SIZES[qtype] + if shape[-1] % block_size != 0: + raise ValueError( + f"Quantized tensor row size ({shape[-1]}) is not a multiple of Q4_0 block size ({block_size})") + return (*shape[:-1], shape[-1] // block_size * type_size) + + +def quant_shape_from_byte_shape(shape, qtype) -> tuple[int, ...]: + # shape = shape[::-1] + block_size, type_size = GGML_QUANT_SIZES[qtype] + if shape[-1] % type_size != 0: + raise ValueError( + f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of Q4_0 type size ({type_size})") + return (*shape[:-1], shape[-1] // type_size * block_size) + + +GGML_QUANT_SIZES = { + "Q4_0": (32, 2 + 16), +} + + +def dequantize_blocks_Q4_0(data, dtype=torch.float16): + block_size, type_size = GGML_QUANT_SIZES["Q4_0"] + + data = data.to(torch.uint8) + shape = data.shape + + rows = data.reshape( + (-1, data.shape[-1]) + ).view(torch.uint8) + + n_blocks = rows.numel() // type_size + blocks = data.reshape((n_blocks, type_size)) + + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + + out = (d * qs) + + out = out.reshape(quant_shape_from_byte_shape( + shape, + qtype="Q4_0", + )).to(dtype) + return out + diff --git a/nodes.py b/nodes.py index 7d19db2..af9a1e0 100644 --- a/nodes.py +++ b/nodes.py @@ -49,6 +49,7 @@ class DownloadAndLoadMochiModel: [ "mochi_preview_dit_fp8_e4m3fn.safetensors", "mochi_preview_dit_bf16.safetensors", + "mochi_preview_dit_GGUF_Q4_0_v1.safetensors" ], {"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/diffusion_models/mochi'", }, @@ -208,6 +209,7 @@ class MochiSampler: "steps": ("INT", {"default": 50, "min": 2}), "cfg": ("FLOAT", {"default": 4.5, "min": 0.0, "max": 30.0, "step": 0.01}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + #"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}), }, }