update from main
This commit is contained in:
commit
3535a846a8
@ -1,4 +0,0 @@
|
||||
{
|
||||
"mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285],
|
||||
"std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041]
|
||||
}
|
||||
@ -1,163 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from einops import rearrange
|
||||
|
||||
_CONTEXT_PARALLEL_GROUP = None
|
||||
_CONTEXT_PARALLEL_RANK = None
|
||||
_CONTEXT_PARALLEL_GROUP_SIZE = None
|
||||
_CONTEXT_PARALLEL_GROUP_RANKS = None
|
||||
|
||||
|
||||
def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
return x
|
||||
|
||||
cp_rank, cp_size = get_cp_rank_size()
|
||||
return x.tensor_split(cp_size, dim=dim)[cp_rank]
|
||||
|
||||
|
||||
def set_cp_group(cp_group, ranks, global_rank):
|
||||
global \
|
||||
_CONTEXT_PARALLEL_GROUP, \
|
||||
_CONTEXT_PARALLEL_RANK, \
|
||||
_CONTEXT_PARALLEL_GROUP_SIZE, \
|
||||
_CONTEXT_PARALLEL_GROUP_RANKS
|
||||
if _CONTEXT_PARALLEL_GROUP is not None:
|
||||
raise RuntimeError("CP group already initialized.")
|
||||
_CONTEXT_PARALLEL_GROUP = cp_group
|
||||
_CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
|
||||
_CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
|
||||
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
|
||||
|
||||
assert (
|
||||
_CONTEXT_PARALLEL_RANK == ranks.index(global_rank)
|
||||
), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
|
||||
assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
|
||||
ranks
|
||||
), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
|
||||
|
||||
|
||||
def get_cp_group():
|
||||
if _CONTEXT_PARALLEL_GROUP is None:
|
||||
raise RuntimeError("CP group not initialized")
|
||||
return _CONTEXT_PARALLEL_GROUP
|
||||
|
||||
|
||||
def is_cp_active():
|
||||
return _CONTEXT_PARALLEL_GROUP is not None
|
||||
|
||||
|
||||
def get_cp_rank_size():
|
||||
if _CONTEXT_PARALLEL_GROUP:
|
||||
return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
|
||||
else:
|
||||
return 0, 1
|
||||
|
||||
|
||||
class AllGatherIntoTensorFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
|
||||
ctx.reduce_dtype = reduce_dtype
|
||||
ctx.group = group
|
||||
ctx.batch_size = x.size(0)
|
||||
group_size = dist.get_world_size(group)
|
||||
|
||||
x = x.contiguous()
|
||||
output = torch.empty(
|
||||
group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device
|
||||
)
|
||||
dist.all_gather_into_tensor(output, x, group=group)
|
||||
return output
|
||||
|
||||
|
||||
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
return tensor
|
||||
|
||||
return AllGatherIntoTensorFunction.apply(
|
||||
tensor, torch.float32, _CONTEXT_PARALLEL_GROUP
|
||||
)
|
||||
|
||||
|
||||
@torch.compiler.disable()
|
||||
def _all_to_all_single(output, input, group):
|
||||
# Disable compilation since torch compile changes contiguity.
|
||||
assert input.is_contiguous(), "Input tensor must be contiguous."
|
||||
assert output.is_contiguous(), "Output tensor must be contiguous."
|
||||
return dist.all_to_all_single(output, input, group=group)
|
||||
|
||||
|
||||
class CollectTokens(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
|
||||
"""Redistribute heads and receive tokens.
|
||||
|
||||
Args:
|
||||
qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
|
||||
|
||||
Returns:
|
||||
qkv: shape: [3, B, N, local_heads, head_dim]
|
||||
|
||||
where M is the number of local tokens,
|
||||
N = cp_size * M is the number of global tokens,
|
||||
local_heads = num_heads // cp_size is the number of local heads.
|
||||
"""
|
||||
ctx.group = group
|
||||
ctx.num_heads = num_heads
|
||||
cp_size = dist.get_world_size(group)
|
||||
assert num_heads % cp_size == 0
|
||||
ctx.local_heads = num_heads // cp_size
|
||||
|
||||
qkv = rearrange(
|
||||
qkv,
|
||||
"B M (qkv G h d) -> G M h B (qkv d)",
|
||||
qkv=3,
|
||||
G=cp_size,
|
||||
h=ctx.local_heads,
|
||||
).contiguous()
|
||||
|
||||
output_chunks = torch.empty_like(qkv)
|
||||
_all_to_all_single(output_chunks, qkv, group=group)
|
||||
|
||||
return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
|
||||
|
||||
|
||||
def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
# Move QKV dimension to the front.
|
||||
# B M (3 H d) -> 3 B M H d
|
||||
B, M, _ = x.size()
|
||||
x = x.view(B, M, 3, num_heads, -1)
|
||||
return x.permute(2, 0, 1, 3, 4)
|
||||
|
||||
return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
|
||||
|
||||
|
||||
class CollectHeads(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
|
||||
"""Redistribute tokens and receive heads.
|
||||
|
||||
Args:
|
||||
x: Output of attention. Shape: [B, N, local_heads, head_dim]
|
||||
|
||||
Returns:
|
||||
Shape: [B, M, num_heads * head_dim]
|
||||
"""
|
||||
ctx.group = group
|
||||
ctx.local_heads = x.size(2)
|
||||
ctx.head_dim = x.size(3)
|
||||
group_size = dist.get_world_size(group)
|
||||
x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
|
||||
output = torch.empty_like(x)
|
||||
_all_to_all_single(output, x, group=group)
|
||||
del x
|
||||
return rearrange(output, "G h M B D -> B M (G h D)")
|
||||
|
||||
|
||||
def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
# Merge heads.
|
||||
return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
|
||||
|
||||
return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)
|
||||
@ -62,28 +62,6 @@ class TimestepEmbedder(nn.Module):
|
||||
return t_emb
|
||||
|
||||
|
||||
class PooledCaptionEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
caption_feature_dim: int,
|
||||
hidden_size: int,
|
||||
*,
|
||||
bias: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.caption_feature_dim = caption_feature_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -152,8 +130,6 @@ class PatchEmbed(nn.Module):
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||
#print("x",x.dtype, x.device)
|
||||
#print(self.proj.weight.dtype, self.proj.weight.device)
|
||||
x = self.proj(x)
|
||||
|
||||
# Flatten temporal and spatial dimensions.
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import functools
|
||||
#import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
@ -21,7 +21,7 @@ def centers(start: float, stop, num, dtype=None, device=None):
|
||||
return (edges[:-1] + edges[1:]) / 2
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
#@functools.lru_cache(maxsize=1)
|
||||
def create_position_matrix(
|
||||
T: int,
|
||||
pH: int,
|
||||
|
||||
@ -28,95 +28,4 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
|
||||
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||
return pooled
|
||||
|
||||
|
||||
class PadSplitXY(torch.autograd.Function):
|
||||
"""
|
||||
Merge heads, pad and extract visual and text tokens,
|
||||
and split along the sequence length.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
xy: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
B: int,
|
||||
N: int,
|
||||
L: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim).
|
||||
indices: Valid token indices out of unpacked tensor. Shape: (total,)
|
||||
|
||||
Returns:
|
||||
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
|
||||
y: Text tokens. Shape: (B, L, num_heads * head_dim).
|
||||
"""
|
||||
ctx.save_for_backward(indices)
|
||||
ctx.B, ctx.N, ctx.L = B, N, L
|
||||
D = xy.size(1)
|
||||
|
||||
# Pad sequences to (B, N + L, dim).
|
||||
assert indices.ndim == 1
|
||||
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
|
||||
indices = indices.unsqueeze(1).expand(
|
||||
-1, D
|
||||
) # (total,) -> (total, num_heads * head_dim)
|
||||
output.scatter_(0, indices, xy)
|
||||
xy = output.view(B, N + L, D)
|
||||
|
||||
# Split visual and text tokens along the sequence length.
|
||||
return torch.tensor_split(xy, (N,), dim=1)
|
||||
|
||||
|
||||
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return PadSplitXY.apply(xy, indices, B, N, L, dtype)
|
||||
|
||||
|
||||
class UnifyStreams(torch.autograd.Function):
|
||||
"""Unify visual and text streams."""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q_x: torch.Tensor,
|
||||
k_x: torch.Tensor,
|
||||
v_x: torch.Tensor,
|
||||
q_y: torch.Tensor,
|
||||
k_y: torch.Tensor,
|
||||
v_y: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
q_x: (B, N, num_heads, head_dim)
|
||||
k_x: (B, N, num_heads, head_dim)
|
||||
v_x: (B, N, num_heads, head_dim)
|
||||
q_y: (B, L, num_heads, head_dim)
|
||||
k_y: (B, L, num_heads, head_dim)
|
||||
v_y: (B, L, num_heads, head_dim)
|
||||
indices: (total <= B * (N + L))
|
||||
|
||||
Returns:
|
||||
qkv: (total <= B * (N + L), 3, num_heads, head_dim)
|
||||
"""
|
||||
ctx.save_for_backward(indices)
|
||||
B, N, num_heads, head_dim = q_x.size()
|
||||
ctx.B, ctx.N, ctx.L = B, N, q_y.size(1)
|
||||
D = num_heads * head_dim
|
||||
|
||||
q = torch.cat([q_x, q_y], dim=1)
|
||||
k = torch.cat([k_x, k_y], dim=1)
|
||||
v = torch.cat([v_x, v_y], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D)
|
||||
|
||||
indices = indices[:, None, None].expand(-1, 3, D)
|
||||
qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim)
|
||||
return qkv.unflatten(2, (num_heads, head_dim))
|
||||
|
||||
|
||||
def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor:
|
||||
return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices)
|
||||
|
||||
@ -1,152 +0,0 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that handles communication between ranks for inference.
|
||||
Args:
|
||||
x: Tensor of shape (B, C, T, H, W)
|
||||
frames_to_send: int, number of frames to communicate between ranks
|
||||
Returns:
|
||||
output: Tensor of shape (B, C, T', H, W)
|
||||
"""
|
||||
cp_rank, cp_world_size = cp.get_cp_rank_size()
|
||||
if frames_to_send == 0 or cp_world_size == 1:
|
||||
return x
|
||||
|
||||
group = get_cp_group()
|
||||
global_rank = dist.get_rank()
|
||||
|
||||
# Send to next rank
|
||||
if cp_rank < cp_world_size - 1:
|
||||
assert x.size(2) >= frames_to_send
|
||||
tail = x[:, :, -frames_to_send:].contiguous()
|
||||
dist.send(tail, global_rank + 1, group=group)
|
||||
|
||||
# Receive from previous rank
|
||||
if cp_rank > 0:
|
||||
B, C, _, H, W = x.shape
|
||||
recv_buffer = torch.empty(
|
||||
(B, C, frames_to_send, H, W),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
dist.recv(recv_buffer, global_rank - 1, group=group)
|
||||
x = torch.cat([recv_buffer, x], dim=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
|
||||
if max_T > x.size(2):
|
||||
pad_T = max_T - x.size(2)
|
||||
pad_dims = (0, 0, 0, 0, 0, pad_T)
|
||||
return F.pad(x, pad_dims)
|
||||
return x
|
||||
|
||||
|
||||
def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Gathers all frames from all processes for inference.
|
||||
Args:
|
||||
x: Tensor of shape (B, C, T, H, W)
|
||||
Returns:
|
||||
output: Tensor of shape (B, C, T_total, H, W)
|
||||
"""
|
||||
cp_rank, cp_size = get_cp_rank_size()
|
||||
cp_group = get_cp_group()
|
||||
|
||||
# Ensure the tensor is contiguous for collective operations
|
||||
x = x.contiguous()
|
||||
|
||||
# Get the local time dimension size
|
||||
local_T = x.size(2)
|
||||
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
|
||||
|
||||
# Gather all T sizes from all processes
|
||||
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
|
||||
dist.all_gather(all_T, local_T_tensor, group=cp_group)
|
||||
all_T = [t.item() for t in all_T]
|
||||
|
||||
# Pad the tensor at the end of the time dimension to match max_T
|
||||
max_T = max(all_T)
|
||||
x = _pad_to_max(x, max_T).contiguous()
|
||||
|
||||
# Prepare a list to hold the gathered tensors
|
||||
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
|
||||
|
||||
# Perform the all_gather operation
|
||||
dist.all_gather(gathered_x, x, group=cp_group)
|
||||
|
||||
# Slice each gathered tensor back to its original T size
|
||||
for idx, t_size in enumerate(all_T):
|
||||
gathered_x[idx] = gathered_x[idx][:, :, :t_size]
|
||||
|
||||
return torch.cat(gathered_x, dim=2)
|
||||
|
||||
|
||||
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
|
||||
"""Estimate memory usage based on input tensor size and data type."""
|
||||
element_size = input.element_size() # Size in bytes of each element
|
||||
memory_bytes = input.numel() * element_size
|
||||
memory_gb = memory_bytes / 1024**3
|
||||
return memory_gb > max_gb
|
||||
|
||||
|
||||
class ContextParallelCausalConv3d(torch.nn.Conv3d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
**kwargs,
|
||||
):
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
stride = cast_tuple(stride, 3)
|
||||
height_pad = (kernel_size[1] - 1) // 2
|
||||
width_pad = (kernel_size[2] - 1) // 2
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=(1, 1, 1),
|
||||
padding=(0, height_pad, width_pad),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
cp_rank, cp_world_size = get_cp_rank_size()
|
||||
|
||||
context_size = self.kernel_size[0] - 1
|
||||
if cp_rank == 0:
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
|
||||
|
||||
if cp_world_size == 1:
|
||||
return super().forward(x)
|
||||
|
||||
if all(s == 1 for s in self.stride):
|
||||
# Receive some frames from previous rank.
|
||||
x = cp_pass_frames(x, context_size)
|
||||
return super().forward(x)
|
||||
|
||||
# Less efficient implementation for strided convs.
|
||||
# All gather x, infer and chunk.
|
||||
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
|
||||
x = super().forward(x)
|
||||
x_chunks = x.tensor_split(cp_world_size, dim=2)
|
||||
assert len(x_chunks) == cp_world_size
|
||||
return x_chunks[cp_rank]
|
||||
@ -6,8 +6,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
#from ..dit.joint_model.context_parallel import get_cp_rank_size
|
||||
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||
from .latent_dist import LatentDistribution
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
@ -96,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def mps_safe_pad(input, pad, mode):
|
||||
if input.device.type == "mps" and input.numel() >= 2 ** 16:
|
||||
device = input.device
|
||||
input = input.to(device="cpu")
|
||||
output = F.pad(input, pad, mode=mode)
|
||||
return output.to(device=device)
|
||||
else:
|
||||
return F.pad(input, pad, mode=mode)
|
||||
|
||||
class ContextParallelConv3d(SafeConv3d):
|
||||
def __init__(
|
||||
@ -138,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d):
|
||||
# Apply padding.
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
if self.context_parallel:
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
|
||||
else:
|
||||
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
|
||||
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
31
nodes.py
31
nodes.py
@ -59,7 +59,7 @@ class MochiSigmaSchedule:
|
||||
RETURN_NAMES = ("sigmas",)
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "MochiWrapper"
|
||||
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
|
||||
DESCRIPTION = "Sigma schedule to use with mochi wrapper sampler"
|
||||
|
||||
def loadmodel(self, num_steps, threshold_noise, denoise, linear_steps=None):
|
||||
total_steps = num_steps
|
||||
@ -105,6 +105,7 @@ class DownloadAndLoadMochiModel:
|
||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}),
|
||||
"rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}),
|
||||
},
|
||||
}
|
||||
|
||||
@ -114,7 +115,7 @@ class DownloadAndLoadMochiModel:
|
||||
CATEGORY = "MochiWrapper"
|
||||
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
|
||||
|
||||
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
|
||||
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -154,11 +155,11 @@ class DownloadAndLoadMochiModel:
|
||||
model = T2VSynthMochiModel(
|
||||
device=device,
|
||||
offload_device=offload_device,
|
||||
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
||||
dit_checkpoint_path=model_path,
|
||||
weight_dtype=dtype,
|
||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||
attention_mode=attention_mode,
|
||||
rms_norm_func=rms_norm_func,
|
||||
compile_args=compile_args,
|
||||
cublas_ops=cublas_ops
|
||||
)
|
||||
@ -180,7 +181,7 @@ class DownloadAndLoadMochiModel:
|
||||
vae_sd = load_torch_file(vae_path)
|
||||
if is_accelerate_available:
|
||||
for key in vae_sd:
|
||||
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
|
||||
set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key])
|
||||
else:
|
||||
vae.load_state_dict(vae_sd, strict=True)
|
||||
vae.eval().to(torch.bfloat16).to("cpu")
|
||||
@ -201,6 +202,7 @@ class MochiModelLoader:
|
||||
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
|
||||
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
|
||||
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}),
|
||||
"rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}),
|
||||
|
||||
},
|
||||
}
|
||||
@ -209,7 +211,7 @@ class MochiModelLoader:
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
|
||||
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"):
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
@ -221,11 +223,11 @@ class MochiModelLoader:
|
||||
model = T2VSynthMochiModel(
|
||||
device=device,
|
||||
offload_device=offload_device,
|
||||
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
|
||||
dit_checkpoint_path=model_path,
|
||||
weight_dtype=dtype,
|
||||
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
|
||||
attention_mode=attention_mode,
|
||||
rms_norm_func=rms_norm_func,
|
||||
compile_args=compile_args,
|
||||
cublas_ops=cublas_ops
|
||||
)
|
||||
@ -473,6 +475,7 @@ class MochiSampler:
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
if opt_sigmas is not None:
|
||||
@ -630,7 +633,7 @@ class MochiDecode:
|
||||
return torch.cat(result_rows, dim=3)
|
||||
|
||||
vae.to(device)
|
||||
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
|
||||
with torch.autocast(mm.get_autocast_device(device), dtype=vae.dtype):
|
||||
if enable_vae_tiling and frame_batch_size > T:
|
||||
logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}")
|
||||
frame_batch_size = T
|
||||
@ -748,10 +751,16 @@ class MochiImageEncode:
|
||||
from .mochi_preview.vae.model import apply_tiled
|
||||
B, H, W, C = images.shape
|
||||
|
||||
images = images.unsqueeze(0) * 2 - 1
|
||||
images = rearrange(images, "t b h w c -> t c b h w")
|
||||
images = images.to(device)
|
||||
print(images.shape)
|
||||
import torchvision.transforms as transforms
|
||||
normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
input_image_tensor = rearrange(images, 'b h w c -> b c h w')
|
||||
input_image_tensor = normalize(input_image_tensor).unsqueeze(0)
|
||||
input_image_tensor = rearrange(input_image_tensor, 'b t c h w -> b c t h w', t=B)
|
||||
|
||||
#images = images.unsqueeze(0).sub_(0.5).div_(0.5)
|
||||
#images = rearrange(input_image_tensor, "b c t h w -> t c b h w")
|
||||
images = input_image_tensor.to(device)
|
||||
|
||||
encoder.to(device)
|
||||
print("images before encoding", images.shape)
|
||||
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user