cleanup, sampler output name fix

This commit is contained in:
kijai 2024-10-27 20:02:44 +02:00
parent 195da244df
commit 3395aa8ca0
2 changed files with 9 additions and 33 deletions

View File

@ -5,8 +5,8 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ..dit.joint_model.context_parallel import get_cp_rank_size
from ..vae.cp_conv import cp_pass_frames, gather_all_frames
#from ..dit.joint_model.context_parallel import get_cp_rank_size
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
def cast_tuple(t, length=1):
@ -125,8 +125,6 @@ class ContextParallelConv3d(SafeConv3d):
)
def forward(self, x: torch.Tensor):
cp_rank, cp_world_size = get_cp_rank_size()
# Compute padding amounts.
context_size = self.kernel_size[0] - 1
if self.causal:
@ -139,31 +137,9 @@ class ContextParallelConv3d(SafeConv3d):
# Apply padding.
assert self.padding_mode == "replicate" # DEBUG
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
if self.context_parallel and cp_world_size == 1:
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
else:
if cp_rank == 0:
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
elif cp_rank == cp_world_size - 1 and pad_back:
x = F.pad(x, (0, 0, 0, 0, 0, pad_back), mode=mode)
if self.context_parallel and cp_world_size == 1:
return super().forward(x)
if self.stride[0] == 1:
# Receive some frames from previous rank.
x = cp_pass_frames(x, context_size)
return super().forward(x)
# Less efficient implementation for strided convs.
# All gather x, infer and chunk.
assert (
x.dtype == torch.bfloat16
), f"Expected x to be of type torch.bfloat16, got {x.dtype}"
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
return StridedSafeConv3d.forward(self, x, local_shard=True)
class Conv1x1(nn.Linear):
"""*1x1 Conv implemented with a linear layer."""
@ -217,8 +193,8 @@ class DepthToSpaceTime(nn.Module):
sw=self.spatial_expansion,
)
cp_rank, _ = get_cp_rank_size()
if self.temporal_expansion > 1 and cp_rank == 0:
# cp_rank, _ = cp.get_cp_rank_size()
if self.temporal_expansion > 1: # and cp_rank == 0:
# Drop the first self.temporal_expansion - 1 frames.
# This is because we always want the 3x3x3 conv filter to only apply
# to the first frame, and the first frame doesn't need to be repeated.
@ -457,10 +433,10 @@ class CausalUpsampleBlock(nn.Module):
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
attn_block = AttentionBlock(channels) if has_attention else None
#attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(
channels, affine=True, attn_block=attn_block, **block_kwargs
channels, affine=True, attn_block=None, **block_kwargs
)

View File

@ -364,7 +364,7 @@ class MochiSampler:
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("model", "samples",)
RETURN_NAMES = ("samples",)
FUNCTION = "process"
CATEGORY = "MochiWrapper"