cleanup, sampler output name fix
This commit is contained in:
parent
195da244df
commit
3395aa8ca0
@ -5,8 +5,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from ..dit.joint_model.context_parallel import get_cp_rank_size
|
||||
from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||
#from ..dit.joint_model.context_parallel import get_cp_rank_size
|
||||
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
@ -125,8 +125,6 @@ class ContextParallelConv3d(SafeConv3d):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
cp_rank, cp_world_size = get_cp_rank_size()
|
||||
|
||||
# Compute padding amounts.
|
||||
context_size = self.kernel_size[0] - 1
|
||||
if self.causal:
|
||||
@ -139,31 +137,9 @@ class ContextParallelConv3d(SafeConv3d):
|
||||
# Apply padding.
|
||||
assert self.padding_mode == "replicate" # DEBUG
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
if self.context_parallel and cp_world_size == 1:
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
else:
|
||||
if cp_rank == 0:
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||
elif cp_rank == cp_world_size - 1 and pad_back:
|
||||
x = F.pad(x, (0, 0, 0, 0, 0, pad_back), mode=mode)
|
||||
|
||||
if self.context_parallel and cp_world_size == 1:
|
||||
return super().forward(x)
|
||||
|
||||
if self.stride[0] == 1:
|
||||
# Receive some frames from previous rank.
|
||||
x = cp_pass_frames(x, context_size)
|
||||
return super().forward(x)
|
||||
|
||||
# Less efficient implementation for strided convs.
|
||||
# All gather x, infer and chunk.
|
||||
assert (
|
||||
x.dtype == torch.bfloat16
|
||||
), f"Expected x to be of type torch.bfloat16, got {x.dtype}"
|
||||
|
||||
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
|
||||
return StridedSafeConv3d.forward(self, x, local_shard=True)
|
||||
|
||||
|
||||
class Conv1x1(nn.Linear):
|
||||
"""*1x1 Conv implemented with a linear layer."""
|
||||
@ -217,8 +193,8 @@ class DepthToSpaceTime(nn.Module):
|
||||
sw=self.spatial_expansion,
|
||||
)
|
||||
|
||||
cp_rank, _ = get_cp_rank_size()
|
||||
if self.temporal_expansion > 1 and cp_rank == 0:
|
||||
# cp_rank, _ = cp.get_cp_rank_size()
|
||||
if self.temporal_expansion > 1: # and cp_rank == 0:
|
||||
# Drop the first self.temporal_expansion - 1 frames.
|
||||
# This is because we always want the 3x3x3 conv filter to only apply
|
||||
# to the first frame, and the first frame doesn't need to be repeated.
|
||||
@ -457,10 +433,10 @@ class CausalUpsampleBlock(nn.Module):
|
||||
|
||||
|
||||
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
|
||||
attn_block = AttentionBlock(channels) if has_attention else None
|
||||
#attn_block = AttentionBlock(channels) if has_attention else None
|
||||
|
||||
return ResBlock(
|
||||
channels, affine=True, attn_block=attn_block, **block_kwargs
|
||||
channels, affine=True, attn_block=None, **block_kwargs
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user