update from main
This commit is contained in:
commit
3535a846a8
@ -1,4 +0,0 @@
|
|||||||
{
|
|
||||||
"mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285],
|
|
||||||
"std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041]
|
|
||||||
}
|
|
||||||
@ -1,163 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
_CONTEXT_PARALLEL_GROUP = None
|
|
||||||
_CONTEXT_PARALLEL_RANK = None
|
|
||||||
_CONTEXT_PARALLEL_GROUP_SIZE = None
|
|
||||||
_CONTEXT_PARALLEL_GROUP_RANKS = None
|
|
||||||
|
|
||||||
|
|
||||||
def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
|
|
||||||
if not _CONTEXT_PARALLEL_GROUP:
|
|
||||||
return x
|
|
||||||
|
|
||||||
cp_rank, cp_size = get_cp_rank_size()
|
|
||||||
return x.tensor_split(cp_size, dim=dim)[cp_rank]
|
|
||||||
|
|
||||||
|
|
||||||
def set_cp_group(cp_group, ranks, global_rank):
|
|
||||||
global \
|
|
||||||
_CONTEXT_PARALLEL_GROUP, \
|
|
||||||
_CONTEXT_PARALLEL_RANK, \
|
|
||||||
_CONTEXT_PARALLEL_GROUP_SIZE, \
|
|
||||||
_CONTEXT_PARALLEL_GROUP_RANKS
|
|
||||||
if _CONTEXT_PARALLEL_GROUP is not None:
|
|
||||||
raise RuntimeError("CP group already initialized.")
|
|
||||||
_CONTEXT_PARALLEL_GROUP = cp_group
|
|
||||||
_CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
|
|
||||||
_CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
|
|
||||||
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
|
|
||||||
|
|
||||||
assert (
|
|
||||||
_CONTEXT_PARALLEL_RANK == ranks.index(global_rank)
|
|
||||||
), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
|
|
||||||
assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
|
|
||||||
ranks
|
|
||||||
), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
|
|
||||||
|
|
||||||
|
|
||||||
def get_cp_group():
|
|
||||||
if _CONTEXT_PARALLEL_GROUP is None:
|
|
||||||
raise RuntimeError("CP group not initialized")
|
|
||||||
return _CONTEXT_PARALLEL_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def is_cp_active():
|
|
||||||
return _CONTEXT_PARALLEL_GROUP is not None
|
|
||||||
|
|
||||||
|
|
||||||
def get_cp_rank_size():
|
|
||||||
if _CONTEXT_PARALLEL_GROUP:
|
|
||||||
return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
|
|
||||||
else:
|
|
||||||
return 0, 1
|
|
||||||
|
|
||||||
|
|
||||||
class AllGatherIntoTensorFunction(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
|
|
||||||
ctx.reduce_dtype = reduce_dtype
|
|
||||||
ctx.group = group
|
|
||||||
ctx.batch_size = x.size(0)
|
|
||||||
group_size = dist.get_world_size(group)
|
|
||||||
|
|
||||||
x = x.contiguous()
|
|
||||||
output = torch.empty(
|
|
||||||
group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device
|
|
||||||
)
|
|
||||||
dist.all_gather_into_tensor(output, x, group=group)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
if not _CONTEXT_PARALLEL_GROUP:
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
return AllGatherIntoTensorFunction.apply(
|
|
||||||
tensor, torch.float32, _CONTEXT_PARALLEL_GROUP
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compiler.disable()
|
|
||||||
def _all_to_all_single(output, input, group):
|
|
||||||
# Disable compilation since torch compile changes contiguity.
|
|
||||||
assert input.is_contiguous(), "Input tensor must be contiguous."
|
|
||||||
assert output.is_contiguous(), "Output tensor must be contiguous."
|
|
||||||
return dist.all_to_all_single(output, input, group=group)
|
|
||||||
|
|
||||||
|
|
||||||
class CollectTokens(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
|
|
||||||
"""Redistribute heads and receive tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
qkv: shape: [3, B, N, local_heads, head_dim]
|
|
||||||
|
|
||||||
where M is the number of local tokens,
|
|
||||||
N = cp_size * M is the number of global tokens,
|
|
||||||
local_heads = num_heads // cp_size is the number of local heads.
|
|
||||||
"""
|
|
||||||
ctx.group = group
|
|
||||||
ctx.num_heads = num_heads
|
|
||||||
cp_size = dist.get_world_size(group)
|
|
||||||
assert num_heads % cp_size == 0
|
|
||||||
ctx.local_heads = num_heads // cp_size
|
|
||||||
|
|
||||||
qkv = rearrange(
|
|
||||||
qkv,
|
|
||||||
"B M (qkv G h d) -> G M h B (qkv d)",
|
|
||||||
qkv=3,
|
|
||||||
G=cp_size,
|
|
||||||
h=ctx.local_heads,
|
|
||||||
).contiguous()
|
|
||||||
|
|
||||||
output_chunks = torch.empty_like(qkv)
|
|
||||||
_all_to_all_single(output_chunks, qkv, group=group)
|
|
||||||
|
|
||||||
return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
|
|
||||||
|
|
||||||
|
|
||||||
def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
|
|
||||||
if not _CONTEXT_PARALLEL_GROUP:
|
|
||||||
# Move QKV dimension to the front.
|
|
||||||
# B M (3 H d) -> 3 B M H d
|
|
||||||
B, M, _ = x.size()
|
|
||||||
x = x.view(B, M, 3, num_heads, -1)
|
|
||||||
return x.permute(2, 0, 1, 3, 4)
|
|
||||||
|
|
||||||
return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
|
|
||||||
|
|
||||||
|
|
||||||
class CollectHeads(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
|
|
||||||
"""Redistribute tokens and receive heads.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Output of attention. Shape: [B, N, local_heads, head_dim]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Shape: [B, M, num_heads * head_dim]
|
|
||||||
"""
|
|
||||||
ctx.group = group
|
|
||||||
ctx.local_heads = x.size(2)
|
|
||||||
ctx.head_dim = x.size(3)
|
|
||||||
group_size = dist.get_world_size(group)
|
|
||||||
x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
|
|
||||||
output = torch.empty_like(x)
|
|
||||||
_all_to_all_single(output, x, group=group)
|
|
||||||
del x
|
|
||||||
return rearrange(output, "G h M B D -> B M (G h D)")
|
|
||||||
|
|
||||||
|
|
||||||
def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
if not _CONTEXT_PARALLEL_GROUP:
|
|
||||||
# Merge heads.
|
|
||||||
return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
|
|
||||||
|
|
||||||
return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)
|
|
||||||
@ -62,28 +62,6 @@ class TimestepEmbedder(nn.Module):
|
|||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|
||||||
class PooledCaptionEmbedder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
caption_feature_dim: int,
|
|
||||||
hidden_size: int,
|
|
||||||
*,
|
|
||||||
bias: bool = True,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.caption_feature_dim = caption_feature_dim
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.mlp(x)
|
|
||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -152,8 +130,6 @@ class PatchEmbed(nn.Module):
|
|||||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||||
|
|
||||||
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||||
#print("x",x.dtype, x.device)
|
|
||||||
#print(self.proj.weight.dtype, self.proj.weight.device)
|
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
|
|
||||||
# Flatten temporal and spatial dimensions.
|
# Flatten temporal and spatial dimensions.
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import functools
|
#import functools
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -21,7 +21,7 @@ def centers(start: float, stop, num, dtype=None, device=None):
|
|||||||
return (edges[:-1] + edges[1:]) / 2
|
return (edges[:-1] + edges[1:]) / 2
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=1)
|
#@functools.lru_cache(maxsize=1)
|
||||||
def create_position_matrix(
|
def create_position_matrix(
|
||||||
T: int,
|
T: int,
|
||||||
pH: int,
|
pH: int,
|
||||||
|
|||||||
@ -28,95 +28,4 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
|
|||||||
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||||
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||||
return pooled
|
return pooled
|
||||||
|
|
||||||
|
|
||||||
class PadSplitXY(torch.autograd.Function):
|
|
||||||
"""
|
|
||||||
Merge heads, pad and extract visual and text tokens,
|
|
||||||
and split along the sequence length.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx,
|
|
||||||
xy: torch.Tensor,
|
|
||||||
indices: torch.Tensor,
|
|
||||||
B: int,
|
|
||||||
N: int,
|
|
||||||
L: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim).
|
|
||||||
indices: Valid token indices out of unpacked tensor. Shape: (total,)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
|
|
||||||
y: Text tokens. Shape: (B, L, num_heads * head_dim).
|
|
||||||
"""
|
|
||||||
ctx.save_for_backward(indices)
|
|
||||||
ctx.B, ctx.N, ctx.L = B, N, L
|
|
||||||
D = xy.size(1)
|
|
||||||
|
|
||||||
# Pad sequences to (B, N + L, dim).
|
|
||||||
assert indices.ndim == 1
|
|
||||||
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
|
|
||||||
indices = indices.unsqueeze(1).expand(
|
|
||||||
-1, D
|
|
||||||
) # (total,) -> (total, num_heads * head_dim)
|
|
||||||
output.scatter_(0, indices, xy)
|
|
||||||
xy = output.view(B, N + L, D)
|
|
||||||
|
|
||||||
# Split visual and text tokens along the sequence length.
|
|
||||||
return torch.tensor_split(xy, (N,), dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
return PadSplitXY.apply(xy, indices, B, N, L, dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class UnifyStreams(torch.autograd.Function):
|
|
||||||
"""Unify visual and text streams."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx,
|
|
||||||
q_x: torch.Tensor,
|
|
||||||
k_x: torch.Tensor,
|
|
||||||
v_x: torch.Tensor,
|
|
||||||
q_y: torch.Tensor,
|
|
||||||
k_y: torch.Tensor,
|
|
||||||
v_y: torch.Tensor,
|
|
||||||
indices: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
q_x: (B, N, num_heads, head_dim)
|
|
||||||
k_x: (B, N, num_heads, head_dim)
|
|
||||||
v_x: (B, N, num_heads, head_dim)
|
|
||||||
q_y: (B, L, num_heads, head_dim)
|
|
||||||
k_y: (B, L, num_heads, head_dim)
|
|
||||||
v_y: (B, L, num_heads, head_dim)
|
|
||||||
indices: (total <= B * (N + L))
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
qkv: (total <= B * (N + L), 3, num_heads, head_dim)
|
|
||||||
"""
|
|
||||||
ctx.save_for_backward(indices)
|
|
||||||
B, N, num_heads, head_dim = q_x.size()
|
|
||||||
ctx.B, ctx.N, ctx.L = B, N, q_y.size(1)
|
|
||||||
D = num_heads * head_dim
|
|
||||||
|
|
||||||
q = torch.cat([q_x, q_y], dim=1)
|
|
||||||
k = torch.cat([k_x, k_y], dim=1)
|
|
||||||
v = torch.cat([v_x, v_y], dim=1)
|
|
||||||
qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D)
|
|
||||||
|
|
||||||
indices = indices[:, None, None].expand(-1, 3, D)
|
|
||||||
qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim)
|
|
||||||
return qkv.unflatten(2, (num_heads, head_dim))
|
|
||||||
|
|
||||||
|
|
||||||
def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor:
|
|
||||||
return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices)
|
|
||||||
@ -1,152 +0,0 @@
|
|||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size
|
|
||||||
|
|
||||||
|
|
||||||
def cast_tuple(t, length=1):
|
|
||||||
return t if isinstance(t, tuple) else ((t,) * length)
|
|
||||||
|
|
||||||
|
|
||||||
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Forward pass that handles communication between ranks for inference.
|
|
||||||
Args:
|
|
||||||
x: Tensor of shape (B, C, T, H, W)
|
|
||||||
frames_to_send: int, number of frames to communicate between ranks
|
|
||||||
Returns:
|
|
||||||
output: Tensor of shape (B, C, T', H, W)
|
|
||||||
"""
|
|
||||||
cp_rank, cp_world_size = cp.get_cp_rank_size()
|
|
||||||
if frames_to_send == 0 or cp_world_size == 1:
|
|
||||||
return x
|
|
||||||
|
|
||||||
group = get_cp_group()
|
|
||||||
global_rank = dist.get_rank()
|
|
||||||
|
|
||||||
# Send to next rank
|
|
||||||
if cp_rank < cp_world_size - 1:
|
|
||||||
assert x.size(2) >= frames_to_send
|
|
||||||
tail = x[:, :, -frames_to_send:].contiguous()
|
|
||||||
dist.send(tail, global_rank + 1, group=group)
|
|
||||||
|
|
||||||
# Receive from previous rank
|
|
||||||
if cp_rank > 0:
|
|
||||||
B, C, _, H, W = x.shape
|
|
||||||
recv_buffer = torch.empty(
|
|
||||||
(B, C, frames_to_send, H, W),
|
|
||||||
dtype=x.dtype,
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
dist.recv(recv_buffer, global_rank - 1, group=group)
|
|
||||||
x = torch.cat([recv_buffer, x], dim=2)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
|
|
||||||
if max_T > x.size(2):
|
|
||||||
pad_T = max_T - x.size(2)
|
|
||||||
pad_dims = (0, 0, 0, 0, 0, pad_T)
|
|
||||||
return F.pad(x, pad_dims)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Gathers all frames from all processes for inference.
|
|
||||||
Args:
|
|
||||||
x: Tensor of shape (B, C, T, H, W)
|
|
||||||
Returns:
|
|
||||||
output: Tensor of shape (B, C, T_total, H, W)
|
|
||||||
"""
|
|
||||||
cp_rank, cp_size = get_cp_rank_size()
|
|
||||||
cp_group = get_cp_group()
|
|
||||||
|
|
||||||
# Ensure the tensor is contiguous for collective operations
|
|
||||||
x = x.contiguous()
|
|
||||||
|
|
||||||
# Get the local time dimension size
|
|
||||||
local_T = x.size(2)
|
|
||||||
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
|
|
||||||
|
|
||||||
# Gather all T sizes from all processes
|
|
||||||
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
|
|
||||||
dist.all_gather(all_T, local_T_tensor, group=cp_group)
|
|
||||||
all_T = [t.item() for t in all_T]
|
|
||||||
|
|
||||||
# Pad the tensor at the end of the time dimension to match max_T
|
|
||||||
max_T = max(all_T)
|
|
||||||
x = _pad_to_max(x, max_T).contiguous()
|
|
||||||
|
|
||||||
# Prepare a list to hold the gathered tensors
|
|
||||||
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
|
|
||||||
|
|
||||||
# Perform the all_gather operation
|
|
||||||
dist.all_gather(gathered_x, x, group=cp_group)
|
|
||||||
|
|
||||||
# Slice each gathered tensor back to its original T size
|
|
||||||
for idx, t_size in enumerate(all_T):
|
|
||||||
gathered_x[idx] = gathered_x[idx][:, :, :t_size]
|
|
||||||
|
|
||||||
return torch.cat(gathered_x, dim=2)
|
|
||||||
|
|
||||||
|
|
||||||
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
|
|
||||||
"""Estimate memory usage based on input tensor size and data type."""
|
|
||||||
element_size = input.element_size() # Size in bytes of each element
|
|
||||||
memory_bytes = input.numel() * element_size
|
|
||||||
memory_gb = memory_bytes / 1024**3
|
|
||||||
return memory_gb > max_gb
|
|
||||||
|
|
||||||
|
|
||||||
class ContextParallelCausalConv3d(torch.nn.Conv3d):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size: Union[int, Tuple[int, int, int]],
|
|
||||||
stride: Union[int, Tuple[int, int, int]],
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
kernel_size = cast_tuple(kernel_size, 3)
|
|
||||||
stride = cast_tuple(stride, 3)
|
|
||||||
height_pad = (kernel_size[1] - 1) // 2
|
|
||||||
width_pad = (kernel_size[2] - 1) // 2
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
dilation=(1, 1, 1),
|
|
||||||
padding=(0, height_pad, width_pad),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
cp_rank, cp_world_size = get_cp_rank_size()
|
|
||||||
|
|
||||||
context_size = self.kernel_size[0] - 1
|
|
||||||
if cp_rank == 0:
|
|
||||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
|
||||||
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
|
|
||||||
|
|
||||||
if cp_world_size == 1:
|
|
||||||
return super().forward(x)
|
|
||||||
|
|
||||||
if all(s == 1 for s in self.stride):
|
|
||||||
# Receive some frames from previous rank.
|
|
||||||
x = cp_pass_frames(x, context_size)
|
|
||||||
return super().forward(x)
|
|
||||||
|
|
||||||
# Less efficient implementation for strided convs.
|
|
||||||
# All gather x, infer and chunk.
|
|
||||||
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
|
|
||||||
x = super().forward(x)
|
|
||||||
x_chunks = x.tensor_split(cp_world_size, dim=2)
|
|
||||||
assert len(x_chunks) == cp_world_size
|
|
||||||
return x_chunks[cp_rank]
|
|
||||||
@ -6,8 +6,6 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
#from ..dit.joint_model.context_parallel import get_cp_rank_size
|
|
||||||
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
|
||||||
from .latent_dist import LatentDistribution
|
from .latent_dist import LatentDistribution
|
||||||
|
|
||||||
def cast_tuple(t, length=1):
|
def cast_tuple(t, length=1):
|
||||||
@ -96,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d):
|
|||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def mps_safe_pad(input, pad, mode):
|
||||||
|
if input.device.type == "mps" and input.numel() >= 2 ** 16:
|
||||||
|
device = input.device
|
||||||
|
input = input.to(device="cpu")
|
||||||
|
output = F.pad(input, pad, mode=mode)
|
||||||
|
return output.to(device=device)
|
||||||
|
else:
|
||||||
|
return F.pad(input, pad, mode=mode)
|
||||||
|
|
||||||
class ContextParallelConv3d(SafeConv3d):
|
class ContextParallelConv3d(SafeConv3d):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -138,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d):
|
|||||||
# Apply padding.
|
# Apply padding.
|
||||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||||
if self.context_parallel:
|
if self.context_parallel:
|
||||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||||
else:
|
else:
|
||||||
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||||
|
|
||||||
|
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|||||||
31
nodes.py
31
nodes.py
@ -59,7 +59,7 @@ class MochiSigmaSchedule:
|
|||||||
RETURN_NAMES = ("sigmas",)
|
RETURN_NAMES = ("sigmas",)
|
||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
DESCRIPTION = "Sigma schedule to use with mochi wrapper sampler"
|
||||||
|
|
||||||
def loadmodel(self, num_steps, threshold_noise, denoise, linear_steps=None):
|
def loadmodel(self, num_steps, threshold_noise, denoise, linear_steps=None):
|
||||||
total_steps = num_steps
|
total_steps = num_steps
|
||||||
@ -105,6 +105,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||||
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||||
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}),
|
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}),
|
||||||
|
"rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,7 +115,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
||||||
|
|
||||||
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
|
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -154,11 +155,11 @@ class DownloadAndLoadMochiModel:
|
|||||||
model = T2VSynthMochiModel(
|
model = T2VSynthMochiModel(
|
||||||
device=device,
|
device=device,
|
||||||
offload_device=offload_device,
|
offload_device=offload_device,
|
||||||
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
|
||||||
dit_checkpoint_path=model_path,
|
dit_checkpoint_path=model_path,
|
||||||
weight_dtype=dtype,
|
weight_dtype=dtype,
|
||||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||||
attention_mode=attention_mode,
|
attention_mode=attention_mode,
|
||||||
|
rms_norm_func=rms_norm_func,
|
||||||
compile_args=compile_args,
|
compile_args=compile_args,
|
||||||
cublas_ops=cublas_ops
|
cublas_ops=cublas_ops
|
||||||
)
|
)
|
||||||
@ -180,7 +181,7 @@ class DownloadAndLoadMochiModel:
|
|||||||
vae_sd = load_torch_file(vae_path)
|
vae_sd = load_torch_file(vae_path)
|
||||||
if is_accelerate_available:
|
if is_accelerate_available:
|
||||||
for key in vae_sd:
|
for key in vae_sd:
|
||||||
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
|
set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key])
|
||||||
else:
|
else:
|
||||||
vae.load_state_dict(vae_sd, strict=True)
|
vae.load_state_dict(vae_sd, strict=True)
|
||||||
vae.eval().to(torch.bfloat16).to("cpu")
|
vae.eval().to(torch.bfloat16).to("cpu")
|
||||||
@ -201,6 +202,7 @@ class MochiModelLoader:
|
|||||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||||
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||||
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}),
|
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}),
|
||||||
|
"rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}),
|
||||||
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -209,7 +211,7 @@ class MochiModelLoader:
|
|||||||
FUNCTION = "loadmodel"
|
FUNCTION = "loadmodel"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
|
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"):
|
||||||
|
|
||||||
device = mm.get_torch_device()
|
device = mm.get_torch_device()
|
||||||
offload_device = mm.unet_offload_device()
|
offload_device = mm.unet_offload_device()
|
||||||
@ -221,11 +223,11 @@ class MochiModelLoader:
|
|||||||
model = T2VSynthMochiModel(
|
model = T2VSynthMochiModel(
|
||||||
device=device,
|
device=device,
|
||||||
offload_device=offload_device,
|
offload_device=offload_device,
|
||||||
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
|
||||||
dit_checkpoint_path=model_path,
|
dit_checkpoint_path=model_path,
|
||||||
weight_dtype=dtype,
|
weight_dtype=dtype,
|
||||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||||
attention_mode=attention_mode,
|
attention_mode=attention_mode,
|
||||||
|
rms_norm_func=rms_norm_func,
|
||||||
compile_args=compile_args,
|
compile_args=compile_args,
|
||||||
cublas_ops=cublas_ops
|
cublas_ops=cublas_ops
|
||||||
)
|
)
|
||||||
@ -473,6 +475,7 @@ class MochiSampler:
|
|||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
|
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
|
||||||
|
mm.unload_all_models()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
if opt_sigmas is not None:
|
if opt_sigmas is not None:
|
||||||
@ -630,7 +633,7 @@ class MochiDecode:
|
|||||||
return torch.cat(result_rows, dim=3)
|
return torch.cat(result_rows, dim=3)
|
||||||
|
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
with torch.autocast(mm.get_autocast_device(device), dtype=vae.dtype):
|
||||||
if enable_vae_tiling and frame_batch_size > T:
|
if enable_vae_tiling and frame_batch_size > T:
|
||||||
logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}")
|
logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}")
|
||||||
frame_batch_size = T
|
frame_batch_size = T
|
||||||
@ -748,10 +751,16 @@ class MochiImageEncode:
|
|||||||
from .mochi_preview.vae.model import apply_tiled
|
from .mochi_preview.vae.model import apply_tiled
|
||||||
B, H, W, C = images.shape
|
B, H, W, C = images.shape
|
||||||
|
|
||||||
images = images.unsqueeze(0) * 2 - 1
|
import torchvision.transforms as transforms
|
||||||
images = rearrange(images, "t b h w c -> t c b h w")
|
normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||||
images = images.to(device)
|
input_image_tensor = rearrange(images, 'b h w c -> b c h w')
|
||||||
print(images.shape)
|
input_image_tensor = normalize(input_image_tensor).unsqueeze(0)
|
||||||
|
input_image_tensor = rearrange(input_image_tensor, 'b t c h w -> b c t h w', t=B)
|
||||||
|
|
||||||
|
#images = images.unsqueeze(0).sub_(0.5).div_(0.5)
|
||||||
|
#images = rearrange(input_image_tensor, "b c t h w -> t c b h w")
|
||||||
|
images = input_image_tensor.to(device)
|
||||||
|
|
||||||
encoder.to(device)
|
encoder.to(device)
|
||||||
print("images before encoding", images.shape)
|
print("images before encoding", images.shape)
|
||||||
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):
|
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user