From 3cf9289e08c601ab65f93f4109b39e06d10410cd Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 3 Nov 2024 01:47:34 +0200 Subject: [PATCH 1/9] cleanup code --- mochi_preview/dit/joint_model/layers.py | 22 --------- mochi_preview/t2v_synth_mochi.py | 59 ++++++++++++------------- 2 files changed, 29 insertions(+), 52 deletions(-) diff --git a/mochi_preview/dit/joint_model/layers.py b/mochi_preview/dit/joint_model/layers.py index aa40a67..b41d66c 100644 --- a/mochi_preview/dit/joint_model/layers.py +++ b/mochi_preview/dit/joint_model/layers.py @@ -62,28 +62,6 @@ class TimestepEmbedder(nn.Module): return t_emb -class PooledCaptionEmbedder(nn.Module): - def __init__( - self, - caption_feature_dim: int, - hidden_size: int, - *, - bias: bool = True, - device: Optional[torch.device] = None, - ): - super().__init__() - self.caption_feature_dim = caption_feature_dim - self.hidden_size = hidden_size - self.mlp = nn.Sequential( - nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=bias, device=device), - ) - - def forward(self, x): - return self.mlp(x) - - class FeedForward(nn.Module): def __init__( self, diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index a1c6fb7..c5c99cd 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -286,41 +286,40 @@ class T2VSynthMochiModel: sample_null["y_mask"], **latent_dims ) - def model_fn(*, z, sigma, cfg_scale): - self.dit.to(self.device) - if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: - autocast_dtype = torch.float16 - else: - autocast_dtype = torch.bfloat16 - - nonlocal sample, sample_null - with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): - if cfg_scale > 1.0: - out_cond = self.dit(z, sigma, **sample) - out_uncond = self.dit(z, sigma, **sample_null) - else: - out_cond = self.dit(z, sigma, **sample) - return out_cond + self.dit.to(self.device) + if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: + autocast_dtype = torch.float16 + else: + autocast_dtype = torch.bfloat16 + def model_fn(*, z, sigma, cfg_scale): + nonlocal sample, sample_null + if cfg_scale > 1.0: + out_cond = self.dit(z, sigma, **sample) + out_uncond = self.dit(z, sigma, **sample_null) + else: + out_cond = self.dit(z, sigma, **sample) + return out_cond return out_uncond + cfg_scale * (out_cond - out_uncond) comfy_pbar = ProgressBar(sample_steps) - for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): - sigma = sigma_schedule[i] - dsigma = sigma - sigma_schedule[i + 1] + with torch.autocast(mm.get_autocast_device(self.device), dtype=autocast_dtype): + for i in tqdm(range(0, sample_steps), desc="Processing Samples", total=sample_steps): + sigma = sigma_schedule[i] + dsigma = sigma - sigma_schedule[i + 1] - # `pred` estimates `z_0 - eps`. - pred = model_fn( - z=z, - sigma=torch.full([B], sigma, device=z.device), - cfg_scale=cfg_schedule[i], - ) - pred = pred.to(z) - z = z + dsigma * pred - if callback is not None: - callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps) - else: - comfy_pbar.update(1) + # `pred` estimates `z_0 - eps`. + pred = model_fn( + z=z, + sigma=torch.full([B], sigma, device=z.device), + cfg_scale=cfg_schedule[i], + ) + pred = pred.to(z) + z = z + dsigma * pred + if callback is not None: + callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps) + else: + comfy_pbar.update(1) self.dit.to(self.offload_device) logging.info(f"samples shape: {z.shape}") From 0dc011d1b678bfcd334b9a03369f1bc5e7de9ce2 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 3 Nov 2024 18:36:23 +0200 Subject: [PATCH 2/9] cleanup and align more to comfy code, switch to using cpu seed as well --- .../dit/joint_model/asymm_models_joint.py | 253 ++++++------------ .../dit/joint_model/context_parallel.py | 163 ----------- mochi_preview/dit/joint_model/layers.py | 2 - mochi_preview/dit/joint_model/rope_mixed.py | 4 +- mochi_preview/dit/joint_model/utils.py | 93 +------ mochi_preview/t2v_synth_mochi.py | 84 ++---- mochi_preview/vae/cp_conv.py | 152 ----------- nodes.py | 1 + 8 files changed, 100 insertions(+), 652 deletions(-) delete mode 100644 mochi_preview/dit/joint_model/context_parallel.py delete mode 100644 mochi_preview/vae/cp_conv.py diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index f4fdf6a..502a2df 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from .layers import ( FeedForward, @@ -14,20 +13,12 @@ from .layers import ( from .mod_rmsnorm import modulated_rmsnorm from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm) -from .rope_mixed import ( - compute_mixed_rotation, - create_position_matrix, -) +from .rope_mixed import (compute_mixed_rotation, create_position_matrix) from .temporal_rope import apply_rotary_emb_qk_real -from .utils import ( - pool_tokens, - modulate, - pad_and_split_xy, - unify_streams, -) +from .utils import (pool_tokens, modulate) try: - from flash_attn import flash_attn_varlen_qkvpacked_func + from flash_attn import flash_attn_func FLASH_ATTN_IS_AVAILABLE = True except ImportError: FLASH_ATTN_IS_AVAILABLE = False @@ -45,6 +36,7 @@ backends.append(SDPBackend.EFFICIENT_ATTENTION) backends.append(SDPBackend.MATH) import comfy.model_management as mm +from comfy.ldm.modules.attention import optimized_attention class AttentionPool(nn.Module): def __init__( @@ -103,10 +95,9 @@ class AttentionPool(nn.Module): q = q.unsqueeze(2) # (B, H, 1, head_dim) # Compute attention. - with sdpa_kernel(backends): - x = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, dropout_p=0.0 - ) # (B, H, 1, head_dim) + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0 + ) # (B, H, 1, head_dim) # Concatenate heads and run output. x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) @@ -166,79 +157,23 @@ class AsymmetricAttention(nn.Module): if update_y else nn.Identity() ) - - def run_qkv_y(self, y): - local_heads = self.num_heads - qkv_y = self.qkv_y(y) # (B, L, 3 * dim) - qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) - q_y, k_y, v_y = qkv_y.unbind(2) - return q_y, k_y, v_y - - def prepare_qkv( - self, - x: torch.Tensor, # (B, N, dim_x) - y: torch.Tensor, # (B, L, dim_y) - *, - scale_x: torch.Tensor, - scale_y: torch.Tensor, - rope_cos: torch.Tensor, - rope_sin: torch.Tensor, - valid_token_indices: torch.Tensor, - ): - # Pre-norm for visual features - x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size - - # Process visual features - qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) - #assert qkv_x.dtype == torch.bfloat16 - - # Move QKV dimension to the front. - # B M (3 H d) -> 3 B M H d - B, M, _ = qkv_x.size() - qkv_x = qkv_x.view(B, M, 3, self.num_heads, -1) - qkv_x = qkv_x.permute(2, 0, 1, 3, 4) - - # Process text features - y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) - q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim) - q_y = self.q_norm_y(q_y) - k_y = self.k_norm_y(k_y) - - # Split qkv_x into q, k, v - q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim) - q_x = self.q_norm_x(q_x) - q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) - k_x = self.k_norm_x(k_x) - k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) - - # Unite streams - qkv = unify_streams( - q_x, - k_x, - v_x, - q_y, - k_y, - v_y, - valid_token_indices, - ) - - return qkv - - def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim): + def flash_attention(self, q, k ,v): + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + b, _, _, dim_head = q.shape with torch.autocast(mm.get_autocast_device(self.device), enabled=False): - out: torch.Tensor = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen_in_batch, + out: torch.Tensor = flash_attn_func( #q: (batch_size, seqlen, nheads, headdim) + q, k, v, dropout_p=0.0, softmax_scale=self.softmax_scale, ) # (total, local_heads, head_dim) - return out.view(total, local_dim) + out = out.permute(0, 2, 1, 3) + return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head) - def sdpa_attention(self, qkv): - q, k, v = qkv.unbind(dim=1) - q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)] + def sdpa_attention(self, q, k, v): + b, _, _, dim_head = q.shape with torch.autocast(mm.get_autocast_device(self.device), enabled=False): with sdpa_kernel(backends): out = F.scaled_dot_product_attention( @@ -249,13 +184,10 @@ class AsymmetricAttention(nn.Module): dropout_p=0.0, is_causal=False ) - return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1) + return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head) - def sage_attention(self, qkv): - #q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) - q, k, v = qkv.unbind(dim=1) - q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)] - + def sage_attention(self, q, k, v): + b, _, _, dim_head = q.shape with torch.autocast(mm.get_autocast_device(self.device), enabled=False): out = sageattn( q, @@ -265,14 +197,9 @@ class AsymmetricAttention(nn.Module): dropout_p=0.0, is_causal=False ) - #print(out.shape) - #out = rearrange(out, 'b h s d -> s (b h d)') - return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1) + return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head) - def comfy_attention(self, qkv): - from comfy.ldm.modules.attention import optimized_attention - q, k, v = qkv.unbind(dim=1) - q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)] + def comfy_attention(self, q, k, v): with torch.autocast(mm.get_autocast_device(self.device), enabled=False): out = optimized_attention( q, @@ -281,41 +208,24 @@ class AsymmetricAttention(nn.Module): heads = self.num_heads, skip_reshape=True ) - return out.squeeze(0) + return out @torch.compiler.disable() def run_attention( self, - qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) - *, - B: int, - L: int, - M: int, - cu_seqlens: torch.Tensor, - max_seqlen_in_batch: int, - valid_token_indices: torch.Tensor, - ): - local_dim = self.num_heads * self.head_dim - total = qkv.size(0) - + q, + k, + v, + ): if self.attention_mode == "flash_attn": - out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim) + out = self.flash_attention(q, k ,v) elif self.attention_mode == "sdpa": - out = self.sdpa_attention(qkv) + out = self.sdpa_attention(q, k, v) elif self.attention_mode == "sage_attn": - out = self.sage_attention(qkv) + out = self.sage_attention(q, k, v) elif self.attention_mode == "comfy": - out = self.comfy_attention(qkv) - - x, y = pad_and_split_xy(out, valid_token_indices, B, M, L, qkv.dtype) - assert x.size() == (B, M, local_dim) - assert y.size() == (B, L, local_dim) - - x = x.view(B, M, self.num_heads, self.head_dim) - x = x.view(x.size(0), x.size(1), x.size(2) * x.size(3)) - x = self.proj_x(x) # (B, M, dim_x) - y = self.proj_y(y) # (B, L, dim_y) - return x, y + out = self.comfy_attention(q, k, v) + return out def forward( self, @@ -324,45 +234,41 @@ class AsymmetricAttention(nn.Module): *, scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. - packed_indices: Dict[str, torch.Tensor] = None, + num_tokens, **rope_rotation, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward pass of asymmetric multi-modal attention. - Args: - x: (B, N, dim_x) tensor for visual tokens - y: (B, L, dim_y) tensor of text token features - packed_indices: Dict with keys for Flash Attention - num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens + rope_cos = rope_rotation.get("rope_cos") + rope_sin = rope_rotation.get("rope_sin") + + # Pre-norm for visual features + x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size + # Process text features + y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) + q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim) - Returns: - x: (B, N, dim_x) tensor of visual tokens after multi-modal attention - y: (B, L, dim_y) tensor of text token features after multi-modal attention - """ - B, L, _ = y.shape - _, M, _ = x.shape + q_y = self.q_norm_y(q_y) + k_y = self.k_norm_y(k_y) - # Predict a packed QKV tensor from visual and text features. - # Don't checkpoint the all_to_all. - qkv = self.prepare_qkv( - x=x, - y=y, - scale_x=scale_x, - scale_y=scale_y, - rope_cos=rope_rotation.get("rope_cos"), - rope_sin=rope_rotation.get("rope_sin"), - valid_token_indices=packed_indices["valid_token_indices_kv"], - ) # (total <= B * (N + L), 3, local_heads, head_dim) + # Split qkv_x into q, k, v + q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim) + q_x = self.q_norm_x(q_x) + q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin) + k_x = self.k_norm_x(k_x) + k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin) - x, y = self.run_attention( - qkv, - B=B, - L=L, - M=M, - cu_seqlens=packed_indices["cu_seqlens_kv"], - max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], - valid_token_indices=packed_indices["valid_token_indices_kv"], - ) + q = torch.cat([q_x, q_y[:, :num_tokens]], dim=1).transpose(1, 2) + k = torch.cat([k_x, k_y[:, :num_tokens]], dim=1).transpose(1, 2) + v = torch.cat([v_x, v_y[:, :num_tokens]], dim=1).transpose(1, 2) + + xy = self.run_attention(q, k, v) + + x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1) + x = self.proj_x(x) + o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype) + o[:, :y.shape[1]] = y + + y = self.proj_y(o) return x, y class AsymmetricJointBlock(nn.Module): @@ -453,7 +359,7 @@ class AsymmetricJointBlock(nn.Module): scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1) else: scale_msa_y = mod_y - + # Self-attention block. x_attn, y_attn = self.attn( x, @@ -467,12 +373,12 @@ class AsymmetricJointBlock(nn.Module): x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x) if self.update_y: y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y) - + # MLP block. x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x) if self.update_y: y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y) - + return x, y def ff_block_x(self, x, scale_x, gate_x): @@ -676,7 +582,6 @@ class AsymmDiTJoint(nn.Module): sigma: torch.Tensor, y_feat: List[torch.Tensor], y_mask: List[torch.Tensor], - packed_indices: Dict[str, torch.Tensor] = None, rope_cos: torch.Tensor = None, rope_sin: torch.Tensor = None, ): @@ -687,18 +592,18 @@ class AsymmDiTJoint(nn.Module): sigma: (B,) tensor of noise standard deviations y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) y_mask: List((B, L) boolean tensor indicating which tokens are not padding) - packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. """ B, _, T, H, W = x.shape # Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask. # Have to call sdpa_kernel outside of a torch.compile region. + num_tokens = max(1, torch.sum(y_mask[0]).item()) with sdpa_kernel(backends): x, c, y_feat, rope_cos, rope_sin = self.prepare( x, sigma, y_feat[0], y_mask[0] ) del y_mask - + for i, block in enumerate(self.blocks): x, y_feat = block( x, @@ -706,20 +611,20 @@ class AsymmDiTJoint(nn.Module): y_feat, rope_cos=rope_cos, rope_sin=rope_sin, - packed_indices=packed_indices, + num_tokens=num_tokens, ) # (B, M, D), (B, L, D) del y_feat # Final layers don't use dense text features. - x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) - x = rearrange( - x, - "B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)", - T=T, - hp=H // self.patch_size, - wp=W // self.patch_size, - p1=self.patch_size, - p2=self.patch_size, - c=self.out_channels, - ) + x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) + + hp = H // self.patch_size + wp = W // self.patch_size + p1 = self.patch_size + p2 = self.patch_size + c = self.out_channels + + x = x.view(B, T, hp, wp, p1, p2, c) + x = x.permute(0, 6, 1, 2, 4, 3, 5) + x = x.reshape(B, c, T, hp * p1, wp * p2) return x diff --git a/mochi_preview/dit/joint_model/context_parallel.py b/mochi_preview/dit/joint_model/context_parallel.py deleted file mode 100644 index d93145d..0000000 --- a/mochi_preview/dit/joint_model/context_parallel.py +++ /dev/null @@ -1,163 +0,0 @@ -import torch -import torch.distributed as dist -from einops import rearrange - -_CONTEXT_PARALLEL_GROUP = None -_CONTEXT_PARALLEL_RANK = None -_CONTEXT_PARALLEL_GROUP_SIZE = None -_CONTEXT_PARALLEL_GROUP_RANKS = None - - -def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor: - if not _CONTEXT_PARALLEL_GROUP: - return x - - cp_rank, cp_size = get_cp_rank_size() - return x.tensor_split(cp_size, dim=dim)[cp_rank] - - -def set_cp_group(cp_group, ranks, global_rank): - global \ - _CONTEXT_PARALLEL_GROUP, \ - _CONTEXT_PARALLEL_RANK, \ - _CONTEXT_PARALLEL_GROUP_SIZE, \ - _CONTEXT_PARALLEL_GROUP_RANKS - if _CONTEXT_PARALLEL_GROUP is not None: - raise RuntimeError("CP group already initialized.") - _CONTEXT_PARALLEL_GROUP = cp_group - _CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group) - _CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group) - _CONTEXT_PARALLEL_GROUP_RANKS = ranks - - assert ( - _CONTEXT_PARALLEL_RANK == ranks.index(global_rank) - ), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} " - assert _CONTEXT_PARALLEL_GROUP_SIZE == len( - ranks - ), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})" - - -def get_cp_group(): - if _CONTEXT_PARALLEL_GROUP is None: - raise RuntimeError("CP group not initialized") - return _CONTEXT_PARALLEL_GROUP - - -def is_cp_active(): - return _CONTEXT_PARALLEL_GROUP is not None - - -def get_cp_rank_size(): - if _CONTEXT_PARALLEL_GROUP: - return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE - else: - return 0, 1 - - -class AllGatherIntoTensorFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup): - ctx.reduce_dtype = reduce_dtype - ctx.group = group - ctx.batch_size = x.size(0) - group_size = dist.get_world_size(group) - - x = x.contiguous() - output = torch.empty( - group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device - ) - dist.all_gather_into_tensor(output, x, group=group) - return output - - -def all_gather(tensor: torch.Tensor) -> torch.Tensor: - if not _CONTEXT_PARALLEL_GROUP: - return tensor - - return AllGatherIntoTensorFunction.apply( - tensor, torch.float32, _CONTEXT_PARALLEL_GROUP - ) - - -@torch.compiler.disable() -def _all_to_all_single(output, input, group): - # Disable compilation since torch compile changes contiguity. - assert input.is_contiguous(), "Input tensor must be contiguous." - assert output.is_contiguous(), "Output tensor must be contiguous." - return dist.all_to_all_single(output, input, group=group) - - -class CollectTokens(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int): - """Redistribute heads and receive tokens. - - Args: - qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim] - - Returns: - qkv: shape: [3, B, N, local_heads, head_dim] - - where M is the number of local tokens, - N = cp_size * M is the number of global tokens, - local_heads = num_heads // cp_size is the number of local heads. - """ - ctx.group = group - ctx.num_heads = num_heads - cp_size = dist.get_world_size(group) - assert num_heads % cp_size == 0 - ctx.local_heads = num_heads // cp_size - - qkv = rearrange( - qkv, - "B M (qkv G h d) -> G M h B (qkv d)", - qkv=3, - G=cp_size, - h=ctx.local_heads, - ).contiguous() - - output_chunks = torch.empty_like(qkv) - _all_to_all_single(output_chunks, qkv, group=group) - - return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3) - - -def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor: - if not _CONTEXT_PARALLEL_GROUP: - # Move QKV dimension to the front. - # B M (3 H d) -> 3 B M H d - B, M, _ = x.size() - x = x.view(B, M, 3, num_heads, -1) - return x.permute(2, 0, 1, 3, 4) - - return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads) - - -class CollectHeads(torch.autograd.Function): - @staticmethod - def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup): - """Redistribute tokens and receive heads. - - Args: - x: Output of attention. Shape: [B, N, local_heads, head_dim] - - Returns: - Shape: [B, M, num_heads * head_dim] - """ - ctx.group = group - ctx.local_heads = x.size(2) - ctx.head_dim = x.size(3) - group_size = dist.get_world_size(group) - x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous() - output = torch.empty_like(x) - _all_to_all_single(output, x, group=group) - del x - return rearrange(output, "G h M B D -> B M (G h D)") - - -def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor: - if not _CONTEXT_PARALLEL_GROUP: - # Merge heads. - return x.view(x.size(0), x.size(1), x.size(2) * x.size(3)) - - return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP) diff --git a/mochi_preview/dit/joint_model/layers.py b/mochi_preview/dit/joint_model/layers.py index b41d66c..9d66921 100644 --- a/mochi_preview/dit/joint_model/layers.py +++ b/mochi_preview/dit/joint_model/layers.py @@ -130,8 +130,6 @@ class PatchEmbed(nn.Module): x = F.pad(x, (0, pad_w, 0, pad_h)) x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T) - #print("x",x.dtype, x.device) - #print(self.proj.weight.dtype, self.proj.weight.device) x = self.proj(x) # Flatten temporal and spatial dimensions. diff --git a/mochi_preview/dit/joint_model/rope_mixed.py b/mochi_preview/dit/joint_model/rope_mixed.py index f2952bd..d0102dc 100644 --- a/mochi_preview/dit/joint_model/rope_mixed.py +++ b/mochi_preview/dit/joint_model/rope_mixed.py @@ -1,4 +1,4 @@ -import functools +#import functools import math import torch @@ -21,7 +21,7 @@ def centers(start: float, stop, num, dtype=None, device=None): return (edges[:-1] + edges[1:]) / 2 -@functools.lru_cache(maxsize=1) +#@functools.lru_cache(maxsize=1) def create_position_matrix( T: int, pH: int, diff --git a/mochi_preview/dit/joint_model/utils.py b/mochi_preview/dit/joint_model/utils.py index 85fd2df..0bcfbd3 100644 --- a/mochi_preview/dit/joint_model/utils.py +++ b/mochi_preview/dit/joint_model/utils.py @@ -28,95 +28,4 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch. mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1) pooled = (x * mask).sum(dim=1, keepdim=keepdim) return pooled - - -class PadSplitXY(torch.autograd.Function): - """ - Merge heads, pad and extract visual and text tokens, - and split along the sequence length. - """ - - @staticmethod - def forward( - ctx, - xy: torch.Tensor, - indices: torch.Tensor, - B: int, - N: int, - L: int, - dtype: torch.dtype, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim). - indices: Valid token indices out of unpacked tensor. Shape: (total,) - - Returns: - x: Visual tokens. Shape: (B, N, num_heads * head_dim). - y: Text tokens. Shape: (B, L, num_heads * head_dim). - """ - ctx.save_for_backward(indices) - ctx.B, ctx.N, ctx.L = B, N, L - D = xy.size(1) - - # Pad sequences to (B, N + L, dim). - assert indices.ndim == 1 - output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype) - indices = indices.unsqueeze(1).expand( - -1, D - ) # (total,) -> (total, num_heads * head_dim) - output.scatter_(0, indices, xy) - xy = output.view(B, N + L, D) - - # Split visual and text tokens along the sequence length. - return torch.tensor_split(xy, (N,), dim=1) - - -def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]: - return PadSplitXY.apply(xy, indices, B, N, L, dtype) - - -class UnifyStreams(torch.autograd.Function): - """Unify visual and text streams.""" - - @staticmethod - def forward( - ctx, - q_x: torch.Tensor, - k_x: torch.Tensor, - v_x: torch.Tensor, - q_y: torch.Tensor, - k_y: torch.Tensor, - v_y: torch.Tensor, - indices: torch.Tensor, - ): - """ - Args: - q_x: (B, N, num_heads, head_dim) - k_x: (B, N, num_heads, head_dim) - v_x: (B, N, num_heads, head_dim) - q_y: (B, L, num_heads, head_dim) - k_y: (B, L, num_heads, head_dim) - v_y: (B, L, num_heads, head_dim) - indices: (total <= B * (N + L)) - - Returns: - qkv: (total <= B * (N + L), 3, num_heads, head_dim) - """ - ctx.save_for_backward(indices) - B, N, num_heads, head_dim = q_x.size() - ctx.B, ctx.N, ctx.L = B, N, q_y.size(1) - D = num_heads * head_dim - - q = torch.cat([q_x, q_y], dim=1) - k = torch.cat([k_x, k_y], dim=1) - v = torch.cat([v_x, v_y], dim=1) - qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D) - - indices = indices[:, None, None].expand(-1, 3, D) - qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim) - return qkv.unflatten(2, (num_heads, head_dim)) - - -def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor: - return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices) + \ No newline at end of file diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index c5c99cd..d6f2c60 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -33,9 +33,7 @@ except: pass import torch -import torch.nn.functional as F import torch.utils.data -from einops import rearrange, repeat #from .dit.joint_model.context_parallel import get_cp_rank_size from tqdm import tqdm @@ -79,48 +77,6 @@ def unnormalize_latents( assert z.size(1) == mean.size(0) == std.size(0) return z * std.to(z) + mean.to(z) - - -def compute_packed_indices( - N: int, - text_mask: List[torch.Tensor], -) -> Dict[str, torch.Tensor]: - """ - Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80 - - Args: - N: Number of visual tokens. - text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding. - - Returns: - packed_indices: Dict with keys for Flash Attention: - - valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding) - in the packed sequence. - - cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence. - - max_seqlen_in_batch_kv: int of the maximum sequence length in the batch. - """ - # Create an expanded token mask saying which tokens are valid across both visual and text tokens. - assert N > 0 and len(text_mask) == 1 - text_mask = text_mask[0] - - mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L) - seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) - valid_token_indices = torch.nonzero( - mask.flatten(), as_tuple=False - ).flatten() # up to (B * (N + L),) - - assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,) - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) - ) - max_seqlen_in_batch = seqlens_in_batch.max().item() - - return { - "cu_seqlens_kv": cu_seqlens, - "max_seqlen_in_batch_kv": max_seqlen_in_batch, - "valid_token_indices_kv": valid_token_indices, - } - class T2VSynthMochiModel: def __init__( self, @@ -166,6 +122,17 @@ class T2VSynthMochiModel: params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} logging.info(f"Loading model state_dict from {dit_checkpoint_path}...") dit_sd = load_torch_file(dit_checkpoint_path) + + #comfy format + prefix = "model.diffusion_model." + first_key = next(iter(dit_sd), None) + if first_key and first_key.startswith(prefix): + new_dit_sd = { + key[len(prefix):] if key.startswith(prefix) else key: value + for key, value in dit_sd.items() + } + dit_sd = new_dit_sd + if "gguf" in dit_checkpoint_path.lower(): logging.info("Loading GGUF model state_dict...") from .. import mz_gguf_loader @@ -209,14 +176,6 @@ class T2VSynthMochiModel: self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) - def get_packed_indices(self, y_mask, *, lT, lW, lH): - patch_size = 2 - N = lT * lH * lW // (patch_size**2) - assert len(y_mask) == 1 - packed_indices = compute_packed_indices(N, y_mask) - self.move_to_device_(packed_indices) - return packed_indices - def move_to_device_(self, sample): if isinstance(sample, dict): for key in sample.keys(): @@ -227,7 +186,7 @@ class T2VSynthMochiModel: torch.manual_seed(args["seed"]) torch.cuda.manual_seed(args["seed"]) - generator = torch.Generator(device=self.device) + generator = torch.Generator(device=torch.device("cpu")) generator.manual_seed(args["seed"]) num_frames = args["num_frames"] @@ -259,14 +218,13 @@ class T2VSynthMochiModel: T = (num_frames - 1) // temporal_downsample + 1 H = height // spatial_downsample W = width // spatial_downsample - latent_dims = dict(lT=T, lW=W, lH=H) z = torch.randn( (B, C, T, H, W), - device=self.device, + device=torch.device("cpu"), generator=generator, dtype=torch.float32, - ) + ).to(self.device) if in_samples is not None: z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device) @@ -277,14 +235,7 @@ class T2VSynthMochiModel: sample_null = { "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], "y_feat": [args["negative_embeds"]["embeds"].to(self.device)] - } - - sample["packed_indices"] = self.get_packed_indices( - sample["y_mask"], **latent_dims - ) - sample_null["packed_indices"] = self.get_packed_indices( - sample_null["y_mask"], **latent_dims - ) + } self.dit.to(self.device) if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: @@ -292,7 +243,7 @@ class T2VSynthMochiModel: else: autocast_dtype = torch.bfloat16 - def model_fn(*, z, sigma, cfg_scale): + def model_fn(*, z, sigma, cfg_scale): nonlocal sample, sample_null if cfg_scale > 1.0: out_cond = self.dit(z, sigma, **sample) @@ -314,8 +265,7 @@ class T2VSynthMochiModel: sigma=torch.full([B], sigma, device=z.device), cfg_scale=cfg_schedule[i], ) - pred = pred.to(z) - z = z + dsigma * pred + z = z + dsigma * pred.to(z) if callback is not None: callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps) else: diff --git a/mochi_preview/vae/cp_conv.py b/mochi_preview/vae/cp_conv.py deleted file mode 100644 index e5e96de..0000000 --- a/mochi_preview/vae/cp_conv.py +++ /dev/null @@ -1,152 +0,0 @@ -from typing import Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn.functional as F - -from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size - - -def cast_tuple(t, length=1): - return t if isinstance(t, tuple) else ((t,) * length) - - -def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor: - """ - Forward pass that handles communication between ranks for inference. - Args: - x: Tensor of shape (B, C, T, H, W) - frames_to_send: int, number of frames to communicate between ranks - Returns: - output: Tensor of shape (B, C, T', H, W) - """ - cp_rank, cp_world_size = cp.get_cp_rank_size() - if frames_to_send == 0 or cp_world_size == 1: - return x - - group = get_cp_group() - global_rank = dist.get_rank() - - # Send to next rank - if cp_rank < cp_world_size - 1: - assert x.size(2) >= frames_to_send - tail = x[:, :, -frames_to_send:].contiguous() - dist.send(tail, global_rank + 1, group=group) - - # Receive from previous rank - if cp_rank > 0: - B, C, _, H, W = x.shape - recv_buffer = torch.empty( - (B, C, frames_to_send, H, W), - dtype=x.dtype, - device=x.device, - ) - dist.recv(recv_buffer, global_rank - 1, group=group) - x = torch.cat([recv_buffer, x], dim=2) - - return x - - -def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor: - if max_T > x.size(2): - pad_T = max_T - x.size(2) - pad_dims = (0, 0, 0, 0, 0, pad_T) - return F.pad(x, pad_dims) - return x - - -def gather_all_frames(x: torch.Tensor) -> torch.Tensor: - """ - Gathers all frames from all processes for inference. - Args: - x: Tensor of shape (B, C, T, H, W) - Returns: - output: Tensor of shape (B, C, T_total, H, W) - """ - cp_rank, cp_size = get_cp_rank_size() - cp_group = get_cp_group() - - # Ensure the tensor is contiguous for collective operations - x = x.contiguous() - - # Get the local time dimension size - local_T = x.size(2) - local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64) - - # Gather all T sizes from all processes - all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)] - dist.all_gather(all_T, local_T_tensor, group=cp_group) - all_T = [t.item() for t in all_T] - - # Pad the tensor at the end of the time dimension to match max_T - max_T = max(all_T) - x = _pad_to_max(x, max_T).contiguous() - - # Prepare a list to hold the gathered tensors - gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)] - - # Perform the all_gather operation - dist.all_gather(gathered_x, x, group=cp_group) - - # Slice each gathered tensor back to its original T size - for idx, t_size in enumerate(all_T): - gathered_x[idx] = gathered_x[idx][:, :, :t_size] - - return torch.cat(gathered_x, dim=2) - - -def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool: - """Estimate memory usage based on input tensor size and data type.""" - element_size = input.element_size() # Size in bytes of each element - memory_bytes = input.numel() * element_size - memory_gb = memory_bytes / 1024**3 - return memory_gb > max_gb - - -class ContextParallelCausalConv3d(torch.nn.Conv3d): - def __init__( - self, - in_channels, - out_channels, - kernel_size: Union[int, Tuple[int, int, int]], - stride: Union[int, Tuple[int, int, int]], - **kwargs, - ): - kernel_size = cast_tuple(kernel_size, 3) - stride = cast_tuple(stride, 3) - height_pad = (kernel_size[1] - 1) // 2 - width_pad = (kernel_size[2] - 1) // 2 - - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=(1, 1, 1), - padding=(0, height_pad, width_pad), - **kwargs, - ) - - def forward(self, x: torch.Tensor): - cp_rank, cp_world_size = get_cp_rank_size() - - context_size = self.kernel_size[0] - 1 - if cp_rank == 0: - mode = "constant" if self.padding_mode == "zeros" else self.padding_mode - x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode) - - if cp_world_size == 1: - return super().forward(x) - - if all(s == 1 for s in self.stride): - # Receive some frames from previous rank. - x = cp_pass_frames(x, context_size) - return super().forward(x) - - # Less efficient implementation for strided convs. - # All gather x, infer and chunk. - x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W] - x = super().forward(x) - x_chunks = x.tensor_split(cp_world_size, dim=2) - assert len(x_chunks) == cp_world_size - return x_chunks[cp_rank] diff --git a/nodes.py b/nodes.py index 65fd380..c78a682 100644 --- a/nodes.py +++ b/nodes.py @@ -473,6 +473,7 @@ class MochiSampler: CATEGORY = "MochiWrapper" def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None): + mm.unload_all_models() mm.soft_empty_cache() if opt_sigmas is not None: From 4a7458ffd6cb72e7c29ca4e45f4ac099453ed262 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 3 Nov 2024 19:18:29 +0200 Subject: [PATCH 3/9] restore MochiEdit compatiblity temporary --- mochi_preview/dit/joint_model/asymm_models_joint.py | 1 + mochi_preview/t2v_synth_mochi.py | 4 ++++ mochi_preview/vae/model.py | 2 -- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 502a2df..9650f40 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -584,6 +584,7 @@ class AsymmDiTJoint(nn.Module): y_mask: List[torch.Tensor], rope_cos: torch.Tensor = None, rope_sin: torch.Tensor = None, + packed_indices: Optional[dict] = None, ): """Forward pass of DiT. diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index d6f2c60..ccec165 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -176,6 +176,10 @@ class T2VSynthMochiModel: self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) + def get_packed_indices(self, y_mask, **latent_dims): + # temporary dummy func for compatibility + return [] + def move_to_device_(self, sample): if isinstance(sample, dict): for key in sample.keys(): diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 6cfeeae..56f937e 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -6,8 +6,6 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -#from ..dit.joint_model.context_parallel import get_cp_rank_size -#from ..vae.cp_conv import cp_pass_frames, gather_all_frames from .latent_dist import LatentDistribution def cast_tuple(t, length=1): From fd4a02e6a619c84e9a76c3dff9a3644eec5216b9 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 3 Nov 2024 19:41:08 +0200 Subject: [PATCH 4/9] Update nodes.py --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index c78a682..6196927 100644 --- a/nodes.py +++ b/nodes.py @@ -59,7 +59,7 @@ class MochiSigmaSchedule: RETURN_NAMES = ("sigmas",) FUNCTION = "loadmodel" CATEGORY = "MochiWrapper" - DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended" + DESCRIPTION = "Sigma schedule to use with mochi wrapper sampler" def loadmodel(self, num_steps, threshold_noise, denoise, linear_steps=None): total_steps = num_steps From 56b5dbbf828f86f59ce3e01004d707aeee7ad304 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:11:58 +0200 Subject: [PATCH 5/9] Add different RMSNorm functions for testing Initial testing for me shows that the RMSNorm from flash_attn.ops.triton.layer_norm is ~8-10% faster, apex is untested as I don't currently have it installed. --- configs/vae_stats.json | 4 --- .../dit/joint_model/asymm_models_joint.py | 36 ++++++++++++++++--- mochi_preview/t2v_synth_mochi.py | 13 +++---- nodes.py | 23 ++++++++---- 4 files changed, 51 insertions(+), 25 deletions(-) delete mode 100644 configs/vae_stats.json diff --git a/configs/vae_stats.json b/configs/vae_stats.json deleted file mode 100644 index e3278af..0000000 --- a/configs/vae_stats.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285], - "std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041] -} diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 9650f40..7443961 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -7,15 +7,15 @@ import torch.nn.functional as F from .layers import ( FeedForward, PatchEmbed, - RMSNorm, TimestepEmbedder, ) + from .mod_rmsnorm import modulated_rmsnorm -from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm) -from .rope_mixed import (compute_mixed_rotation, create_position_matrix) +from .residual_tanh_gated_rmsnorm import residual_tanh_gated_rmsnorm +from .rope_mixed import compute_mixed_rotation, create_position_matrix from .temporal_rope import apply_rotary_emb_qk_real -from .utils import (pool_tokens, modulate) +from .utils import pool_tokens, modulate try: from flash_attn import flash_attn_func @@ -119,6 +119,7 @@ class AsymmetricAttention(nn.Module): softmax_scale: Optional[float] = None, device: Optional[torch.device] = None, attention_mode: str = "sdpa", + rms_norm_func: bool = False, ): super().__init__() @@ -145,6 +146,28 @@ class AsymmetricAttention(nn.Module): # Query and key normalization for stability. assert qk_norm + if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func + try: + from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster + @torch.compiler.disable() #cause NaNs when compiled for some reason + class RMSNorm(FlashTritonRMSNorm): + pass + except: + raise ImportError("Flash Triton RMSNorm not available.") + elif rms_norm_func == "flash_attn": + try: + from flash_attn.ops.rms_norm import RMSNorm as FlashRMSNorm #slightly faster + @torch.compiler.disable() #cause NaNs when compiled for some reason + class RMSNorm(FlashRMSNorm): + pass + except: + raise ImportError("Flash RMSNorm not available.") + elif rms_norm_func == "apex": + from apex.normalization import FusedRMSNorm as ApexRMSNorm + class RMSNorm(ApexRMSNorm): + pass + else: + from .layers import RMSNorm self.q_norm_x = RMSNorm(self.head_dim, device=device) self.k_norm_x = RMSNorm(self.head_dim, device=device) self.q_norm_y = RMSNorm(self.head_dim, device=device) @@ -210,7 +233,6 @@ class AsymmetricAttention(nn.Module): ) return out - @torch.compiler.disable() def run_attention( self, q, @@ -283,6 +305,7 @@ class AsymmetricJointBlock(nn.Module): update_y: bool = True, # Whether to update text tokens in this block. device: Optional[torch.device] = None, attention_mode: str = "sdpa", + rms_norm_func: str = "default", **block_kwargs, ): super().__init__() @@ -304,6 +327,7 @@ class AsymmetricJointBlock(nn.Module): update_y=update_y, device=device, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, **block_kwargs, ) @@ -450,6 +474,7 @@ class AsymmDiTJoint(nn.Module): rope_theta: float = 10000.0, device: Optional[torch.device] = None, attention_mode: str = "sdpa", + rms_norm_func: str = "default", **block_kwargs, ): super().__init__() @@ -518,6 +543,7 @@ class AsymmDiTJoint(nn.Module): update_y=update_y, device=device, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, **block_kwargs, ) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index ccec165..798a7aa 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -1,4 +1,3 @@ -import json from typing import Dict, List, Optional, Union #temporary patch to fix torch compile bug in Windows @@ -35,7 +34,6 @@ except: import torch import torch.utils.data -#from .dit.joint_model.context_parallel import get_cp_rank_size from tqdm import tqdm from comfy.utils import ProgressBar, load_torch_file import comfy.model_management as mm @@ -83,11 +81,11 @@ class T2VSynthMochiModel: *, device: torch.device, offload_device: torch.device, - vae_stats_path: str, dit_checkpoint_path: str, weight_dtype: torch.dtype = torch.float8_e4m3fn, fp8_fastmode: bool = False, attention_mode: str = "sdpa", + rms_norm_func: str = "default", compile_args: Optional[Dict] = None, cublas_ops: Optional[bool] = False, ): @@ -117,6 +115,7 @@ class T2VSynthMochiModel: t5_token_length=256, rope_theta=10000.0, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, ) params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} @@ -171,10 +170,6 @@ class T2VSynthMochiModel: model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) self.dit = model - - vae_stats = json.load(open(vae_stats_path)) - self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device) - self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device) def get_packed_indices(self, y_mask, **latent_dims): # temporary dummy func for compatibility @@ -233,8 +228,8 @@ class T2VSynthMochiModel: z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device) sample = { - "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], - "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] + "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], + "y_feat": [args["positive_embeds"]["embeds"].to(self.device)] } sample_null = { "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], diff --git a/nodes.py b/nodes.py index 6196927..75b33c7 100644 --- a/nodes.py +++ b/nodes.py @@ -105,6 +105,7 @@ class DownloadAndLoadMochiModel: "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), "cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}), + "rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}), }, } @@ -114,7 +115,7 @@ class DownloadAndLoadMochiModel: CATEGORY = "MochiWrapper" DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface" - def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False): + def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -154,11 +155,11 @@ class DownloadAndLoadMochiModel: model = T2VSynthMochiModel( device=device, offload_device=offload_device, - vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), dit_checkpoint_path=model_path, weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, compile_args=compile_args, cublas_ops=cublas_ops ) @@ -201,6 +202,7 @@ class MochiModelLoader: "trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}), "compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}), "cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}), + "rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}), }, } @@ -209,7 +211,7 @@ class MochiModelLoader: FUNCTION = "loadmodel" CATEGORY = "MochiWrapper" - def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False): + def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -226,6 +228,7 @@ class MochiModelLoader: weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, attention_mode=attention_mode, + rms_norm_func=rms_norm_func, compile_args=compile_args, cublas_ops=cublas_ops ) @@ -749,10 +752,16 @@ class MochiImageEncode: from .mochi_preview.vae.model import apply_tiled B, H, W, C = images.shape - images = images.unsqueeze(0) * 2 - 1 - images = rearrange(images, "t b h w c -> t c b h w") - images = images.to(device) - print(images.shape) + import torchvision.transforms as transforms + normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + input_image_tensor = rearrange(images, 'b h w c -> b c h w') + input_image_tensor = normalize(input_image_tensor).unsqueeze(0) + input_image_tensor = rearrange(input_image_tensor, 'b t c h w -> b c t h w', t=B) + + #images = images.unsqueeze(0).sub_(0.5).div_(0.5) + #images = rearrange(input_image_tensor, "b c t h w -> t c b h w") + images = input_image_tensor.to(device) + encoder.to(device) print("images before encoding", images.shape) with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype): From f94cf433313b842e36740723247c4c8d5dce0957 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:06:20 +0200 Subject: [PATCH 6/9] Update nodes.py --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 75b33c7..3c1e947 100644 --- a/nodes.py +++ b/nodes.py @@ -181,7 +181,7 @@ class DownloadAndLoadMochiModel: vae_sd = load_torch_file(vae_path) if is_accelerate_available: for key in vae_sd: - set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key]) + set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key]) else: vae.load_state_dict(vae_sd, strict=True) vae.eval().to(torch.bfloat16).to("cpu") From 78f9e7b8969f72a48df715a9eedc5686a3b1fa60 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:10:15 +0200 Subject: [PATCH 7/9] Update nodes.py --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 3c1e947..46a9c8e 100644 --- a/nodes.py +++ b/nodes.py @@ -634,7 +634,7 @@ class MochiDecode: return torch.cat(result_rows, dim=3) vae.to(device) - with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16): + with torch.autocast(mm.get_autocast_device(device), dtype=vae.dtype): if enable_vae_tiling and frame_batch_size > T: logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}") frame_batch_size = T From 5ca4bbf319ac864a30c2024acc932b172da0db89 Mon Sep 17 00:00:00 2001 From: Yoshimasa Niwa Date: Tue, 5 Nov 2024 12:46:13 +0900 Subject: [PATCH 8/9] Workaround pad problem on mps When using `torch.nn.functional.pad` with tensor that size is larger than 2^16 (65526), the output tensor would be broken. This patch moves tensor to CPU to workaround the problem. It doesn't too much impacts in terms of speed of vea on mps. --- mochi_preview/vae/model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 56f937e..e26add7 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -94,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d): raise NotImplementedError +def mps_safe_pad(input, pad, mode): + if input.device.type == "mps" and input.numel() >= 2 ** 16: + device = input.device + input = input.to(device="cpu") + output = F.pad(input, pad, mode=mode) + return output.to(device=device) + else: + return F.pad(input, pad, mode=mode) class ContextParallelConv3d(SafeConv3d): def __init__( @@ -136,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d): # Apply padding. mode = "constant" if self.padding_mode == "zeros" else self.padding_mode if self.context_parallel: - x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) + x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) else: - x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode) + x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode) return super().forward(x) From fdecd4ee08874aa32ffc9496b60fc9db20c45707 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 5 Nov 2024 06:18:49 +0200 Subject: [PATCH 9/9] Update nodes.py --- nodes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nodes.py b/nodes.py index 46a9c8e..1546747 100644 --- a/nodes.py +++ b/nodes.py @@ -223,7 +223,6 @@ class MochiModelLoader: model = T2VSynthMochiModel( device=device, offload_device=offload_device, - vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"), dit_checkpoint_path=model_path, weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,