cleanup, sampler output name fix
This commit is contained in:
parent
195da244df
commit
3395aa8ca0
@ -5,8 +5,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from ..dit.joint_model.context_parallel import get_cp_rank_size
|
#from ..dit.joint_model.context_parallel import get_cp_rank_size
|
||||||
from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||||
|
|
||||||
|
|
||||||
def cast_tuple(t, length=1):
|
def cast_tuple(t, length=1):
|
||||||
@ -125,8 +125,6 @@ class ContextParallelConv3d(SafeConv3d):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
cp_rank, cp_world_size = get_cp_rank_size()
|
|
||||||
|
|
||||||
# Compute padding amounts.
|
# Compute padding amounts.
|
||||||
context_size = self.kernel_size[0] - 1
|
context_size = self.kernel_size[0] - 1
|
||||||
if self.causal:
|
if self.causal:
|
||||||
@ -139,30 +137,8 @@ class ContextParallelConv3d(SafeConv3d):
|
|||||||
# Apply padding.
|
# Apply padding.
|
||||||
assert self.padding_mode == "replicate" # DEBUG
|
assert self.padding_mode == "replicate" # DEBUG
|
||||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||||
if self.context_parallel and cp_world_size == 1:
|
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
return super().forward(x)
|
||||||
else:
|
|
||||||
if cp_rank == 0:
|
|
||||||
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
|
||||||
elif cp_rank == cp_world_size - 1 and pad_back:
|
|
||||||
x = F.pad(x, (0, 0, 0, 0, 0, pad_back), mode=mode)
|
|
||||||
|
|
||||||
if self.context_parallel and cp_world_size == 1:
|
|
||||||
return super().forward(x)
|
|
||||||
|
|
||||||
if self.stride[0] == 1:
|
|
||||||
# Receive some frames from previous rank.
|
|
||||||
x = cp_pass_frames(x, context_size)
|
|
||||||
return super().forward(x)
|
|
||||||
|
|
||||||
# Less efficient implementation for strided convs.
|
|
||||||
# All gather x, infer and chunk.
|
|
||||||
assert (
|
|
||||||
x.dtype == torch.bfloat16
|
|
||||||
), f"Expected x to be of type torch.bfloat16, got {x.dtype}"
|
|
||||||
|
|
||||||
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
|
|
||||||
return StridedSafeConv3d.forward(self, x, local_shard=True)
|
|
||||||
|
|
||||||
|
|
||||||
class Conv1x1(nn.Linear):
|
class Conv1x1(nn.Linear):
|
||||||
@ -217,8 +193,8 @@ class DepthToSpaceTime(nn.Module):
|
|||||||
sw=self.spatial_expansion,
|
sw=self.spatial_expansion,
|
||||||
)
|
)
|
||||||
|
|
||||||
cp_rank, _ = get_cp_rank_size()
|
# cp_rank, _ = cp.get_cp_rank_size()
|
||||||
if self.temporal_expansion > 1 and cp_rank == 0:
|
if self.temporal_expansion > 1: # and cp_rank == 0:
|
||||||
# Drop the first self.temporal_expansion - 1 frames.
|
# Drop the first self.temporal_expansion - 1 frames.
|
||||||
# This is because we always want the 3x3x3 conv filter to only apply
|
# This is because we always want the 3x3x3 conv filter to only apply
|
||||||
# to the first frame, and the first frame doesn't need to be repeated.
|
# to the first frame, and the first frame doesn't need to be repeated.
|
||||||
@ -457,10 +433,10 @@ class CausalUpsampleBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
|
def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
|
||||||
attn_block = AttentionBlock(channels) if has_attention else None
|
#attn_block = AttentionBlock(channels) if has_attention else None
|
||||||
|
|
||||||
return ResBlock(
|
return ResBlock(
|
||||||
channels, affine=True, attn_block=attn_block, **block_kwargs
|
channels, affine=True, attn_block=None, **block_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user