update from main

This commit is contained in:
kijai 2024-11-05 07:08:14 +02:00
commit 3535a846a8
8 changed files with 33 additions and 452 deletions

View File

@ -1,4 +0,0 @@
{
"mean": [-0.06730895953510081, -0.038011381506090416, -0.07477820912866141, -0.05565264470995561, 0.012767231469026969, -0.04703542746246419, 0.043896967884726704, -0.09346305707025976, -0.09918314763016893, -0.008729793427399178, -0.011931556316503654, -0.0321993391887285],
"std": [0.9263795028493863, 0.9248894543193766, 0.9393059390890617, 0.959253732819592, 0.8244560132752793, 0.917259975397747, 0.9294154431013696, 1.3720942357788521, 0.881393668867029, 0.9168315692124348, 0.9185249279345552, 0.9274757570805041]
}

View File

@ -1,163 +0,0 @@
import torch
import torch.distributed as dist
from einops import rearrange
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_RANK = None
_CONTEXT_PARALLEL_GROUP_SIZE = None
_CONTEXT_PARALLEL_GROUP_RANKS = None
def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
return x
cp_rank, cp_size = get_cp_rank_size()
return x.tensor_split(cp_size, dim=dim)[cp_rank]
def set_cp_group(cp_group, ranks, global_rank):
global \
_CONTEXT_PARALLEL_GROUP, \
_CONTEXT_PARALLEL_RANK, \
_CONTEXT_PARALLEL_GROUP_SIZE, \
_CONTEXT_PARALLEL_GROUP_RANKS
if _CONTEXT_PARALLEL_GROUP is not None:
raise RuntimeError("CP group already initialized.")
_CONTEXT_PARALLEL_GROUP = cp_group
_CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
_CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
assert (
_CONTEXT_PARALLEL_RANK == ranks.index(global_rank)
), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
ranks
), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
def get_cp_group():
if _CONTEXT_PARALLEL_GROUP is None:
raise RuntimeError("CP group not initialized")
return _CONTEXT_PARALLEL_GROUP
def is_cp_active():
return _CONTEXT_PARALLEL_GROUP is not None
def get_cp_rank_size():
if _CONTEXT_PARALLEL_GROUP:
return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
else:
return 0, 1
class AllGatherIntoTensorFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
ctx.reduce_dtype = reduce_dtype
ctx.group = group
ctx.batch_size = x.size(0)
group_size = dist.get_world_size(group)
x = x.contiguous()
output = torch.empty(
group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device
)
dist.all_gather_into_tensor(output, x, group=group)
return output
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
return tensor
return AllGatherIntoTensorFunction.apply(
tensor, torch.float32, _CONTEXT_PARALLEL_GROUP
)
@torch.compiler.disable()
def _all_to_all_single(output, input, group):
# Disable compilation since torch compile changes contiguity.
assert input.is_contiguous(), "Input tensor must be contiguous."
assert output.is_contiguous(), "Output tensor must be contiguous."
return dist.all_to_all_single(output, input, group=group)
class CollectTokens(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
"""Redistribute heads and receive tokens.
Args:
qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
Returns:
qkv: shape: [3, B, N, local_heads, head_dim]
where M is the number of local tokens,
N = cp_size * M is the number of global tokens,
local_heads = num_heads // cp_size is the number of local heads.
"""
ctx.group = group
ctx.num_heads = num_heads
cp_size = dist.get_world_size(group)
assert num_heads % cp_size == 0
ctx.local_heads = num_heads // cp_size
qkv = rearrange(
qkv,
"B M (qkv G h d) -> G M h B (qkv d)",
qkv=3,
G=cp_size,
h=ctx.local_heads,
).contiguous()
output_chunks = torch.empty_like(qkv)
_all_to_all_single(output_chunks, qkv, group=group)
return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
# Move QKV dimension to the front.
# B M (3 H d) -> 3 B M H d
B, M, _ = x.size()
x = x.view(B, M, 3, num_heads, -1)
return x.permute(2, 0, 1, 3, 4)
return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
class CollectHeads(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
"""Redistribute tokens and receive heads.
Args:
x: Output of attention. Shape: [B, N, local_heads, head_dim]
Returns:
Shape: [B, M, num_heads * head_dim]
"""
ctx.group = group
ctx.local_heads = x.size(2)
ctx.head_dim = x.size(3)
group_size = dist.get_world_size(group)
x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
output = torch.empty_like(x)
_all_to_all_single(output, x, group=group)
del x
return rearrange(output, "G h M B D -> B M (G h D)")
def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
# Merge heads.
return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)

View File

@ -62,28 +62,6 @@ class TimestepEmbedder(nn.Module):
return t_emb
class PooledCaptionEmbedder(nn.Module):
def __init__(
self,
caption_feature_dim: int,
hidden_size: int,
*,
bias: bool = True,
device: Optional[torch.device] = None,
):
super().__init__()
self.caption_feature_dim = caption_feature_dim
self.hidden_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
)
def forward(self, x):
return self.mlp(x)
class FeedForward(nn.Module):
def __init__(
self,
@ -152,8 +130,6 @@ class PatchEmbed(nn.Module):
x = F.pad(x, (0, pad_w, 0, pad_h))
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
#print("x",x.dtype, x.device)
#print(self.proj.weight.dtype, self.proj.weight.device)
x = self.proj(x)
# Flatten temporal and spatial dimensions.

View File

@ -1,4 +1,4 @@
import functools
#import functools
import math
import torch
@ -21,7 +21,7 @@ def centers(start: float, stop, num, dtype=None, device=None):
return (edges[:-1] + edges[1:]) / 2
@functools.lru_cache(maxsize=1)
#@functools.lru_cache(maxsize=1)
def create_position_matrix(
T: int,
pH: int,

View File

@ -29,94 +29,3 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
class PadSplitXY(torch.autograd.Function):
"""
Merge heads, pad and extract visual and text tokens,
and split along the sequence length.
"""
@staticmethod
def forward(
ctx,
xy: torch.Tensor,
indices: torch.Tensor,
B: int,
N: int,
L: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim).
indices: Valid token indices out of unpacked tensor. Shape: (total,)
Returns:
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
y: Text tokens. Shape: (B, L, num_heads * head_dim).
"""
ctx.save_for_backward(indices)
ctx.B, ctx.N, ctx.L = B, N, L
D = xy.size(1)
# Pad sequences to (B, N + L, dim).
assert indices.ndim == 1
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
indices = indices.unsqueeze(1).expand(
-1, D
) # (total,) -> (total, num_heads * head_dim)
output.scatter_(0, indices, xy)
xy = output.view(B, N + L, D)
# Split visual and text tokens along the sequence length.
return torch.tensor_split(xy, (N,), dim=1)
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
return PadSplitXY.apply(xy, indices, B, N, L, dtype)
class UnifyStreams(torch.autograd.Function):
"""Unify visual and text streams."""
@staticmethod
def forward(
ctx,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_y: torch.Tensor,
k_y: torch.Tensor,
v_y: torch.Tensor,
indices: torch.Tensor,
):
"""
Args:
q_x: (B, N, num_heads, head_dim)
k_x: (B, N, num_heads, head_dim)
v_x: (B, N, num_heads, head_dim)
q_y: (B, L, num_heads, head_dim)
k_y: (B, L, num_heads, head_dim)
v_y: (B, L, num_heads, head_dim)
indices: (total <= B * (N + L))
Returns:
qkv: (total <= B * (N + L), 3, num_heads, head_dim)
"""
ctx.save_for_backward(indices)
B, N, num_heads, head_dim = q_x.size()
ctx.B, ctx.N, ctx.L = B, N, q_y.size(1)
D = num_heads * head_dim
q = torch.cat([q_x, q_y], dim=1)
k = torch.cat([k_x, k_y], dim=1)
v = torch.cat([v_x, v_y], dim=1)
qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D)
indices = indices[:, None, None].expand(-1, 3, D)
qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim)
return qkv.unflatten(2, (num_heads, head_dim))
def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor:
return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices)

View File

@ -1,152 +0,0 @@
from typing import Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
"""
Forward pass that handles communication between ranks for inference.
Args:
x: Tensor of shape (B, C, T, H, W)
frames_to_send: int, number of frames to communicate between ranks
Returns:
output: Tensor of shape (B, C, T', H, W)
"""
cp_rank, cp_world_size = cp.get_cp_rank_size()
if frames_to_send == 0 or cp_world_size == 1:
return x
group = get_cp_group()
global_rank = dist.get_rank()
# Send to next rank
if cp_rank < cp_world_size - 1:
assert x.size(2) >= frames_to_send
tail = x[:, :, -frames_to_send:].contiguous()
dist.send(tail, global_rank + 1, group=group)
# Receive from previous rank
if cp_rank > 0:
B, C, _, H, W = x.shape
recv_buffer = torch.empty(
(B, C, frames_to_send, H, W),
dtype=x.dtype,
device=x.device,
)
dist.recv(recv_buffer, global_rank - 1, group=group)
x = torch.cat([recv_buffer, x], dim=2)
return x
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
if max_T > x.size(2):
pad_T = max_T - x.size(2)
pad_dims = (0, 0, 0, 0, 0, pad_T)
return F.pad(x, pad_dims)
return x
def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
"""
Gathers all frames from all processes for inference.
Args:
x: Tensor of shape (B, C, T, H, W)
Returns:
output: Tensor of shape (B, C, T_total, H, W)
"""
cp_rank, cp_size = get_cp_rank_size()
cp_group = get_cp_group()
# Ensure the tensor is contiguous for collective operations
x = x.contiguous()
# Get the local time dimension size
local_T = x.size(2)
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
# Gather all T sizes from all processes
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
dist.all_gather(all_T, local_T_tensor, group=cp_group)
all_T = [t.item() for t in all_T]
# Pad the tensor at the end of the time dimension to match max_T
max_T = max(all_T)
x = _pad_to_max(x, max_T).contiguous()
# Prepare a list to hold the gathered tensors
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
# Perform the all_gather operation
dist.all_gather(gathered_x, x, group=cp_group)
# Slice each gathered tensor back to its original T size
for idx, t_size in enumerate(all_T):
gathered_x[idx] = gathered_x[idx][:, :, :t_size]
return torch.cat(gathered_x, dim=2)
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
"""Estimate memory usage based on input tensor size and data type."""
element_size = input.element_size() # Size in bytes of each element
memory_bytes = input.numel() * element_size
memory_gb = memory_bytes / 1024**3
return memory_gb > max_gb
class ContextParallelCausalConv3d(torch.nn.Conv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
**kwargs,
):
kernel_size = cast_tuple(kernel_size, 3)
stride = cast_tuple(stride, 3)
height_pad = (kernel_size[1] - 1) // 2
width_pad = (kernel_size[2] - 1) // 2
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=(1, 1, 1),
padding=(0, height_pad, width_pad),
**kwargs,
)
def forward(self, x: torch.Tensor):
cp_rank, cp_world_size = get_cp_rank_size()
context_size = self.kernel_size[0] - 1
if cp_rank == 0:
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
if cp_world_size == 1:
return super().forward(x)
if all(s == 1 for s in self.stride):
# Receive some frames from previous rank.
x = cp_pass_frames(x, context_size)
return super().forward(x)
# Less efficient implementation for strided convs.
# All gather x, infer and chunk.
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
x = super().forward(x)
x_chunks = x.tensor_split(cp_world_size, dim=2)
assert len(x_chunks) == cp_world_size
return x_chunks[cp_rank]

View File

@ -6,8 +6,6 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
#from ..dit.joint_model.context_parallel import get_cp_rank_size
#from ..vae.cp_conv import cp_pass_frames, gather_all_frames
from .latent_dist import LatentDistribution
def cast_tuple(t, length=1):
@ -96,6 +94,14 @@ class StridedSafeConv3d(torch.nn.Conv3d):
raise NotImplementedError
def mps_safe_pad(input, pad, mode):
if input.device.type == "mps" and input.numel() >= 2 ** 16:
device = input.device
input = input.to(device="cpu")
output = F.pad(input, pad, mode=mode)
return output.to(device=device)
else:
return F.pad(input, pad, mode=mode)
class ContextParallelConv3d(SafeConv3d):
def __init__(
@ -138,9 +144,9 @@ class ContextParallelConv3d(SafeConv3d):
# Apply padding.
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
if self.context_parallel:
x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
else:
x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
x = mps_safe_pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
return super().forward(x)

View File

@ -59,7 +59,7 @@ class MochiSigmaSchedule:
RETURN_NAMES = ("sigmas",)
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"
DESCRIPTION = "Sigma schedule to use with mochi wrapper sampler"
def loadmodel(self, num_steps, threshold_noise, denoise, linear_steps=None):
total_steps = num_steps
@ -105,6 +105,7 @@ class DownloadAndLoadMochiModel:
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}),
"rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}),
},
}
@ -114,7 +115,7 @@ class DownloadAndLoadMochiModel:
CATEGORY = "MochiWrapper"
DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface"
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
def loadmodel(self, model, vae, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -154,11 +155,11 @@ class DownloadAndLoadMochiModel:
model = T2VSynthMochiModel(
device=device,
offload_device=offload_device,
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
dit_checkpoint_path=model_path,
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
compile_args=compile_args,
cublas_ops=cublas_ops
)
@ -180,7 +181,7 @@ class DownloadAndLoadMochiModel:
vae_sd = load_torch_file(vae_path)
if is_accelerate_available:
for key in vae_sd:
set_module_tensor_to_device(vae, key, dtype=torch.float32, device=device, value=vae_sd[key])
set_module_tensor_to_device(vae, key, dtype=torch.bfloat16, device=offload_device, value=vae_sd[key])
else:
vae.load_state_dict(vae_sd, strict=True)
vae.eval().to(torch.bfloat16).to("cpu")
@ -201,6 +202,7 @@ class MochiModelLoader:
"trigger": ("CONDITIONING", {"tooltip": "Dummy input for forcing execution order",}),
"compile_args": ("MOCHICOMPILEARGS", {"tooltip": "Optional torch.compile arguments",}),
"cublas_ops": ("BOOLEAN", {"tooltip": "tested on 4090, unsure of gpu requirements, enables faster linear ops for the GGUF models, for more info:'https://github.com/aredden/torch-cublas-hgemm'",}),
"rms_norm_func": (["default", "flash_attn_triton", "flash_attn", "apex"],{"tooltip": "RMSNorm function to use, flash_attn if available seems to be faster, apex untested",}),
},
}
@ -209,7 +211,7 @@ class MochiModelLoader:
FUNCTION = "loadmodel"
CATEGORY = "MochiWrapper"
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False):
def loadmodel(self, model_name, precision, attention_mode, trigger=None, compile_args=None, cublas_ops=False, rms_norm_func="default"):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
@ -221,11 +223,11 @@ class MochiModelLoader:
model = T2VSynthMochiModel(
device=device,
offload_device=offload_device,
vae_stats_path=os.path.join(script_directory, "configs", "vae_stats.json"),
dit_checkpoint_path=model_path,
weight_dtype=dtype,
fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
compile_args=compile_args,
cublas_ops=cublas_ops
)
@ -473,6 +475,7 @@ class MochiSampler:
CATEGORY = "MochiWrapper"
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
mm.unload_all_models()
mm.soft_empty_cache()
if opt_sigmas is not None:
@ -630,7 +633,7 @@ class MochiDecode:
return torch.cat(result_rows, dim=3)
vae.to(device)
with torch.autocast(mm.get_autocast_device(device), dtype=torch.bfloat16):
with torch.autocast(mm.get_autocast_device(device), dtype=vae.dtype):
if enable_vae_tiling and frame_batch_size > T:
logging.warning(f"Frame batch size is larger than the number of samples, setting to {T}")
frame_batch_size = T
@ -748,10 +751,16 @@ class MochiImageEncode:
from .mochi_preview.vae.model import apply_tiled
B, H, W, C = images.shape
images = images.unsqueeze(0) * 2 - 1
images = rearrange(images, "t b h w c -> t c b h w")
images = images.to(device)
print(images.shape)
import torchvision.transforms as transforms
normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
input_image_tensor = rearrange(images, 'b h w c -> b c h w')
input_image_tensor = normalize(input_image_tensor).unsqueeze(0)
input_image_tensor = rearrange(input_image_tensor, 'b t c h w -> b c t h w', t=B)
#images = images.unsqueeze(0).sub_(0.5).div_(0.5)
#images = rearrange(input_image_tensor, "b c t h w -> t c b h w")
images = input_image_tensor.to(device)
encoder.to(device)
print("images before encoding", images.shape)
with torch.autocast(mm.get_autocast_device(device), dtype=encoder.dtype):