From 56b5dbbf828f86f59ce3e01004d707aeee7ad304 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:11:58 +0200 Subject: [PATCH] Add different RMSNorm functions for testing Initial testing for me shows that the RMSNorm from flash_attn.ops.triton.layer_norm is ~8-10% faster, apex is untested as I don't currently have it installed. --- configs/vae_stats.json | 4 --- .../dit/joint_model/asymm_models_joint.py | 36 ++++++++++++++++--- mochi_preview/t2v_synth_mochi.py | 13 +++---- nodes.py | 23 ++++++++---- 4 files changed, 51 insertions(+), 25 deletions(-) delete mode 100644 configs/vae_stats.json diff --git a/configs/vae_stats.json b/configs/vae_stats.json deleted file mode 100644 index e3278af..0000000 --- a/configs/vae_stats.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285], - "std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041] -} diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 9650f40..7443961 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -7,15 +7,15 @@ import torch.nn.functional as F from .layers import ( FeedForward, PatchEmbed, - RMSNorm, TimestepEmbedder, ) + from .mod_rmsnorm import modulated_rmsnorm -from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm) -from .rope_mixed import (compute_mixed_rotation, create_position_matrix) +from .residual_tanh_gated_rmsnorm import residual_tanh_gated_rmsnorm +from .rope_mixed import compute_mixed_rotation, create_position_matrix from .temporal_rope import apply_rotary_emb_qk_real -from .utils import (pool_tokens, modulate) +from .utils import pool_tokens, modulate try: from flash_attn import flash_attn_func @@ -119,6 +119,7 @@ class AsymmetricAttention(nn.Module): softmax_scale: Optional[float] = None, device: Optional[torch.device] = None, attention_mode: str = "sdpa", + rms_norm_func: bool = False, ): super().__init__() @@ -145,6 +146,28 @@ class AsymmetricAttention(nn.Module): # Query and key normalization for stability. assert qk_norm + if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func + try: + from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster + @torch.compiler.disable() #cause NaNs when compiled for some reason + class RMSNorm(FlashTritonRMSNorm): + pass + except: + raise ImportError("Flash Triton RMSNorm not available.") + elif rms_norm_func == "flash_attn": + try: + from flash_attn.ops.rms_norm import RMSNorm as FlashRMSNorm #slightly faster + @torch.compiler.disable() #cause NaNs when compiled for some reason + class RMSNorm(FlashRMSNorm): + pass + except: + raise ImportError("Flash RMSNorm not available.") + elif rms_norm_func == "apex": + from apex.normalization import FusedRMSNorm as ApexRMSNorm + class RMSNorm(ApexRMSNorm): + pass + else: + from .layers import RMSNorm self.q_norm_x = RMSNorm(self.head_dim, device=device) self.k_norm_x = RMSNorm(self.head_dim, device=device) self.q_norm_y = RMSNorm(self.head_dim, device=device) @@ -210,7 +233,6 @@ class AsymmetricAttention(nn.Module): ) return out - @torch.compiler.disable() def run_attention( self, q, @@ -283,6 +305,7 @@ class AsymmetricJointBlock(nn.Module): update_y: bool = True, # Whether to update text tokens in this block. device: Optional[torch.device] = None, attention_mode: str = "sdpa", + rms_norm_func: str = "default", **block_kwargs, ): super().__init__() @@ -304,6 +327,7 @@ class AsymmetricJointBlock(nn.Module): update_y=update_y, device=device, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, **block_kwargs, ) @@ -450,6 +474,7 @@ class AsymmDiTJoint(nn.Module): rope_theta: float = 10000.0, device: Optional[torch.device] = None, attention_mode: str = "sdpa", + rms_norm_func: str = "default", **block_kwargs, ): super().__init__() @@ -518,6 +543,7 @@ class AsymmDiTJoint(nn.Module): update_y=update_y, device=device, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, **block_kwargs, ) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index ccec165..798a7aa 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -1,4 +1,3 @@ -import json from typing import Dict, List, Optional, Union #temporary patch to fix torch compile bug in Windows @@ -35,7 +34,6 @@ except: import torch import torch.utils.data -#from .dit.joint_model.context_parallel import get_cp_rank_size from tqdm import tqdm from comfy.utils import ProgressBar, load_torch_file import comfy.model_management as mm @@ -83,11 +81,11 @@ class T2VSynthMochiModel: *, device: torch.device, offload_device: torch.device, - vae_stats_path: str, dit_checkpoint_path: str, weight_dtype: torch.dtype = torch.float8_e4m3fn, fp8_fastmode: bool = False, attention_mode: str = "sdpa", + rms_norm_func: str = "default", compile_args: Optional[Dict] = None, cublas_ops: Optional[bool] = False, ): @@ -117,6 +115,7 @@ class T2VSynthMochiModel: t5_token_length=256, rope_theta=10000.0, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, ) params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} @@ -171,10 +170,6 @@ class T2VSynthMochiModel: model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) self.dit = model - - vae_stats = json.load(open(vae_stats_path)) - self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) - self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) def get_packed_indices(self, y_mask, **latent_dims): # temporary dummy func for compatibility @@ -233,8 +228,8 @@ class T2VSynthMochiModel: z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device) sample = { - "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], - "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] + "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], + "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] } sample_null = { "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], diff --git a/nodes.py b/nodes.py index 6196927..75b33c7 100644 --- a/nodes.py +++ b/nodes.py @@ -105,6 +105,7 @@ class DownloadAndLoadMochiModel: "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), "cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}), + "rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}), }, } @@ -114,7 +115,7 @@ class DownloadAndLoadMochiModel: CATEGORY = "MochiWrapper" DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface" - def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False): + def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -154,11 +155,11 @@ class DownloadAndLoadMochiModel: model = T2VSynthMochiModel( device=device, offload_device=offload_device, - vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), dit_checkpoint_path=model_path, weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, compile_args=compile_args, cublas_ops=cublas_ops ) @@ -201,6 +202,7 @@ class MochiModelLoader: "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), "cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}), + "rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}), }, } @@ -209,7 +211,7 @@ class MochiModelLoader: FUNCTION = "loadmodel" CATEGORY = "MochiWrapper" - def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False): + def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -226,6 +228,7 @@ class MochiModelLoader: weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, compile_args=compile_args, cublas_ops=cublas_ops ) @@ -749,10 +752,16 @@ class MochiImageEncode: from .mochi_preview.vae.model import apply_tiled B, H, W, C = images.shape - images = images.unsqueeze(0) * 2 - 1 - images = rearrange(images, "t b h w c -> t c b h w") - images = images.to(device) - print(images.shape) + import torchvision.transforms as transforms + normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + input_image_tensor = rearrange(images, 'b h w c -> b c h w') + input_image_tensor = normalize(input_image_tensor).unsqueeze(0) + input_image_tensor = rearrange(input_image_tensor, 'b t c h w -> b c t h w', t=B) + + #images = images.unsqueeze(0).sub_(0.5).div_(0.5) + #images = rearrange(input_image_tensor, "b c t h w -> t c b h w") + images = input_image_tensor.to(device) + encoder.to(device) print("images before encoding", images.shape) with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):