diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index ba0d67d..e812869 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -15,6 +15,7 @@ from .layers import ( RMSNorm, TimestepEmbedder, ) + from .mod_rmsnorm import modulated_rmsnorm from .residual_tanh_gated_rmsnorm import ( residual_tanh_gated_rmsnorm, @@ -140,7 +141,7 @@ class AsymmetricAttention(nn.Module): # Process visual features qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) - assert qkv_x.dtype == torch.bfloat16 + #assert qkv_x.dtype == torch.bfloat16 qkv_x = all_to_all_collect_tokens( qkv_x, self.num_heads ) # (3, B, N, local_h, head_dim) diff --git a/mochi_preview/dit/joint_model/temporal_rope.py b/mochi_preview/dit/joint_model/temporal_rope.py index a8276db..1ac6f89 100644 --- a/mochi_preview/dit/joint_model/temporal_rope.py +++ b/mochi_preview/dit/joint_model/temporal_rope.py @@ -19,7 +19,7 @@ def apply_rotary_emb_qk_real( Returns: torch.Tensor: The input tensor with rotary embeddings applied. """ - assert xqk.dtype == torch.bfloat16 + #assert xqk.dtype == torch.bfloat16 # Split the last dimension into even and odd parts xqk_even = xqk[..., 0::2] xqk_odd = xqk[..., 1::2] @@ -30,5 +30,5 @@ def apply_rotary_emb_qk_real( # Interleave the results back into the original shape out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2) - assert out.dtype == torch.bfloat16 + #assert out.dtype == torch.bfloat16 return out diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 7504afa..bf592bb 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -1,7 +1,7 @@ import json from typing import Dict, List, Optional, Union -#temporary patch to fix bug in Windows +#temporary patch to fix torch compile bug in Windows def patched_write_atomic( path_: str, content: Union[str, bytes], @@ -24,7 +24,7 @@ def patched_write_atomic( write_mode = "w" if isinstance(content, str) else "wb" with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: f.write(content) - shutil.copy2(src=tmp_path, dst=path) #to allow overwriting cache files + shutil.copy2(src=tmp_path, dst=path) #changed to allow overwriting cache files os.remove(tmp_path) try: import torch._inductor.codecache @@ -37,7 +37,7 @@ import torch.nn.functional as F import torch.utils.data from einops import rearrange, repeat -from .dit.joint_model.context_parallel import get_cp_rank_size +#from .dit.joint_model.context_parallel import get_cp_rank_size from tqdm import tqdm from comfy.utils import ProgressBar, load_torch_file import comfy.model_management as mm @@ -133,9 +133,11 @@ class T2VSynthMochiModel: fp8_fastmode: bool = False, attention_mode: str = "sdpa", compile_args: Optional[Dict] = None, + cublas_ops: Optional[bool] = False, ): super().__init__() self.device = device + self.weight_dtype = weight_dtype self.offload_device = offload_device logging.info("Initializing model...") @@ -170,7 +172,7 @@ class T2VSynthMochiModel: import importlib importlib.reload(mz_gguf_loader) with mz_gguf_loader.quantize_lazy_load(): - model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu") + model = mz_gguf_loader.quantize_load_state_dict(model, dit_sd, device="cpu", cublas_ops=cublas_ops) elif is_accelerate_available: logging.info("Using accelerate to load and assign model weights to device...") for name, param in model.named_parameters(): @@ -207,51 +209,6 @@ class T2VSynthMochiModel: self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) - def get_conditioning(self, prompts, *, zero_last_n_prompts: int): - B = len(prompts) - assert ( - 0 <= zero_last_n_prompts <= B - ), f"zero_last_n_prompts should be between 0 and {B}, got {zero_last_n_prompts}" - tokenize_kwargs = dict( - prompt=prompts, - padding="max_length", - return_tensors="pt", - truncation=True, - ) - - t5_toks = self.t5_tokenizer(**tokenize_kwargs, max_length=MAX_T5_TOKEN_LENGTH) - caption_input_ids_t5 = t5_toks["input_ids"] - caption_attention_mask_t5 = t5_toks["attention_mask"].bool() - del t5_toks - - assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH) - assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH) - - if zero_last_n_prompts > 0: - # Zero the last N prompts - caption_input_ids_t5[-zero_last_n_prompts:] = 0 - caption_attention_mask_t5[-zero_last_n_prompts:] = False - - caption_input_ids_t5 = caption_input_ids_t5.to(self.device, non_blocking=True) - caption_attention_mask_t5 = caption_attention_mask_t5.to( - self.device, non_blocking=True - ) - - y_mask = [caption_attention_mask_t5] - y_feat = [] - - self.t5_enc.to(self.device) - y_feat.append( - self.t5_enc( - caption_input_ids_t5, caption_attention_mask_t5 - ).last_hidden_state.detach().to(torch.float32) - ) - self.t5_enc.to(self.offload_device) - # Sometimes returns a tensor, othertimes a tuple, not sure why - # See: https://huggingface.co/genmo/mochi-1-preview/discussions/3 - assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096) - return dict(y_mask=y_mask, y_feat=y_feat) - def get_packed_indices(self, y_mask, *, lT, lW, lH): patch_size = 2 N = lT * lH * lW // (patch_size**2) @@ -364,7 +321,11 @@ class T2VSynthMochiModel: out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0) else: nonlocal sample, sample_null - with torch.autocast(mm.get_autocast_device(self.device), dtype=torch.bfloat16): + if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: + autocast_dtype = torch.float16 + else: + autocast_dtype = torch.bfloat16 + with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): out_cond = self.dit(z, sigma, **sample) out_uncond = self.dit(z, sigma, **sample_null) @@ -390,11 +351,11 @@ class T2VSynthMochiModel: z = z + dsigma * pred comfy_pbar.update(1) - cp_rank, cp_size = get_cp_rank_size() + #cp_rank, cp_size = get_cp_rank_size() if batch_cfg: z = z[:B] - z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim - self.dit.to(self.offload_device) + #z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim + self.dit.to(self.offload_device, non_blocking=True) samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) logging.info(f"samples shape: {samples.shape}") diff --git a/mz_gguf_loader.py b/mz_gguf_loader.py index a6fdf85..1b66c18 100644 --- a/mz_gguf_loader.py +++ b/mz_gguf_loader.py @@ -2,7 +2,8 @@ import torch import torch.nn as nn -import gc +import torch.nn.functional as F + class quantize_lazy_load(): @@ -18,7 +19,17 @@ class quantize_lazy_load(): self.device.__exit__(exc_type, exc_value, traceback) -def quantize_load_state_dict(model, state_dict, device="cpu"): +def quantize_load_state_dict(model, state_dict, device="cpu", cublas_ops=False): + if cublas_ops: + try: + from cublas_ops import cublas_half_matmul + linear_ops = cublas_half_matmul + print("Using cublas_ops") + except: + raise ImportError("Install cublas_ops (https://github.com/aredden/torch-cublas-hgemm) to use cublas_ops") + else: + linear_ops = F.linear + pass quant_keys = [] for key in state_dict.keys(): if key.endswith(".Q4_0_qweight"): @@ -35,12 +46,14 @@ def quantize_load_state_dict(model, state_dict, device="cpu"): linear=module, device=device, qtype=qtype, + linear_ops=linear_ops ) set_op_by_name(model, name, q_linear) model.to_empty(device=device) model.load_state_dict(state_dict, strict=False) - model.to(device) + if linear_ops == cublas_half_matmul: + setattr(model, "cublas_half_matmul", True) return model @@ -57,19 +70,16 @@ def set_op_by_name(layer, name, new_module): else: setattr(layer, name, new_module) - -import torch.nn.functional as F - - class WQLinear_GGUF(nn.Module): def __init__( - self, in_features, out_features, bias, dev, qtype="Q4_0" + self, in_features, out_features, bias, dev, qtype, linear_ops ): super().__init__() self.in_features = in_features self.out_features = out_features self.qtype = qtype + self.linear_ops = linear_ops qweight_shape = quant_shape_to_byte_shape( (out_features, in_features), qtype @@ -99,6 +109,7 @@ class WQLinear_GGUF(nn.Module): cls, linear, device="cpu", qtype="Q4_0", + linear_ops=F.linear ): q_linear = cls( linear.in_features, @@ -106,6 +117,7 @@ class WQLinear_GGUF(nn.Module): linear.bias is not None, device, qtype=qtype, + linear_ops=linear_ops ) return q_linear @@ -120,6 +132,7 @@ class WQLinear_GGUF(nn.Module): ) ) + @torch.no_grad() def forward(self, x): # x = torch.matmul(x, dequantize_blocks_Q4_0(self.qweight)) @@ -127,8 +140,11 @@ class WQLinear_GGUF(nn.Module): x = F.linear(x, dequantize_blocks_Q4_0( self.Q4_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) elif self.qtype == "Q8_0": - x = F.linear(x, dequantize_blocks_Q8_0( - self.Q8_0_qweight, x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) + dequant = dequantize_blocks_Q8_0(self.Q8_0_qweight, x.dtype) + + #x = F.linear(x, dequant, self.bias.to(x.dtype) if self.bias is not None else None) + x = self.linear_ops(x, dequant, bias=self.bias.to(x.dtype) if self.bias is not None else None) + else: raise ValueError(f"Unknown qtype: {self.qtype}") diff --git a/nodes.py b/nodes.py index e1c342e..50156f1 100644 --- a/nodes.py +++ b/nodes.py @@ -72,6 +72,7 @@ class DownloadAndLoadMochiModel: "optional": { "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), + "cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops from'https://github.com/aredden/torch-cublas-hgemm'",}), }, } @@ -81,7 +82,7 @@ class DownloadAndLoadMochiModel: CATEGORY = "MochiWrapper" DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface" - def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None): + def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -126,7 +127,8 @@ class DownloadAndLoadMochiModel: weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, attention_mode=attention_mode, - compile_args=compile_args + compile_args=compile_args, + cublas_ops=cublas_ops ) with (init_empty_weights() if is_accelerate_available else nullcontext()): vae = Decoder(