update, works

This commit is contained in:
kijai 2024-11-05 08:41:50 +02:00
parent 3535a846a8
commit 24a7edfca6
3 changed files with 225 additions and 275 deletions

View File

@ -3,31 +3,22 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .layers import (
FeedForward,
PatchEmbed,
RMSNorm,
TimestepEmbedder,
)
from .mod_rmsnorm import modulated_rmsnorm
from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm)
from .rope_mixed import (
compute_mixed_rotation,
create_position_matrix,
)
from .residual_tanh_gated_rmsnorm import residual_tanh_gated_rmsnorm
from .rope_mixed import compute_mixed_rotation, create_position_matrix
from .temporal_rope import apply_rotary_emb_qk_real
from .utils import (
pool_tokens,
modulate,
pad_and_split_xy,
unify_streams,
)
from .utils import pool_tokens, modulate
try:
from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn import flash_attn_func
FLASH_ATTN_IS_AVAILABLE = True
except ImportError:
FLASH_ATTN_IS_AVAILABLE = False
@ -45,6 +36,7 @@ backends.append(SDPBackend.EFFICIENT_ATTENTION)
backends.append(SDPBackend.MATH)
import comfy.model_management as mm
from comfy.ldm.modules.attention import optimized_attention
class AttentionPool(nn.Module):
def __init__(
@ -103,16 +95,16 @@ class AttentionPool(nn.Module):
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
with sdpa_kernel(backends):
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x
#region Attention
class AsymmetricAttention(nn.Module):
def __init__(
self,
@ -128,6 +120,7 @@ class AsymmetricAttention(nn.Module):
softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
rms_norm_func: bool = False,
):
super().__init__()
@ -154,6 +147,28 @@ class AsymmetricAttention(nn.Module):
# Query and key normalization for stability.
assert qk_norm
if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func
try:
from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster
@torch.compiler.disable() #cause NaNs when compiled for some reason
class RMSNorm(FlashTritonRMSNorm):
pass
except:
raise ImportError("Flash Triton RMSNorm not available.")
elif rms_norm_func == "flash_attn":
try:
from flash_attn.ops.rms_norm import RMSNorm as FlashRMSNorm #slightly faster
@torch.compiler.disable() #cause NaNs when compiled for some reason
class RMSNorm(FlashRMSNorm):
pass
except:
raise ImportError("Flash RMSNorm not available.")
elif rms_norm_func == "apex":
from apex.normalization import FusedRMSNorm as ApexRMSNorm
class RMSNorm(ApexRMSNorm):
pass
else:
from .layers import RMSNorm
self.q_norm_x = RMSNorm(self.head_dim, device=device)
self.k_norm_x = RMSNorm(self.head_dim, device=device)
self.q_norm_y = RMSNorm(self.head_dim, device=device)
@ -166,79 +181,23 @@ class AsymmetricAttention(nn.Module):
if update_y
else nn.Identity()
)
def run_qkv_y(self, y):
local_heads = self.num_heads
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
q_y, k_y, v_y = qkv_y.unbind(2)
return q_y, k_y, v_y
def prepare_qkv(
self,
x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y)
*,
scale_x: torch.Tensor,
scale_y: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
valid_token_indices: torch.Tensor,
):
# Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
# Process visual features
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
#assert qkv_x.dtype == torch.bfloat16
# Move QKV dimension to the front.
# B M (3 H d) -> 3 B M H d
B, M, _ = qkv_x.size()
qkv_x = qkv_x.view(B, M, 3, self.num_heads, -1)
qkv_x = qkv_x.permute(2, 0, 1, 3, 4)
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y)
# Split qkv_x into q, k, v
q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
q_x = self.q_norm_x(q_x)
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
k_x = self.k_norm_x(k_x)
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
# Unite streams
qkv = unify_streams(
q_x,
k_x,
v_x,
q_y,
k_y,
v_y,
valid_token_indices,
)
return qkv
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
def flash_attention(self, q, k ,v):
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
b, _, _, dim_head = q.shape
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen_in_batch,
out: torch.Tensor = flash_attn_func( #q: (batch_size, seqlen, nheads, headdim)
q, k, v,
dropout_p=0.0,
softmax_scale=self.softmax_scale,
) # (total, local_heads, head_dim)
return out.view(total, local_dim)
out = out.permute(0, 2, 1, 3)
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def sdpa_attention(self, qkv):
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
def sdpa_attention(self, q, k, v):
b, _, _, dim_head = q.shape
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
with sdpa_kernel(backends):
out = F.scaled_dot_product_attention(
@ -249,13 +208,10 @@ class AsymmetricAttention(nn.Module):
dropout_p=0.0,
is_causal=False
)
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def sage_attention(self, qkv):
#q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
def sage_attention(self, q, k, v):
b, _, _, dim_head = q.shape
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = sageattn(
q,
@ -265,14 +221,9 @@ class AsymmetricAttention(nn.Module):
dropout_p=0.0,
is_causal=False
)
#print(out.shape)
#out = rearrange(out, 'b h s d -> s (b h d)')
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def comfy_attention(self, qkv):
from comfy.ldm.modules.attention import optimized_attention
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
def comfy_attention(self, q, k, v):
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = optimized_attention(
q,
@ -281,41 +232,23 @@ class AsymmetricAttention(nn.Module):
heads = self.num_heads,
skip_reshape=True
)
return out.squeeze(0)
return out
@torch.compiler.disable()
def run_attention(
self,
qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim)
*,
B: int,
L: int,
M: int,
cu_seqlens: torch.Tensor,
max_seqlen_in_batch: int,
valid_token_indices: torch.Tensor,
):
local_dim = self.num_heads * self.head_dim
total = qkv.size(0)
q,
k,
v,
):
if self.attention_mode == "flash_attn":
out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim)
out = self.flash_attention(q, k ,v)
elif self.attention_mode == "sdpa":
out = self.sdpa_attention(qkv)
out = self.sdpa_attention(q, k, v)
elif self.attention_mode == "sage_attn":
out = self.sage_attention(qkv)
out = self.sage_attention(q, k, v)
elif self.attention_mode == "comfy":
out = self.comfy_attention(qkv)
x, y = pad_and_split_xy(out, valid_token_indices, B, M, L, qkv.dtype)
assert x.size() == (B, M, local_dim)
assert y.size() == (B, L, local_dim)
x = x.view(B, M, self.num_heads, self.head_dim)
x = x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
x = self.proj_x(x) # (B, M, dim_x)
y = self.proj_y(y) # (B, L, dim_y)
return x, y
out = self.comfy_attention(q, k, v)
return out
def forward(
self,
@ -324,47 +257,44 @@ class AsymmetricAttention(nn.Module):
*,
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
packed_indices: Dict[str, torch.Tensor] = None,
num_tokens,
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass of asymmetric multi-modal attention.
Args:
x: (B, N, dim_x) tensor for visual tokens
y: (B, L, dim_y) tensor of text token features
packed_indices: Dict with keys for Flash Attention
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
rope_cos = rope_rotation.get("rope_cos")
rope_sin = rope_rotation.get("rope_sin")
# Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
Returns:
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
y: (B, L, dim_y) tensor of text token features after multi-modal attention
"""
B, L, _ = y.shape
_, M, _ = x.shape
q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y)
# Predict a packed QKV tensor from visual and text features.
# Don't checkpoint the all_to_all.
qkv = self.prepare_qkv(
x=x,
y=y,
scale_x=scale_x,
scale_y=scale_y,
rope_cos=rope_rotation.get("rope_cos"),
rope_sin=rope_rotation.get("rope_sin"),
valid_token_indices=packed_indices["valid_token_indices_kv"],
) # (total <= B * (N + L), 3, local_heads, head_dim)
# Split qkv_x into q, k, v
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
q_x = self.q_norm_x(q_x)
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
k_x = self.k_norm_x(k_x)
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
x, y = self.run_attention(
qkv,
B=B,
L=L,
M=M,
cu_seqlens=packed_indices["cu_seqlens_kv"],
max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
valid_token_indices=packed_indices["valid_token_indices_kv"],
)
q = torch.cat([q_x, q_y[:, :num_tokens]], dim=1).transpose(1, 2)
k = torch.cat([k_x, k_y[:, :num_tokens]], dim=1).transpose(1, 2)
v = torch.cat([v_x, v_y[:, :num_tokens]], dim=1).transpose(1, 2)
xy = self.run_attention(q, k, v)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
o[:, :y.shape[1]] = y
y = self.proj_y(o)
return x, y
#region Blocks
class AsymmetricJointBlock(nn.Module):
def __init__(
self,
@ -377,6 +307,7 @@ class AsymmetricJointBlock(nn.Module):
update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
rms_norm_func: str = "default",
**block_kwargs,
):
super().__init__()
@ -401,6 +332,7 @@ class AsymmetricJointBlock(nn.Module):
update_y=update_y,
device=device,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
**block_kwargs,
)
@ -432,7 +364,8 @@ class AsymmetricJointBlock(nn.Module):
c: torch.Tensor,
y: torch.Tensor,
fastercache_counter: Optional[int] = 0,
fastercache_start_step: Optional[int] = 15,
fastercache_start_step: Optional[int] = 1000,
fastercache_device: Optional[torch.device] = None,
**attn_kwargs,
):
"""Forward pass of a block.
@ -459,10 +392,11 @@ class AsymmetricJointBlock(nn.Module):
else:
scale_msa_y = mod_y
#fastercache
#region fastercache
B = x.shape[0]
#print("x", x.shape) #([1, 9540, 3072])
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B:
print("using fastercache")
x_attn = (
self.cached_x_attention[1][:B] +
(self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B])
@ -483,22 +417,24 @@ class AsymmetricJointBlock(nn.Module):
**attn_kwargs,
)
if fastercache_counter == fastercache_start_step:
self.cached_x_attention = [x_attn, x_attn]
self.cached_y_attention = [y_attn, y_attn]
print("caching attention")
self.cached_x_attention = [x_attn.to(fastercache_device), x_attn.to(fastercache_device)]
self.cached_y_attention = [y_attn.to(fastercache_device), y_attn.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_x_attention[-1].copy_(x_attn)
self.cached_y_attention[-1].copy_(y_attn)
print("updating attention")
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
assert x_attn.size(1) == N
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
if self.update_y:
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
# MLP block.
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
if self.update_y:
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
return x, y
def ff_block_x(self, x, scale_x, gate_x):
@ -542,7 +478,7 @@ class FinalLayer(nn.Module):
x = self.linear(x)
return x
#region Model
class AsymmDiTJoint(nn.Module):
"""
Diffusion model with a Transformer backbone.
@ -570,6 +506,7 @@ class AsymmDiTJoint(nn.Module):
rope_theta: float = 10000.0,
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
rms_norm_func: str = "default",
**block_kwargs,
):
super().__init__()
@ -638,6 +575,7 @@ class AsymmDiTJoint(nn.Module):
update_y=update_y,
device=device,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
**block_kwargs,
)
@ -701,12 +639,11 @@ class AsymmDiTJoint(nn.Module):
x: torch.Tensor,
sigma: torch.Tensor,
fastercache_counter: int,
fastercache_start_step: int,
y_feat: List[torch.Tensor],
y_mask: List[torch.Tensor],
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
fastercache: Optional[Dict[str, torch.Tensor]] = None,
):
"""Forward pass of DiT.
@ -715,18 +652,25 @@ class AsymmDiTJoint(nn.Module):
sigma: (B,) tensor of noise standard deviations
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
"""
B, _, T, H, W = x.shape
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
# Have to call sdpa_kernel outside of a torch.compile region.
num_tokens = max(1, torch.sum(y_mask[0]).item())
with sdpa_kernel(backends):
x, c, y_feat, rope_cos, rope_sin = self.prepare(
x, sigma, y_feat[0], y_mask[0]
)
del y_mask
if fastercache is not None:
fastercache_start_step = fastercache["start_step"]
fastercache_device = fastercache["cache_device"]
else:
fastercache_start_step = 1000
fastercache_device = None
for i, block in enumerate(self.blocks):
x, y_feat = block(
x,
@ -734,22 +678,24 @@ class AsymmDiTJoint(nn.Module):
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
packed_indices=packed_indices,
num_tokens=num_tokens,
fastercache_counter = fastercache_counter,
fastercache_start_step = fastercache_start_step,
fastercache_device = fastercache_device,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
x = rearrange(
x,
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
T=T,
hp=H // self.patch_size,
wp=W // self.patch_size,
p1=self.patch_size,
p2=self.patch_size,
c=self.out_channels,
)
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
hp = H // self.patch_size
wp = W // self.patch_size
p1 = self.patch_size
p2 = self.patch_size
c = self.out_channels
x = x.view(B, T, hp, wp, p1, p2, c)
x = x.permute(0, 6, 1, 2, 4, 3, 5)
x = x.reshape(B, c, T, hp * p1, wp * p2)
return x

View File

@ -1,5 +1,5 @@
import json
from typing import Dict, List, Optional, Union
from einops import rearrange
#temporary patch to fix torch compile bug in Windows
def patched_write_atomic(
@ -33,11 +33,8 @@ except:
pass
import torch
import torch.nn.functional as F
import torch.utils.data
from einops import rearrange, repeat
#from .dit.joint_model.context_parallel import get_cp_rank_size
from tqdm import tqdm
from comfy.utils import ProgressBar, load_torch_file
import comfy.model_management as mm
@ -95,59 +92,17 @@ def unnormalize_latents(
assert z.size(1) == mean.size(0) == std.size(0)
return z * std.to(z) + mean.to(z)
def compute_packed_indices(
N: int,
text_mask: List[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80
Args:
N: Number of visual tokens.
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding.
Returns:
packed_indices: Dict with keys for Flash Attention:
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding)
in the packed sequence.
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence.
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch.
"""
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
assert N > 0 and len(text_mask) == 1
text_mask = text_mask[0]
mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L)
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
valid_token_indices = torch.nonzero(
mask.flatten(), as_tuple=False
).flatten() # up to (B * (N + L),)
assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,)
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
max_seqlen_in_batch = seqlens_in_batch.max().item()
return {
"cu_seqlens_kv": cu_seqlens,
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
"valid_token_indices_kv": valid_token_indices,
}
class T2VSynthMochiModel:
def __init__(
self,
*,
device: torch.device,
offload_device: torch.device,
vae_stats_path: str,
dit_checkpoint_path: str,
weight_dtype: torch.dtype = torch.float8_e4m3fn,
fp8_fastmode: bool = False,
attention_mode: str = "sdpa",
rms_norm_func: str = "default",
compile_args: Optional[Dict] = None,
cublas_ops: Optional[bool] = False,
):
@ -177,11 +132,23 @@ class T2VSynthMochiModel:
t5_token_length=256,
rope_theta=10000.0,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
)
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
dit_sd = load_torch_file(dit_checkpoint_path)
#comfy format
prefix = "model.diffusion_model."
first_key = next(iter(dit_sd), None)
if first_key and first_key.startswith(prefix):
new_dit_sd = {
key[len(prefix):] if key.startswith(prefix) else key: value
for key, value in dit_sd.items()
}
dit_sd = new_dit_sd
if "gguf" in dit_checkpoint_path.lower():
logging.info("Loading GGUF model state_dict...")
from .. import mz_gguf_loader
@ -220,18 +187,10 @@ class T2VSynthMochiModel:
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
self.dit = model
vae_stats = json.load(open(vae_stats_path))
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
def get_packed_indices(self, y_mask, *, lT, lW, lH):
patch_size = 2
N = lT * lH * lW // (patch_size**2)
assert len(y_mask) == 1
packed_indices = compute_packed_indices(N, y_mask)
self.move_to_device_(packed_indices)
return packed_indices
def get_packed_indices(self, y_mask, **latent_dims):
# temporary dummy func for compatibility
return []
def move_to_device_(self, sample):
if isinstance(sample, dict):
@ -243,7 +202,7 @@ class T2VSynthMochiModel:
torch.manual_seed(args["seed"])
torch.cuda.manual_seed(args["seed"])
generator = torch.Generator(device=self.device)
generator = torch.Generator(device=torch.device("cpu"))
generator.manual_seed(args["seed"])
num_frames = args["num_frames"]
@ -275,50 +234,49 @@ class T2VSynthMochiModel:
T = (num_frames - 1) // temporal_downsample + 1
H = height // spatial_downsample
W = width // spatial_downsample
latent_dims = dict(lT=T, lW=W, lH=H)
z = torch.randn(
(B, C, T, H, W),
device=self.device,
device=torch.device("cpu"),
generator=generator,
dtype=torch.float32,
)
).to(self.device)
if in_samples is not None:
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device)
sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)],
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
}
sample_null = {
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
}
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)],
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
}
sample["packed_indices"] = self.get_packed_indices(
sample["y_mask"], **latent_dims
)
sample_null["packed_indices"] = self.get_packed_indices(
sample_null["y_mask"], **latent_dims
)
self.use_fastercache = True
if args["fastercache"]:
self.fastercache_start_step = args["fastercache"]["start_step"]
self.fastercache_lf_step = args["fastercache"]["lf_step"]
self.fastercache_hf_step = args["fastercache"]["hf_step"]
else:
self.fastercache_start_step = 1000
self.fastercache_counter = 0
self.fastercache_start_step = 15
self.fastercache_lf_step = 40
self.fastercache_hf_step = 30
def model_fn(*, z, sigma, cfg_scale):
nonlocal sample, sample_null
if self.use_fastercache:
if args["fastercache"]:
self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
out_cond = self.dit(z, sigma,self.fastercache_counter, self.fastercache_start_step, **sample)
out_cond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample)
(bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond.float())
#lf_step = 40
#hf_step = 30
if self.fastercache_counter <= self.fastercache_lf_step:
self.delta_lf = self.delta_lf * 1.1
if self.fastercache_counter >= self.fastercache_hf_step:
@ -334,12 +292,19 @@ class T2VSynthMochiModel:
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
else:
out_cond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample)
out_uncond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample_null)
#print("out_cond.shape",out_cond.shape) #([1, 12, 3, 60, 106])
out_cond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample)
out_uncond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample_null)
if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
@ -352,7 +317,6 @@ class T2VSynthMochiModel:
return out_uncond + cfg_scale * (out_cond - out_uncond)
comfy_pbar = ProgressBar(sample_steps)
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
@ -373,13 +337,19 @@ class T2VSynthMochiModel:
sigma=torch.full([B], sigma, device=z.device),
cfg_scale=cfg_schedule[i],
)
pred = pred.to(z)
z = z + dsigma * pred
z = z + dsigma * pred.to(z)
if callback is not None:
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
else:
comfy_pbar.update(1)
if args["fastercache"] is not None:
for block in self.dit.blocks:
if (hasattr, block, "cached_x_attention") and block.cached_x_attention is not None:
block.cached_x_attention = None
block.cached_y_attention = None
self.dit.to(self.offload_device)
mm.soft_empty_cache()
logging.info(f"samples shape: {z.shape}")
return z

View File

@ -446,6 +446,36 @@ class MochiTextEncode:
}
return (t5_embeds, clip,)
#region FasterCache
class MochiFasterCache:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"start_step": ("INT", {"default": 10, "min": 0, "max": 1024, "step": 1, "tooltip": "The step to start caching, sigma schedule should be adjusted accordingly"}),
"hf_step": ("INT", {"default": 22, "min": 0, "max": 1024, "step": 1}),
"lf_step": ("INT", {"default": 28, "min": 0, "max": 1024, "step": 1}),
"cache_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}),
},
}
RETURN_TYPES = ("FASTERCACHEARGS",)
RETURN_NAMES = ("fastercache", )
FUNCTION = "args"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "FasterCache (https://github.com/Vchitect/FasterCache) settings for the MochiWrapper"
def args(self, start_step, hf_step, lf_step, cache_device):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
fastercache = {
"start_step" : start_step,
"hf_step" : hf_step,
"lf_step" : lf_step,
"cache_device" : device if cache_device == "main_device" else offload_device
}
return (fastercache,)
#region Sampler
class MochiSampler:
@classmethod
@ -466,6 +496,7 @@ class MochiSampler:
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
"opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}),
"samples": ("LATENT", ),
"fastercache": ("FASTERCACHEARGS", {"tooltip": "Optional FasterCache settings"}),
}
}
@ -474,7 +505,7 @@ class MochiSampler:
FUNCTION = "process"
CATEGORY = "MochiWrapper"
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None, fastercache=None):
mm.unload_all_models()
mm.soft_empty_cache()
@ -517,6 +548,7 @@ class MochiSampler:
"negative_embeds": negative,
"seed": seed,
"samples": samples["samples"] if samples is not None else None,
"fastercache": fastercache
}
latents = model.run(args)
@ -848,7 +880,8 @@ NODE_CLASS_MAPPINGS = {
"MochiTorchCompileSettings": MochiTorchCompileSettings,
"MochiImageEncode": MochiImageEncode,
"MochiLatentPreview": MochiLatentPreview,
"MochiSigmaSchedule": MochiSigmaSchedule
"MochiSigmaSchedule": MochiSigmaSchedule,
"MochiFasterCache": MochiFasterCache
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
@ -862,5 +895,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"MochiTorchCompileSettings": "Mochi Torch Compile Settings",
"MochiImageEncode": "Mochi Image Encode",
"MochiLatentPreview": "Mochi Latent Preview",
"MochiSigmaSchedule": "Mochi Sigma Schedule"
"MochiSigmaSchedule": "Mochi Sigma Schedule",
"MochiFasterCache": "Mochi Faster Cache"
}