Add different RMSNorm functions for testing

Initial testing for me shows that the RMSNorm from flash_attn.ops.triton.layer_norm is ~8-10% faster, apex is untested as I don't currently have it installed.
This commit is contained in:
kijai 2024-11-04 14:11:58 +02:00
parent fd4a02e6a6
commit 56b5dbbf82
4 changed files with 51 additions and 25 deletions

View File

@ -1,4 +0,0 @@
{
"mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285],
"std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041]
}

View File

@ -7,15 +7,15 @@ import torch.nn.functional as F
from .layers import (
FeedForward,
PatchEmbed,
RMSNorm,
TimestepEmbedder,
)
from .mod_rmsnorm import modulated_rmsnorm
from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm)
from .rope_mixed import (compute_mixed_rotation, create_position_matrix)
from .residual_tanh_gated_rmsnorm import residual_tanh_gated_rmsnorm
from .rope_mixed import compute_mixed_rotation, create_position_matrix
from .temporal_rope import apply_rotary_emb_qk_real
from .utils import (pool_tokens, modulate)
from .utils import pool_tokens, modulate
try:
from flash_attn import flash_attn_func
@ -119,6 +119,7 @@ class AsymmetricAttention(nn.Module):
softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
rms_norm_func: bool = False,
):
super().__init__()
@ -145,6 +146,28 @@ class AsymmetricAttention(nn.Module):
# Query and key normalization for stability.
assert qk_norm
if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func
try:
from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster
@torch.compiler.disable() #cause NaNs when compiled for some reason
class RMSNorm(FlashTritonRMSNorm):
pass
except:
raise ImportError("Flash Triton RMSNorm not available.")
elif rms_norm_func == "flash_attn":
try:
from flash_attn.ops.rms_norm import RMSNorm as FlashRMSNorm #slightly faster
@torch.compiler.disable() #cause NaNs when compiled for some reason
class RMSNorm(FlashRMSNorm):
pass
except:
raise ImportError("Flash RMSNorm not available.")
elif rms_norm_func == "apex":
from apex.normalization import FusedRMSNorm as ApexRMSNorm
class RMSNorm(ApexRMSNorm):
pass
else:
from .layers import RMSNorm
self.q_norm_x = RMSNorm(self.head_dim, device=device)
self.k_norm_x = RMSNorm(self.head_dim, device=device)
self.q_norm_y = RMSNorm(self.head_dim, device=device)
@ -210,7 +233,6 @@ class AsymmetricAttention(nn.Module):
)
return out
@torch.compiler.disable()
def run_attention(
self,
q,
@ -283,6 +305,7 @@ class AsymmetricJointBlock(nn.Module):
update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
rms_norm_func: str = "default",
**block_kwargs,
):
super().__init__()
@ -304,6 +327,7 @@ class AsymmetricJointBlock(nn.Module):
update_y=update_y,
device=device,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
**block_kwargs,
)
@ -450,6 +474,7 @@ class AsymmDiTJoint(nn.Module):
rope_theta: float = 10000.0,
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
rms_norm_func: str = "default",
**block_kwargs,
):
super().__init__()
@ -518,6 +543,7 @@ class AsymmDiTJoint(nn.Module):
update_y=update_y,
device=device,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
**block_kwargs,
)

View File

@ -1,4 +1,3 @@
import json
from typing import Dict, List, Optional, Union
#temporary patch to fix torch compile bug in Windows
@ -35,7 +34,6 @@ except:
import torch
import torch.utils.data
#from .dit.joint_model.context_parallel import get_cp_rank_size
from tqdm import tqdm
from comfy.utils import ProgressBar, load_torch_file
import comfy.model_management as mm
@ -83,11 +81,11 @@ class T2VSynthMochiModel:
*,
device: torch.device,
offload_device: torch.device,
vae_stats_path: str,
dit_checkpoint_path: str,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
fp8_fastmode: bool = False,
attention_mode: str = "sdpa",
rms_norm_func: str = "default",
compile_args: Optional[Dict] = None,
cublas_ops: Optional[bool] = False,
):
@ -117,6 +115,7 @@ class T2VSynthMochiModel:
t5_token_length=256,
rope_theta=10000.0,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
)
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
@ -171,10 +170,6 @@ class T2VSynthMochiModel:
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
self.dit = model
vae_stats = json.load(open(vae_stats_path))
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
def get_packed_indices(self, y_mask, **latent_dims):
# temporary dummy func for compatibility
@ -233,8 +228,8 @@ class T2VSynthMochiModel:
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device)
sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
}
sample_null = {
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],

View File

@ -105,6 +105,7 @@ class DownloadAndLoadMochiModel:
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}),
"rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}),
},
}
@ -114,7 +115,7 @@ class DownloadAndLoadMochiModel:
CATEGORY = "MochiWrapper"
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -154,11 +155,11 @@ class DownloadAndLoadMochiModel:
model = T2VSynthMochiModel(
device=device,
offload_device=offload_device,
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
dit_checkpoint_path=model_path,
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
compile_args=compile_args,
cublas_ops=cublas_ops
)
@ -201,6 +202,7 @@ class MochiModelLoader:
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}),
"rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}),
},
}
@ -209,7 +211,7 @@ class MochiModelLoader:
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -226,6 +228,7 @@ class MochiModelLoader:
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
compile_args=compile_args,
cublas_ops=cublas_ops
)
@ -749,10 +752,16 @@ class MochiImageEncode:
from .mochi_preview.vae.model import apply_tiled
B, H, W, C = images.shape
images = images.unsqueeze(0) * 2 - 1
images = rearrange(images, "t b h w c -> t c b h w")
images = images.to(device)
print(images.shape)
import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
input_image_tensor = rearrange(images, 'b h w c -> b c h w')
input_image_tensor = normalize(input_image_tensor).unsqueeze(0)
input_image_tensor = rearrange(input_image_tensor, 'b t c h w -> b c t h w', t=B)
#images = images.unsqueeze(0).sub_(0.5).div_(0.5)
#images = rearrange(input_image_tensor, "b c t h w -> t c b h w")
images = input_image_tensor.to(device)
encoder.to(device)
print("images before encoding", images.shape)
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):