diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index a1621e7..68401e4 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -3,31 +3,22 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from .layers import ( FeedForward, PatchEmbed, - RMSNorm, TimestepEmbedder, ) + from .mod_rmsnorm import modulated_rmsnorm -from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm) -from .rope_mixed import ( - compute_mixed_rotation, - create_position_matrix, -) +from .residual_tanh_gated_rmsnorm import residual_tanh_gated_rmsnorm +from .rope_mixed import compute_mixed_rotation, create_position_matrix from .temporal_rope import apply_rotary_emb_qk_real -from .utils import ( - pool_tokens, - modulate, - pad_and_split_xy, - unify_streams, -) +from .utils import pool_tokens, modulate try: - from flash_attn import flash_attn_varlen_qkvpacked_func + from flash_attn import flash_attn_func FLASH_ATTN_IS_AVAILABLE = True except ImportError: FLASH_ATTN_IS_AVAILABLE = False @@ -45,6 +36,7 @@ backends.append(SDPBackend.EFFICIENT_ATTENTION) backends.append(SDPBackend.MATH) import comfy.model_management as mm +from comfy.ldm.modules.attention import optimized_attention class AttentionPool(nn.Module): def __init__( @@ -103,16 +95,16 @@ class AttentionPool(nn.Module): q = q.unsqueeze(2) # (B, H, 1, head_dim) # Compute attention. - with sdpa_kernel(backends): - x = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, dropout_p=0.0 - ) # (B, H, 1, head_dim) + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0 + ) # (B, H, 1, head_dim) # Concatenate heads and run output. x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) x = self.to_out(x) return x - + +#region Attention class AsymmetricAttention(nn.Module): def __init__( self, @@ -128,6 +120,7 @@ class AsymmetricAttention(nn.Module): softmax_scale: Optional[float] = None, device: Optional[torch.device] = None, attention_mode: str = "sdpa", + rms_norm_func: bool = False, ): super().__init__() @@ -154,6 +147,28 @@ class AsymmetricAttention(nn.Module): # Query and key normalization for stability. assert qk_norm + if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func + try: + from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster + @torch.compiler.disable() #cause NaNs when compiled for some reason + class RMSNorm(FlashTritonRMSNorm): + pass + except: + raise ImportError("Flash Triton RMSNorm not available.") + elif rms_norm_func == "flash_attn": + try: + from flash_attn.ops.rms_norm import RMSNorm as FlashRMSNorm #slightly faster + @torch.compiler.disable() #cause NaNs when compiled for some reason + class RMSNorm(FlashRMSNorm): + pass + except: + raise ImportError("Flash RMSNorm not available.") + elif rms_norm_func == "apex": + from apex.normalization import FusedRMSNorm as ApexRMSNorm + class RMSNorm(ApexRMSNorm): + pass + else: + from .layers import RMSNorm self.q_norm_x = RMSNorm(self.head_dim, device=device) self.k_norm_x = RMSNorm(self.head_dim, device=device) self.q_norm_y = RMSNorm(self.head_dim, device=device) @@ -166,79 +181,23 @@ class AsymmetricAttention(nn.Module): if update_y else nn.Identity() ) - - def run_qkv_y(self, y): - local_heads = self.num_heads - qkv_y = self.qkv_y(y) # (B, L, 3 * dim) - qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) - q_y, k_y, v_y = qkv_y.unbind(2) - return q_y, k_y, v_y - - def prepare_qkv( - self, - x: torch.Tensor, # (B, N, dim_x) - y: torch.Tensor, # (B, L, dim_y) - *, - scale_x: torch.Tensor, - scale_y: torch.Tensor, - rope_cos: torch.Tensor, - rope_sin: torch.Tensor, - valid_token_indices: torch.Tensor, - ): - # Pre-norm for visual features - x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size - - # Process visual features - qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) - #assert qkv_x.dtype == torch.bfloat16 - - # Move QKV dimension to the front. - # B M (3 H d) -> 3 B M H d - B, M, _ = qkv_x.size() - qkv_x = qkv_x.view(B, M, 3, self.num_heads, -1) - qkv_x = qkv_x.permute(2, 0, 1, 3, 4) - - # Process text features - y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) - q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) - q_y = self.q_norm_y(q_y) - k_y = self.k_norm_y(k_y) - - # Split qkv_x into q, k, v - q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) - q_x = self.q_norm_x(q_x) - q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) - k_x = self.k_norm_x(k_x) - k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) - - # Unite streams - qkv = unify_streams( - q_x, - k_x, - v_x, - q_y, - k_y, - v_y, - valid_token_indices, - ) - - return qkv - - def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim): + def flash_attention(self, q, k ,v): + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + b, _, _, dim_head = q.shape with torch.autocast(mm.get_autocast_device(self.device), enabled=False): - out: torch.Tensor = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen_in_batch, + out: torch.Tensor = flash_attn_func( #q: (batch_size, seqlen, nheads, headdim) + q, k, v, dropout_p=0.0, softmax_scale=self.softmax_scale, ) # (total, local_heads, head_dim) - return out.view(total, local_dim) + out = out.permute(0, 2, 1, 3) + return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head) - def sdpa_attention(self, qkv): - q, k, v = qkv.unbind(dim=1) - q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)] + def sdpa_attention(self, q, k, v): + b, _, _, dim_head = q.shape with torch.autocast(mm.get_autocast_device(self.device), enabled=False): with sdpa_kernel(backends): out = F.scaled_dot_product_attention( @@ -249,13 +208,10 @@ class AsymmetricAttention(nn.Module): dropout_p=0.0, is_causal=False ) - return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1) + return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head) - def sage_attention(self, qkv): - #q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) - q, k, v = qkv.unbind(dim=1) - q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)] - + def sage_attention(self, q, k, v): + b, _, _, dim_head = q.shape with torch.autocast(mm.get_autocast_device(self.device), enabled=False): out = sageattn( q, @@ -265,14 +221,9 @@ class AsymmetricAttention(nn.Module): dropout_p=0.0, is_causal=False ) - #print(out.shape) - #out = rearrange(out, 'b h s d -> s (b h d)') - return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1) + return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head) - def comfy_attention(self, qkv): - from comfy.ldm.modules.attention import optimized_attention - q, k, v = qkv.unbind(dim=1) - q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)] + def comfy_attention(self, q, k, v): with torch.autocast(mm.get_autocast_device(self.device), enabled=False): out = optimized_attention( q, @@ -281,41 +232,23 @@ class AsymmetricAttention(nn.Module): heads = self.num_heads, skip_reshape=True ) - return out.squeeze(0) + return out - @torch.compiler.disable() def run_attention( self, - qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) - *, - B: int, - L: int, - M: int, - cu_seqlens: torch.Tensor, - max_seqlen_in_batch: int, - valid_token_indices: torch.Tensor, - ): - local_dim = self.num_heads * self.head_dim - total = qkv.size(0) - + q, + k, + v, + ): if self.attention_mode == "flash_attn": - out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim) + out = self.flash_attention(q, k ,v) elif self.attention_mode == "sdpa": - out = self.sdpa_attention(qkv) + out = self.sdpa_attention(q, k, v) elif self.attention_mode == "sage_attn": - out = self.sage_attention(qkv) + out = self.sage_attention(q, k, v) elif self.attention_mode == "comfy": - out = self.comfy_attention(qkv) - - x, y = pad_and_split_xy(out, valid_token_indices, B, M, L, qkv.dtype) - assert x.size() == (B, M, local_dim) - assert y.size() == (B, L, local_dim) - - x = x.view(B, M, self.num_heads, self.head_dim) - x = x.view(x.size(0), x.size(1), x.size(2) * x.size(3)) - x = self.proj_x(x) # (B, M, dim_x) - y = self.proj_y(y) # (B, L, dim_y) - return x, y + out = self.comfy_attention(q, k, v) + return out def forward( self, @@ -324,47 +257,44 @@ class AsymmetricAttention(nn.Module): *, scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. - packed_indices: Dict[str, torch.Tensor] = None, + num_tokens, **rope_rotation, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass of asymmetric multi-modal attention. - Args: - x: (B, N, dim_x) tensor for visual tokens - y: (B, L, dim_y) tensor of text token features - packed_indices: Dict with keys for Flash Attention - num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + rope_cos = rope_rotation.get("rope_cos") + rope_sin = rope_rotation.get("rope_sin") + + # Pre-norm for visual features + x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size + # Process text features + y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) + q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim) - Returns: - x: (B, N, dim_x) tensor of visual tokens after multi-modal attention - y: (B, L, dim_y) tensor of text token features after multi-modal attention - """ - B, L, _ = y.shape - _, M, _ = x.shape + q_y = self.q_norm_y(q_y) + k_y = self.k_norm_y(k_y) - # Predict a packed QKV tensor from visual and text features. - # Don't checkpoint the all_to_all. - qkv = self.prepare_qkv( - x=x, - y=y, - scale_x=scale_x, - scale_y=scale_y, - rope_cos=rope_rotation.get("rope_cos"), - rope_sin=rope_rotation.get("rope_sin"), - valid_token_indices=packed_indices["valid_token_indices_kv"], - ) # (total <= B * (N + L), 3, local_heads, head_dim) + # Split qkv_x into q, k, v + q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim) + q_x = self.q_norm_x(q_x) + q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) + k_x = self.k_norm_x(k_x) + k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) - x, y = self.run_attention( - qkv, - B=B, - L=L, - M=M, - cu_seqlens=packed_indices["cu_seqlens_kv"], - max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], - valid_token_indices=packed_indices["valid_token_indices_kv"], - ) + q = torch.cat([q_x, q_y[:, :num_tokens]], dim=1).transpose(1, 2) + k = torch.cat([k_x, k_y[:, :num_tokens]], dim=1).transpose(1, 2) + v = torch.cat([v_x, v_y[:, :num_tokens]], dim=1).transpose(1, 2) + + xy = self.run_attention(q, k, v) + + x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1) + x = self.proj_x(x) + o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype) + o[:, :y.shape[1]] = y + + y = self.proj_y(o) return x, y - + +#region Blocks class AsymmetricJointBlock(nn.Module): def __init__( self, @@ -377,6 +307,7 @@ class AsymmetricJointBlock(nn.Module): update_y: bool = True, # Whether to update text tokens in this block. device: Optional[torch.device] = None, attention_mode: str = "sdpa", + rms_norm_func: str = "default", **block_kwargs, ): super().__init__() @@ -401,6 +332,7 @@ class AsymmetricJointBlock(nn.Module): update_y=update_y, device=device, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, **block_kwargs, ) @@ -432,7 +364,8 @@ class AsymmetricJointBlock(nn.Module): c: torch.Tensor, y: torch.Tensor, fastercache_counter: Optional[int] = 0, - fastercache_start_step: Optional[int] = 15, + fastercache_start_step: Optional[int] = 1000, + fastercache_device: Optional[torch.device] = None, **attn_kwargs, ): """Forward pass of a block. @@ -459,10 +392,11 @@ class AsymmetricJointBlock(nn.Module): else: scale_msa_y = mod_y - #fastercache + #region fastercache B = x.shape[0] #print("x", x.shape) #([1, 9540, 3072]) if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B: + print("using fastercache") x_attn = ( self.cached_x_attention[1][:B] + (self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B]) @@ -483,22 +417,24 @@ class AsymmetricJointBlock(nn.Module): **attn_kwargs, ) if fastercache_counter == fastercache_start_step: - self.cached_x_attention = [x_attn, x_attn] - self.cached_y_attention = [y_attn, y_attn] + print("caching attention") + self.cached_x_attention = [x_attn.to(fastercache_device), x_attn.to(fastercache_device)] + self.cached_y_attention = [y_attn.to(fastercache_device), y_attn.to(fastercache_device)] elif fastercache_counter > fastercache_start_step: - self.cached_x_attention[-1].copy_(x_attn) - self.cached_y_attention[-1].copy_(y_attn) + print("updating attention") + self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device)) + self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device)) assert x_attn.size(1) == N x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x) if self.update_y: y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y) - + # MLP block. x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x) if self.update_y: y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y) - + return x, y def ff_block_x(self, x, scale_x, gate_x): @@ -542,7 +478,7 @@ class FinalLayer(nn.Module): x = self.linear(x) return x - +#region Model class AsymmDiTJoint(nn.Module): """ Diffusion model with a Transformer backbone. @@ -570,6 +506,7 @@ class AsymmDiTJoint(nn.Module): rope_theta: float = 10000.0, device: Optional[torch.device] = None, attention_mode: str = "sdpa", + rms_norm_func: str = "default", **block_kwargs, ): super().__init__() @@ -638,6 +575,7 @@ class AsymmDiTJoint(nn.Module): update_y=update_y, device=device, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, **block_kwargs, ) @@ -701,12 +639,11 @@ class AsymmDiTJoint(nn.Module): x: torch.Tensor, sigma: torch.Tensor, fastercache_counter: int, - fastercache_start_step: int, y_feat: List[torch.Tensor], y_mask: List[torch.Tensor], - packed_indices: Dict[str, torch.Tensor] = None, rope_cos: torch.Tensor = None, rope_sin: torch.Tensor = None, + fastercache: Optional[Dict[str, torch.Tensor]] = None, ): """Forward pass of DiT. @@ -715,18 +652,25 @@ class AsymmDiTJoint(nn.Module): sigma: (B,) tensor of noise standard deviations y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) y_mask: List((B, L) boolean tensor indicating which tokens are not padding) - packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. """ B, _, T, H, W = x.shape # Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask. # Have to call sdpa_kernel outside of a torch.compile region. + num_tokens = max(1, torch.sum(y_mask[0]).item()) with sdpa_kernel(backends): x, c, y_feat, rope_cos, rope_sin = self.prepare( x, sigma, y_feat[0], y_mask[0] ) del y_mask + if fastercache is not None: + fastercache_start_step = fastercache["start_step"] + fastercache_device = fastercache["cache_device"] + else: + fastercache_start_step = 1000 + fastercache_device = None + for i, block in enumerate(self.blocks): x, y_feat = block( x, @@ -734,22 +678,24 @@ class AsymmDiTJoint(nn.Module): y_feat, rope_cos=rope_cos, rope_sin=rope_sin, - packed_indices=packed_indices, + num_tokens=num_tokens, fastercache_counter = fastercache_counter, fastercache_start_step = fastercache_start_step, + fastercache_device = fastercache_device, + ) # (B, M, D), (B, L, D) del y_feat # Final layers don't use dense text features. - x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) - x = rearrange( - x, - "B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)", - T=T, - hp=H // self.patch_size, - wp=W // self.patch_size, - p1=self.patch_size, - p2=self.patch_size, - c=self.out_channels, - ) + x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) + + hp = H // self.patch_size + wp = W // self.patch_size + p1 = self.patch_size + p2 = self.patch_size + c = self.out_channels + + x = x.view(B, T, hp, wp, p1, p2, c) + x = x.permute(0, 6, 1, 2, 4, 3, 5) + x = x.reshape(B, c, T, hp * p1, wp * p2) return x diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index e713141..ecc152b 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -1,5 +1,5 @@ -import json from typing import Dict, List, Optional, Union +from einops import rearrange #temporary patch to fix torch compile bug in Windows def patched_write_atomic( @@ -33,11 +33,8 @@ except: pass import torch -import torch.nn.functional as F import torch.utils.data -from einops import rearrange, repeat -#from .dit.joint_model.context_parallel import get_cp_rank_size from tqdm import tqdm from comfy.utils import ProgressBar, load_torch_file import comfy.model_management as mm @@ -95,59 +92,17 @@ def unnormalize_latents( assert z.size(1) == mean.size(0) == std.size(0) return z * std.to(z) + mean.to(z) - - -def compute_packed_indices( - N: int, - text_mask: List[torch.Tensor], -) -> Dict[str, torch.Tensor]: - """ - Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80 - - Args: - N: Number of visual tokens. - text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding. - - Returns: - packed_indices: Dict with keys for Flash Attention: - - valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding) - in the packed sequence. - - cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence. - - max_seqlen_in_batch_kv: int of the maximum sequence length in the batch. - """ - # Create an expanded token mask saying which tokens are valid across both visual and text tokens. - assert N > 0 and len(text_mask) == 1 - text_mask = text_mask[0] - - mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L) - seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) - valid_token_indices = torch.nonzero( - mask.flatten(), as_tuple=False - ).flatten() # up to (B * (N + L),) - - assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,) - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) - ) - max_seqlen_in_batch = seqlens_in_batch.max().item() - - return { - "cu_seqlens_kv": cu_seqlens, - "max_seqlen_in_batch_kv": max_seqlen_in_batch, - "valid_token_indices_kv": valid_token_indices, - } - class T2VSynthMochiModel: def __init__( self, *, device: torch.device, offload_device: torch.device, - vae_stats_path: str, dit_checkpoint_path: str, weight_dtype: torch.dtype = torch.float8_e4m3fn, fp8_fastmode: bool = False, attention_mode: str = "sdpa", + rms_norm_func: str = "default", compile_args: Optional[Dict] = None, cublas_ops: Optional[bool] = False, ): @@ -177,11 +132,23 @@ class T2VSynthMochiModel: t5_token_length=256, rope_theta=10000.0, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, ) params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} logging.info(f"Loading model state_dict from {dit_checkpoint_path}...") dit_sd = load_torch_file(dit_checkpoint_path) + + #comfy format + prefix = "model.diffusion_model." + first_key = next(iter(dit_sd), None) + if first_key and first_key.startswith(prefix): + new_dit_sd = { + key[len(prefix):] if key.startswith(prefix) else key: value + for key, value in dit_sd.items() + } + dit_sd = new_dit_sd + if "gguf" in dit_checkpoint_path.lower(): logging.info("Loading GGUF model state_dict...") from .. import mz_gguf_loader @@ -220,18 +187,10 @@ class T2VSynthMochiModel: model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) self.dit = model - - vae_stats = json.load(open(vae_stats_path)) - self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) - self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) - def get_packed_indices(self, y_mask, *, lT, lW, lH): - patch_size = 2 - N = lT * lH * lW // (patch_size**2) - assert len(y_mask) == 1 - packed_indices = compute_packed_indices(N, y_mask) - self.move_to_device_(packed_indices) - return packed_indices + def get_packed_indices(self, y_mask, **latent_dims): + # temporary dummy func for compatibility + return [] def move_to_device_(self, sample): if isinstance(sample, dict): @@ -243,7 +202,7 @@ class T2VSynthMochiModel: torch.manual_seed(args["seed"]) torch.cuda.manual_seed(args["seed"]) - generator = torch.Generator(device=self.device) + generator = torch.Generator(device=torch.device("cpu")) generator.manual_seed(args["seed"]) num_frames = args["num_frames"] @@ -275,50 +234,49 @@ class T2VSynthMochiModel: T = (num_frames - 1) // temporal_downsample + 1 H = height // spatial_downsample W = width // spatial_downsample - latent_dims = dict(lT=T, lW=W, lH=H) z = torch.randn( (B, C, T, H, W), - device=self.device, + device=torch.device("cpu"), generator=generator, dtype=torch.float32, - ) + ).to(self.device) if in_samples is not None: z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device) sample = { - "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], - "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] + "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], + "y_feat": [args["positive_embeds"]["embeds"].to(self.device)], + "fastercache": args["fastercache"] if args["fastercache"] is not None else None } sample_null = { "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], - "y_feat": [args["negative_embeds"]["embeds"].to(self.device)] - } + "y_feat": [args["negative_embeds"]["embeds"].to(self.device)], + "fastercache": args["fastercache"] if args["fastercache"] is not None else None + } - sample["packed_indices"] = self.get_packed_indices( - sample["y_mask"], **latent_dims - ) - sample_null["packed_indices"] = self.get_packed_indices( - sample_null["y_mask"], **latent_dims - ) - self.use_fastercache = True + if args["fastercache"]: + self.fastercache_start_step = args["fastercache"]["start_step"] + self.fastercache_lf_step = args["fastercache"]["lf_step"] + self.fastercache_hf_step = args["fastercache"]["hf_step"] + else: + self.fastercache_start_step = 1000 self.fastercache_counter = 0 - self.fastercache_start_step = 15 - self.fastercache_lf_step = 40 - self.fastercache_hf_step = 30 def model_fn(*, z, sigma, cfg_scale): nonlocal sample, sample_null - if self.use_fastercache: + if args["fastercache"]: self.fastercache_counter+=1 if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0: - out_cond = self.dit(z, sigma,self.fastercache_counter, self.fastercache_start_step, **sample) + out_cond = self.dit( + z, + sigma, + self.fastercache_counter, + **sample) (bb, cc, tt, hh, ww) = out_cond.shape cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) lf_c, hf_c = fft(cond.float()) - #lf_step = 40 - #hf_step = 30 if self.fastercache_counter <= self.fastercache_lf_step: self.delta_lf = self.delta_lf * 1.1 if self.fastercache_counter >= self.fastercache_hf_step: @@ -334,12 +292,19 @@ class T2VSynthMochiModel: return recovered_uncond + cfg_scale * (out_cond - recovered_uncond) else: - out_cond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample) - out_uncond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample_null) - #print("out_cond.shape",out_cond.shape) #([1, 12, 3, 60, 106]) + out_cond = self.dit( + z, + sigma, + self.fastercache_counter, + **sample) + + out_uncond = self.dit( + z, + sigma, + self.fastercache_counter, + **sample_null) if self.fastercache_counter >= self.fastercache_start_step + 1: - (bb, cc, tt, hh, ww) = out_cond.shape cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) @@ -352,7 +317,6 @@ class T2VSynthMochiModel: return out_uncond + cfg_scale * (out_cond - out_uncond) - comfy_pbar = ProgressBar(sample_steps) if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: @@ -373,13 +337,19 @@ class T2VSynthMochiModel: sigma=torch.full([B], sigma, device=z.device), cfg_scale=cfg_schedule[i], ) - pred = pred.to(z) - z = z + dsigma * pred + z = z + dsigma * pred.to(z) if callback is not None: callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps) else: comfy_pbar.update(1) + + if args["fastercache"] is not None: + for block in self.dit.blocks: + if (hasattr, block, "cached_x_attention") and block.cached_x_attention is not None: + block.cached_x_attention = None + block.cached_y_attention = None self.dit.to(self.offload_device) + mm.soft_empty_cache() logging.info(f"samples shape: {z.shape}") return z diff --git a/nodes.py b/nodes.py index 1546747..e6b38cf 100644 --- a/nodes.py +++ b/nodes.py @@ -446,6 +446,36 @@ class MochiTextEncode: } return (t5_embeds, clip,) +#region FasterCache +class MochiFasterCache: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_step": ("INT", {"default": 10, "min": 0, "max": 1024, "step": 1, "tooltip": "The step to start caching, sigma schedule should be adjusted accordingly"}), + "hf_step": ("INT", {"default": 22, "min": 0, "max": 1024, "step": 1}), + "lf_step": ("INT", {"default": 28, "min": 0, "max": 1024, "step": 1}), + "cache_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}), + }, + } + + RETURN_TYPES = ("FASTERCACHEARGS",) + RETURN_NAMES = ("fastercache", ) + FUNCTION = "args" + CATEGORY = "CogVideoWrapper" + DESCRIPTION = "FasterCache (https://github.com/Vchitect/FasterCache) settings for the MochiWrapper" + + def args(self, start_step, hf_step, lf_step, cache_device): + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + fastercache = { + "start_step" : start_step, + "hf_step" : hf_step, + "lf_step" : lf_step, + "cache_device" : device if cache_device == "main_device" else offload_device + } + return (fastercache,) + #region Sampler class MochiSampler: @classmethod @@ -466,6 +496,7 @@ class MochiSampler: "cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}), "opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}), "samples": ("LATENT", ), + "fastercache": ("FASTERCACHEARGS", {"tooltip": "Optional FasterCache settings"}), } } @@ -474,7 +505,7 @@ class MochiSampler: FUNCTION = "process" CATEGORY = "MochiWrapper" - def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None): + def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None, fastercache=None): mm.unload_all_models() mm.soft_empty_cache() @@ -517,6 +548,7 @@ class MochiSampler: "negative_embeds": negative, "seed": seed, "samples": samples["samples"] if samples is not None else None, + "fastercache": fastercache } latents = model.run(args) @@ -848,7 +880,8 @@ NODE_CLASS_MAPPINGS = { "MochiTorchCompileSettings": MochiTorchCompileSettings, "MochiImageEncode": MochiImageEncode, "MochiLatentPreview": MochiLatentPreview, - "MochiSigmaSchedule": MochiSigmaSchedule + "MochiSigmaSchedule": MochiSigmaSchedule, + "MochiFasterCache": MochiFasterCache } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadMochiModel": "(Down)load Mochi Model", @@ -862,5 +895,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "MochiTorchCompileSettings": "Mochi Torch Compile Settings", "MochiImageEncode": "Mochi Image Encode", "MochiLatentPreview": "Mochi Latent Preview", - "MochiSigmaSchedule": "Mochi Sigma Schedule" + "MochiSigmaSchedule": "Mochi Sigma Schedule", + "MochiFasterCache": "Mochi Faster Cache" }