cleanup, sampler output name fix

This commit is contained in:
kijai 2024-10-27 20:02:44 +02:00
parent 195da244df
commit 3395aa8ca0
2 changed files with 9 additions and 33 deletions

View File

@ -5,8 +5,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from ..dit.joint_model.context_parallel import get_cp_rank_size #from ..dit.joint_model.context_parallel import get_cp_rank_size
from ..vae.cp_conv import cp_pass_frames, gather_all_frames #from ..vae.cp_conv import cp_pass_frames, gather_all_frames
def cast_tuple(t, length=1): def cast_tuple(t, length=1):
@ -125,8 +125,6 @@ class ContextParallelConv3d(SafeConv3d):
) )
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
cp_rank, cp_world_size = get_cp_rank_size()
# Compute padding amounts. # Compute padding amounts.
context_size = self.kernel_size[0] - 1 context_size = self.kernel_size[0] - 1
if self.causal: if self.causal:
@ -139,30 +137,8 @@ class ContextParallelConv3d(SafeConv3d):
# Apply padding. # Apply padding.
assert self.padding_mode == "replicate" # DEBUG assert self.padding_mode == "replicate" # DEBUG
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
if self.context_parallel and cp_world_size == 1: x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) return super().forward(x)
else:
if cp_rank == 0:
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
elif cp_rank == cp_world_size - 1 and pad_back:
x = F.pad(x, (0, 0, 0, 0, 0, pad_back), mode=mode)
if self.context_parallel and cp_world_size == 1:
return super().forward(x)
if self.stride[0] == 1:
# Receive some frames from previous rank.
x = cp_pass_frames(x, context_size)
return super().forward(x)
# Less efficient implementation for strided convs.
# All gather x, infer and chunk.
assert (
x.dtype == torch.bfloat16
), f"Expected x to be of type torch.bfloat16, got {x.dtype}"
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
return StridedSafeConv3d.forward(self, x, local_shard=True)
class Conv1x1(nn.Linear): class Conv1x1(nn.Linear):
@ -217,8 +193,8 @@ class DepthToSpaceTime(nn.Module):
sw=self.spatial_expansion, sw=self.spatial_expansion,
) )
cp_rank, _ = get_cp_rank_size() # cp_rank, _ = cp.get_cp_rank_size()
if self.temporal_expansion > 1 and cp_rank == 0: if self.temporal_expansion > 1: # and cp_rank == 0:
# Drop the first self.temporal_expansion - 1 frames. # Drop the first self.temporal_expansion - 1 frames.
# This is because we always want the 3x3x3 conv filter to only apply # This is because we always want the 3x3x3 conv filter to only apply
# to the first frame, and the first frame doesn't need to be repeated. # to the first frame, and the first frame doesn't need to be repeated.
@ -457,10 +433,10 @@ class CausalUpsampleBlock(nn.Module):
def block_fn(channels, *, has_attention: bool = False, **block_kwargs): def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
attn_block = AttentionBlock(channels) if has_attention else None #attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock( return ResBlock(
channels, affine=True, attn_block=attn_block, **block_kwargs channels, affine=True, attn_block=None, **block_kwargs
) )

View File

@ -364,7 +364,7 @@ class MochiSampler:
} }
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("model", "samples",) RETURN_NAMES = ("samples",)
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"