From 3395aa8ca07d5a6726959b7871609dfd28c49868 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 27 Oct 2024 20:02:44 +0200 Subject: [PATCH] cleanup, sampler output name fix --- mochi_preview/vae/model.py | 40 ++++++++------------------------------ nodes.py | 2 +- 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 823343d..385eefd 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -5,8 +5,8 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from ..dit.joint_model.context_parallel import get_cp_rank_size -from ..vae.cp_conv import cp_pass_frames, gather_all_frames +#from ..dit.joint_model.context_parallel import get_cp_rank_size +#from ..vae.cp_conv import cp_pass_frames, gather_all_frames def cast_tuple(t, length=1): @@ -125,8 +125,6 @@ class ContextParallelConv3d(SafeConv3d): ) def forward(self, x: torch.Tensor): - cp_rank, cp_world_size = get_cp_rank_size() - # Compute padding amounts. context_size = self.kernel_size[0] - 1 if self.causal: @@ -139,30 +137,8 @@ class ContextParallelConv3d(SafeConv3d): # Apply padding. assert self.padding_mode == "replicate" # DEBUG mode = "constant" if self.padding_mode == "zeros" else self.padding_mode - if self.context_parallel and cp_world_size == 1: - x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) - else: - if cp_rank == 0: - x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode) - elif cp_rank == cp_world_size - 1 and pad_back: - x = F.pad(x, (0, 0, 0, 0, 0, pad_back), mode=mode) - - if self.context_parallel and cp_world_size == 1: - return super().forward(x) - - if self.stride[0] == 1: - # Receive some frames from previous rank. - x = cp_pass_frames(x, context_size) - return super().forward(x) - - # Less efficient implementation for strided convs. - # All gather x, infer and chunk. - assert ( - x.dtype == torch.bfloat16 - ), f"Expected x to be of type torch.bfloat16, got {x.dtype}" - - x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W] - return StridedSafeConv3d.forward(self, x, local_shard=True) + x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) + return super().forward(x) class Conv1x1(nn.Linear): @@ -217,8 +193,8 @@ class DepthToSpaceTime(nn.Module): sw=self.spatial_expansion, ) - cp_rank, _ = get_cp_rank_size() - if self.temporal_expansion > 1 and cp_rank == 0: + # cp_rank, _ = cp.get_cp_rank_size() + if self.temporal_expansion > 1: # and cp_rank == 0: # Drop the first self.temporal_expansion - 1 frames. # This is because we always want the 3x3x3 conv filter to only apply # to the first frame, and the first frame doesn't need to be repeated. @@ -457,10 +433,10 @@ class CausalUpsampleBlock(nn.Module): def block_fn(channels, *, has_attention: bool = False, **block_kwargs): - attn_block = AttentionBlock(channels) if has_attention else None + #attn_block = AttentionBlock(channels) if has_attention else None return ResBlock( - channels, affine=True, attn_block=attn_block, **block_kwargs + channels, affine=True, attn_block=None, **block_kwargs ) diff --git a/nodes.py b/nodes.py index 358941b..072fe6f 100644 --- a/nodes.py +++ b/nodes.py @@ -364,7 +364,7 @@ class MochiSampler: } RETURN_TYPES = ("LATENT",) - RETURN_NAMES = ("model", "samples",) + RETURN_NAMES = ("samples",) FUNCTION = "process" CATEGORY = "MochiWrapper"