update, works

This commit is contained in:
kijai 2024-11-05 08:41:50 +02:00
parent 3535a846a8
commit 24a7edfca6
3 changed files with 225 additions and 275 deletions

View File

@ -3,31 +3,22 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from .layers import ( from .layers import (
FeedForward, FeedForward,
PatchEmbed, PatchEmbed,
RMSNorm,
TimestepEmbedder, TimestepEmbedder,
) )
from .mod_rmsnorm import modulated_rmsnorm from .mod_rmsnorm import modulated_rmsnorm
from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm) from .residual_tanh_gated_rmsnorm import residual_tanh_gated_rmsnorm
from .rope_mixed import ( from .rope_mixed import compute_mixed_rotation, create_position_matrix
compute_mixed_rotation,
create_position_matrix,
)
from .temporal_rope import apply_rotary_emb_qk_real from .temporal_rope import apply_rotary_emb_qk_real
from .utils import ( from .utils import pool_tokens, modulate
pool_tokens,
modulate,
pad_and_split_xy,
unify_streams,
)
try: try:
from flash_attn import flash_attn_varlen_qkvpacked_func from flash_attn import flash_attn_func
FLASH_ATTN_IS_AVAILABLE = True FLASH_ATTN_IS_AVAILABLE = True
except ImportError: except ImportError:
FLASH_ATTN_IS_AVAILABLE = False FLASH_ATTN_IS_AVAILABLE = False
@ -45,6 +36,7 @@ backends.append(SDPBackend.EFFICIENT_ATTENTION)
backends.append(SDPBackend.MATH) backends.append(SDPBackend.MATH)
import comfy.model_management as mm import comfy.model_management as mm
from comfy.ldm.modules.attention import optimized_attention
class AttentionPool(nn.Module): class AttentionPool(nn.Module):
def __init__( def __init__(
@ -103,16 +95,16 @@ class AttentionPool(nn.Module):
q = q.unsqueeze(2) # (B, H, 1, head_dim) q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention. # Compute attention.
with sdpa_kernel(backends): x = F.scaled_dot_product_attention(
x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=0.0
q, k, v, attn_mask=attn_mask, dropout_p=0.0 ) # (B, H, 1, head_dim)
) # (B, H, 1, head_dim)
# Concatenate heads and run output. # Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x) x = self.to_out(x)
return x return x
#region Attention
class AsymmetricAttention(nn.Module): class AsymmetricAttention(nn.Module):
def __init__( def __init__(
self, self,
@ -128,6 +120,7 @@ class AsymmetricAttention(nn.Module):
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
attention_mode: str = "sdpa", attention_mode: str = "sdpa",
rms_norm_func: bool = False,
): ):
super().__init__() super().__init__()
@ -154,6 +147,28 @@ class AsymmetricAttention(nn.Module):
# Query and key normalization for stability. # Query and key normalization for stability.
assert qk_norm assert qk_norm
if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func
try:
from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster
@torch.compiler.disable() #cause NaNs when compiled for some reason
class RMSNorm(FlashTritonRMSNorm):
pass
except:
raise ImportError("Flash Triton RMSNorm not available.")
elif rms_norm_func == "flash_attn":
try:
from flash_attn.ops.rms_norm import RMSNorm as FlashRMSNorm #slightly faster
@torch.compiler.disable() #cause NaNs when compiled for some reason
class RMSNorm(FlashRMSNorm):
pass
except:
raise ImportError("Flash RMSNorm not available.")
elif rms_norm_func == "apex":
from apex.normalization import FusedRMSNorm as ApexRMSNorm
class RMSNorm(ApexRMSNorm):
pass
else:
from .layers import RMSNorm
self.q_norm_x = RMSNorm(self.head_dim, device=device) self.q_norm_x = RMSNorm(self.head_dim, device=device)
self.k_norm_x = RMSNorm(self.head_dim, device=device) self.k_norm_x = RMSNorm(self.head_dim, device=device)
self.q_norm_y = RMSNorm(self.head_dim, device=device) self.q_norm_y = RMSNorm(self.head_dim, device=device)
@ -166,79 +181,23 @@ class AsymmetricAttention(nn.Module):
if update_y if update_y
else nn.Identity() else nn.Identity()
) )
def run_qkv_y(self, y):
local_heads = self.num_heads
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
q_y, k_y, v_y = qkv_y.unbind(2)
return q_y, k_y, v_y
def flash_attention(self, q, k ,v):
def prepare_qkv( q = q.permute(0, 2, 1, 3)
self, k = k.permute(0, 2, 1, 3)
x: torch.Tensor, # (B, N, dim_x) v = v.permute(0, 2, 1, 3)
y: torch.Tensor, # (B, L, dim_y) b, _, _, dim_head = q.shape
*,
scale_x: torch.Tensor,
scale_y: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
valid_token_indices: torch.Tensor,
):
# Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
# Process visual features
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
#assert qkv_x.dtype == torch.bfloat16
# Move QKV dimension to the front.
# B M (3 H d) -> 3 B M H d
B, M, _ = qkv_x.size()
qkv_x = qkv_x.view(B, M, 3, self.num_heads, -1)
qkv_x = qkv_x.permute(2, 0, 1, 3, 4)
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y)
# Split qkv_x into q, k, v
q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
q_x = self.q_norm_x(q_x)
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
k_x = self.k_norm_x(k_x)
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
# Unite streams
qkv = unify_streams(
q_x,
k_x,
v_x,
q_y,
k_y,
v_y,
valid_token_indices,
)
return qkv
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
with torch.autocast(mm.get_autocast_device(self.device), enabled=False): with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out: torch.Tensor = flash_attn_varlen_qkvpacked_func( out: torch.Tensor = flash_attn_func( #q: (batch_size, seqlen, nheads, headdim)
qkv, q, k, v,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen_in_batch,
dropout_p=0.0, dropout_p=0.0,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
) # (total, local_heads, head_dim) ) # (total, local_heads, head_dim)
return out.view(total, local_dim) out = out.permute(0, 2, 1, 3)
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def sdpa_attention(self, qkv): def sdpa_attention(self, q, k, v):
q, k, v = qkv.unbind(dim=1) b, _, _, dim_head = q.shape
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
with torch.autocast(mm.get_autocast_device(self.device), enabled=False): with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
with sdpa_kernel(backends): with sdpa_kernel(backends):
out = F.scaled_dot_product_attention( out = F.scaled_dot_product_attention(
@ -249,13 +208,10 @@ class AsymmetricAttention(nn.Module):
dropout_p=0.0, dropout_p=0.0,
is_causal=False is_causal=False
) )
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1) return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def sage_attention(self, qkv): def sage_attention(self, q, k, v):
#q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) b, _, _, dim_head = q.shape
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
with torch.autocast(mm.get_autocast_device(self.device), enabled=False): with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = sageattn( out = sageattn(
q, q,
@ -265,14 +221,9 @@ class AsymmetricAttention(nn.Module):
dropout_p=0.0, dropout_p=0.0,
is_causal=False is_causal=False
) )
#print(out.shape) return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
#out = rearrange(out, 'b h s d -> s (b h d)')
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
def comfy_attention(self, qkv): def comfy_attention(self, q, k, v):
from comfy.ldm.modules.attention import optimized_attention
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
with torch.autocast(mm.get_autocast_device(self.device), enabled=False): with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = optimized_attention( out = optimized_attention(
q, q,
@ -281,41 +232,23 @@ class AsymmetricAttention(nn.Module):
heads = self.num_heads, heads = self.num_heads,
skip_reshape=True skip_reshape=True
) )
return out.squeeze(0) return out
@torch.compiler.disable()
def run_attention( def run_attention(
self, self,
qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim) q,
*, k,
B: int, v,
L: int, ):
M: int,
cu_seqlens: torch.Tensor,
max_seqlen_in_batch: int,
valid_token_indices: torch.Tensor,
):
local_dim = self.num_heads * self.head_dim
total = qkv.size(0)
if self.attention_mode == "flash_attn": if self.attention_mode == "flash_attn":
out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim) out = self.flash_attention(q, k ,v)
elif self.attention_mode == "sdpa": elif self.attention_mode == "sdpa":
out = self.sdpa_attention(qkv) out = self.sdpa_attention(q, k, v)
elif self.attention_mode == "sage_attn": elif self.attention_mode == "sage_attn":
out = self.sage_attention(qkv) out = self.sage_attention(q, k, v)
elif self.attention_mode == "comfy": elif self.attention_mode == "comfy":
out = self.comfy_attention(qkv) out = self.comfy_attention(q, k, v)
return out
x, y = pad_and_split_xy(out, valid_token_indices, B, M, L, qkv.dtype)
assert x.size() == (B, M, local_dim)
assert y.size() == (B, L, local_dim)
x = x.view(B, M, self.num_heads, self.head_dim)
x = x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
x = self.proj_x(x) # (B, M, dim_x)
y = self.proj_y(y) # (B, L, dim_y)
return x, y
def forward( def forward(
self, self,
@ -324,47 +257,44 @@ class AsymmetricAttention(nn.Module):
*, *,
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm. scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm. scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
packed_indices: Dict[str, torch.Tensor] = None, num_tokens,
**rope_rotation, **rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass of asymmetric multi-modal attention.
Args: rope_cos = rope_rotation.get("rope_cos")
x: (B, N, dim_x) tensor for visual tokens rope_sin = rope_rotation.get("rope_sin")
y: (B, L, dim_y) tensor of text token features
packed_indices: Dict with keys for Flash Attention # Pre-norm for visual features
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
Returns: q_y = self.q_norm_y(q_y)
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention k_y = self.k_norm_y(k_y)
y: (B, L, dim_y) tensor of text token features after multi-modal attention
"""
B, L, _ = y.shape
_, M, _ = x.shape
# Predict a packed QKV tensor from visual and text features. # Split qkv_x into q, k, v
# Don't checkpoint the all_to_all. q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
qkv = self.prepare_qkv( q_x = self.q_norm_x(q_x)
x=x, q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
y=y, k_x = self.k_norm_x(k_x)
scale_x=scale_x, k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
scale_y=scale_y,
rope_cos=rope_rotation.get("rope_cos"),
rope_sin=rope_rotation.get("rope_sin"),
valid_token_indices=packed_indices["valid_token_indices_kv"],
) # (total <= B * (N + L), 3, local_heads, head_dim)
x, y = self.run_attention( q = torch.cat([q_x, q_y[:, :num_tokens]], dim=1).transpose(1, 2)
qkv, k = torch.cat([k_x, k_y[:, :num_tokens]], dim=1).transpose(1, 2)
B=B, v = torch.cat([v_x, v_y[:, :num_tokens]], dim=1).transpose(1, 2)
L=L,
M=M, xy = self.run_attention(q, k, v)
cu_seqlens=packed_indices["cu_seqlens_kv"],
max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"], x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
valid_token_indices=packed_indices["valid_token_indices_kv"], x = self.proj_x(x)
) o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
o[:, :y.shape[1]] = y
y = self.proj_y(o)
return x, y return x, y
#region Blocks
class AsymmetricJointBlock(nn.Module): class AsymmetricJointBlock(nn.Module):
def __init__( def __init__(
self, self,
@ -377,6 +307,7 @@ class AsymmetricJointBlock(nn.Module):
update_y: bool = True, # Whether to update text tokens in this block. update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
attention_mode: str = "sdpa", attention_mode: str = "sdpa",
rms_norm_func: str = "default",
**block_kwargs, **block_kwargs,
): ):
super().__init__() super().__init__()
@ -401,6 +332,7 @@ class AsymmetricJointBlock(nn.Module):
update_y=update_y, update_y=update_y,
device=device, device=device,
attention_mode=attention_mode, attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
**block_kwargs, **block_kwargs,
) )
@ -432,7 +364,8 @@ class AsymmetricJointBlock(nn.Module):
c: torch.Tensor, c: torch.Tensor,
y: torch.Tensor, y: torch.Tensor,
fastercache_counter: Optional[int] = 0, fastercache_counter: Optional[int] = 0,
fastercache_start_step: Optional[int] = 15, fastercache_start_step: Optional[int] = 1000,
fastercache_device: Optional[torch.device] = None,
**attn_kwargs, **attn_kwargs,
): ):
"""Forward pass of a block. """Forward pass of a block.
@ -459,10 +392,11 @@ class AsymmetricJointBlock(nn.Module):
else: else:
scale_msa_y = mod_y scale_msa_y = mod_y
#fastercache #region fastercache
B = x.shape[0] B = x.shape[0]
#print("x", x.shape) #([1, 9540, 3072]) #print("x", x.shape) #([1, 9540, 3072])
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B: if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B:
print("using fastercache")
x_attn = ( x_attn = (
self.cached_x_attention[1][:B] + self.cached_x_attention[1][:B] +
(self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B]) (self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B])
@ -483,22 +417,24 @@ class AsymmetricJointBlock(nn.Module):
**attn_kwargs, **attn_kwargs,
) )
if fastercache_counter == fastercache_start_step: if fastercache_counter == fastercache_start_step:
self.cached_x_attention = [x_attn, x_attn] print("caching attention")
self.cached_y_attention = [y_attn, y_attn] self.cached_x_attention = [x_attn.to(fastercache_device), x_attn.to(fastercache_device)]
self.cached_y_attention = [y_attn.to(fastercache_device), y_attn.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step: elif fastercache_counter > fastercache_start_step:
self.cached_x_attention[-1].copy_(x_attn) print("updating attention")
self.cached_y_attention[-1].copy_(y_attn) self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
assert x_attn.size(1) == N assert x_attn.size(1) == N
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x) x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
if self.update_y: if self.update_y:
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y) y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
# MLP block. # MLP block.
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x) x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
if self.update_y: if self.update_y:
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y) y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
return x, y return x, y
def ff_block_x(self, x, scale_x, gate_x): def ff_block_x(self, x, scale_x, gate_x):
@ -542,7 +478,7 @@ class FinalLayer(nn.Module):
x = self.linear(x) x = self.linear(x)
return x return x
#region Model
class AsymmDiTJoint(nn.Module): class AsymmDiTJoint(nn.Module):
""" """
Diffusion model with a Transformer backbone. Diffusion model with a Transformer backbone.
@ -570,6 +506,7 @@ class AsymmDiTJoint(nn.Module):
rope_theta: float = 10000.0, rope_theta: float = 10000.0,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
attention_mode: str = "sdpa", attention_mode: str = "sdpa",
rms_norm_func: str = "default",
**block_kwargs, **block_kwargs,
): ):
super().__init__() super().__init__()
@ -638,6 +575,7 @@ class AsymmDiTJoint(nn.Module):
update_y=update_y, update_y=update_y,
device=device, device=device,
attention_mode=attention_mode, attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
**block_kwargs, **block_kwargs,
) )
@ -701,12 +639,11 @@ class AsymmDiTJoint(nn.Module):
x: torch.Tensor, x: torch.Tensor,
sigma: torch.Tensor, sigma: torch.Tensor,
fastercache_counter: int, fastercache_counter: int,
fastercache_start_step: int,
y_feat: List[torch.Tensor], y_feat: List[torch.Tensor],
y_mask: List[torch.Tensor], y_mask: List[torch.Tensor],
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None, rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None, rope_sin: torch.Tensor = None,
fastercache: Optional[Dict[str, torch.Tensor]] = None,
): ):
"""Forward pass of DiT. """Forward pass of DiT.
@ -715,18 +652,25 @@ class AsymmDiTJoint(nn.Module):
sigma: (B,) tensor of noise standard deviations sigma: (B,) tensor of noise standard deviations
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
y_mask: List((B, L) boolean tensor indicating which tokens are not padding) y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
""" """
B, _, T, H, W = x.shape B, _, T, H, W = x.shape
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask. # Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
# Have to call sdpa_kernel outside of a torch.compile region. # Have to call sdpa_kernel outside of a torch.compile region.
num_tokens = max(1, torch.sum(y_mask[0]).item())
with sdpa_kernel(backends): with sdpa_kernel(backends):
x, c, y_feat, rope_cos, rope_sin = self.prepare( x, c, y_feat, rope_cos, rope_sin = self.prepare(
x, sigma, y_feat[0], y_mask[0] x, sigma, y_feat[0], y_mask[0]
) )
del y_mask del y_mask
if fastercache is not None:
fastercache_start_step = fastercache["start_step"]
fastercache_device = fastercache["cache_device"]
else:
fastercache_start_step = 1000
fastercache_device = None
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
x, y_feat = block( x, y_feat = block(
x, x,
@ -734,22 +678,24 @@ class AsymmDiTJoint(nn.Module):
y_feat, y_feat,
rope_cos=rope_cos, rope_cos=rope_cos,
rope_sin=rope_sin, rope_sin=rope_sin,
packed_indices=packed_indices, num_tokens=num_tokens,
fastercache_counter = fastercache_counter, fastercache_counter = fastercache_counter,
fastercache_start_step = fastercache_start_step, fastercache_start_step = fastercache_start_step,
fastercache_device = fastercache_device,
) # (B, M, D), (B, L, D) ) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features. del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
x = rearrange(
x, hp = H // self.patch_size
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)", wp = W // self.patch_size
T=T, p1 = self.patch_size
hp=H // self.patch_size, p2 = self.patch_size
wp=W // self.patch_size, c = self.out_channels
p1=self.patch_size,
p2=self.patch_size, x = x.view(B, T, hp, wp, p1, p2, c)
c=self.out_channels, x = x.permute(0, 6, 1, 2, 4, 3, 5)
) x = x.reshape(B, c, T, hp * p1, wp * p2)
return x return x

View File

@ -1,5 +1,5 @@
import json
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from einops import rearrange
#temporary patch to fix torch compile bug in Windows #temporary patch to fix torch compile bug in Windows
def patched_write_atomic( def patched_write_atomic(
@ -33,11 +33,8 @@ except:
pass pass
import torch import torch
import torch.nn.functional as F
import torch.utils.data import torch.utils.data
from einops import rearrange, repeat
#from .dit.joint_model.context_parallel import get_cp_rank_size
from tqdm import tqdm from tqdm import tqdm
from comfy.utils import ProgressBar, load_torch_file from comfy.utils import ProgressBar, load_torch_file
import comfy.model_management as mm import comfy.model_management as mm
@ -95,59 +92,17 @@ def unnormalize_latents(
assert z.size(1) == mean.size(0) == std.size(0) assert z.size(1) == mean.size(0) == std.size(0)
return z * std.to(z) + mean.to(z) return z * std.to(z) + mean.to(z)
def compute_packed_indices(
N: int,
text_mask: List[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80
Args:
N: Number of visual tokens.
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding.
Returns:
packed_indices: Dict with keys for Flash Attention:
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding)
in the packed sequence.
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence.
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch.
"""
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
assert N > 0 and len(text_mask) == 1
text_mask = text_mask[0]
mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L)
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
valid_token_indices = torch.nonzero(
mask.flatten(), as_tuple=False
).flatten() # up to (B * (N + L),)
assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,)
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
max_seqlen_in_batch = seqlens_in_batch.max().item()
return {
"cu_seqlens_kv": cu_seqlens,
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
"valid_token_indices_kv": valid_token_indices,
}
class T2VSynthMochiModel: class T2VSynthMochiModel:
def __init__( def __init__(
self, self,
*, *,
device: torch.device, device: torch.device,
offload_device: torch.device, offload_device: torch.device,
vae_stats_path: str,
dit_checkpoint_path: str, dit_checkpoint_path: str,
weight_dtype: torch.dtype = torch.float8_e4m3fn, weight_dtype: torch.dtype = torch.float8_e4m3fn,
fp8_fastmode: bool = False, fp8_fastmode: bool = False,
attention_mode: str = "sdpa", attention_mode: str = "sdpa",
rms_norm_func: str = "default",
compile_args: Optional[Dict] = None, compile_args: Optional[Dict] = None,
cublas_ops: Optional[bool] = False, cublas_ops: Optional[bool] = False,
): ):
@ -177,11 +132,23 @@ class T2VSynthMochiModel:
t5_token_length=256, t5_token_length=256,
rope_theta=10000.0, rope_theta=10000.0,
attention_mode=attention_mode, attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
) )
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...") logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
dit_sd = load_torch_file(dit_checkpoint_path) dit_sd = load_torch_file(dit_checkpoint_path)
#comfy format
prefix = "model.diffusion_model."
first_key = next(iter(dit_sd), None)
if first_key and first_key.startswith(prefix):
new_dit_sd = {
key[len(prefix):] if key.startswith(prefix) else key: value
for key, value in dit_sd.items()
}
dit_sd = new_dit_sd
if "gguf" in dit_checkpoint_path.lower(): if "gguf" in dit_checkpoint_path.lower():
logging.info("Loading GGUF model state_dict...") logging.info("Loading GGUF model state_dict...")
from .. import mz_gguf_loader from .. import mz_gguf_loader
@ -220,18 +187,10 @@ class T2VSynthMochiModel:
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"]) model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
self.dit = model self.dit = model
vae_stats = json.load(open(vae_stats_path))
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
def get_packed_indices(self, y_mask, *, lT, lW, lH): def get_packed_indices(self, y_mask, **latent_dims):
patch_size = 2 # temporary dummy func for compatibility
N = lT * lH * lW // (patch_size**2) return []
assert len(y_mask) == 1
packed_indices = compute_packed_indices(N, y_mask)
self.move_to_device_(packed_indices)
return packed_indices
def move_to_device_(self, sample): def move_to_device_(self, sample):
if isinstance(sample, dict): if isinstance(sample, dict):
@ -243,7 +202,7 @@ class T2VSynthMochiModel:
torch.manual_seed(args["seed"]) torch.manual_seed(args["seed"])
torch.cuda.manual_seed(args["seed"]) torch.cuda.manual_seed(args["seed"])
generator = torch.Generator(device=self.device) generator = torch.Generator(device=torch.device("cpu"))
generator.manual_seed(args["seed"]) generator.manual_seed(args["seed"])
num_frames = args["num_frames"] num_frames = args["num_frames"]
@ -275,50 +234,49 @@ class T2VSynthMochiModel:
T = (num_frames - 1) // temporal_downsample + 1 T = (num_frames - 1) // temporal_downsample + 1
H = height // spatial_downsample H = height // spatial_downsample
W = width // spatial_downsample W = width // spatial_downsample
latent_dims = dict(lT=T, lW=W, lH=H)
z = torch.randn( z = torch.randn(
(B, C, T, H, W), (B, C, T, H, W),
device=self.device, device=torch.device("cpu"),
generator=generator, generator=generator,
dtype=torch.float32, dtype=torch.float32,
) ).to(self.device)
if in_samples is not None: if in_samples is not None:
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device) z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device)
sample = { sample = {
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)], "y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)] "y_feat": [args["positive_embeds"]["embeds"].to(self.device)],
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
} }
sample_null = { sample_null = {
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)], "y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)] "y_feat": [args["negative_embeds"]["embeds"].to(self.device)],
} "fastercache": args["fastercache"] if args["fastercache"] is not None else None
}
sample["packed_indices"] = self.get_packed_indices( if args["fastercache"]:
sample["y_mask"], **latent_dims self.fastercache_start_step = args["fastercache"]["start_step"]
) self.fastercache_lf_step = args["fastercache"]["lf_step"]
sample_null["packed_indices"] = self.get_packed_indices( self.fastercache_hf_step = args["fastercache"]["hf_step"]
sample_null["y_mask"], **latent_dims else:
) self.fastercache_start_step = 1000
self.use_fastercache = True
self.fastercache_counter = 0 self.fastercache_counter = 0
self.fastercache_start_step = 15
self.fastercache_lf_step = 40
self.fastercache_hf_step = 30
def model_fn(*, z, sigma, cfg_scale): def model_fn(*, z, sigma, cfg_scale):
nonlocal sample, sample_null nonlocal sample, sample_null
if self.use_fastercache: if args["fastercache"]:
self.fastercache_counter+=1 self.fastercache_counter+=1
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0: if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
out_cond = self.dit(z, sigma,self.fastercache_counter, self.fastercache_start_step, **sample) out_cond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample)
(bb, cc, tt, hh, ww) = out_cond.shape (bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
lf_c, hf_c = fft(cond.float()) lf_c, hf_c = fft(cond.float())
#lf_step = 40
#hf_step = 30
if self.fastercache_counter <= self.fastercache_lf_step: if self.fastercache_counter <= self.fastercache_lf_step:
self.delta_lf = self.delta_lf * 1.1 self.delta_lf = self.delta_lf * 1.1
if self.fastercache_counter >= self.fastercache_hf_step: if self.fastercache_counter >= self.fastercache_hf_step:
@ -334,12 +292,19 @@ class T2VSynthMochiModel:
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond) return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
else: else:
out_cond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample) out_cond = self.dit(
out_uncond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample_null) z,
#print("out_cond.shape",out_cond.shape) #([1, 12, 3, 60, 106]) sigma,
self.fastercache_counter,
**sample)
out_uncond = self.dit(
z,
sigma,
self.fastercache_counter,
**sample_null)
if self.fastercache_counter >= self.fastercache_start_step + 1: if self.fastercache_counter >= self.fastercache_start_step + 1:
(bb, cc, tt, hh, ww) = out_cond.shape (bb, cc, tt, hh, ww) = out_cond.shape
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww) uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
@ -352,7 +317,6 @@ class T2VSynthMochiModel:
return out_uncond + cfg_scale * (out_cond - out_uncond) return out_uncond + cfg_scale * (out_cond - out_uncond)
comfy_pbar = ProgressBar(sample_steps) comfy_pbar = ProgressBar(sample_steps)
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul: if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
@ -373,13 +337,19 @@ class T2VSynthMochiModel:
sigma=torch.full([B], sigma, device=z.device), sigma=torch.full([B], sigma, device=z.device),
cfg_scale=cfg_schedule[i], cfg_scale=cfg_schedule[i],
) )
pred = pred.to(z) z = z + dsigma * pred.to(z)
z = z + dsigma * pred
if callback is not None: if callback is not None:
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps) callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
else: else:
comfy_pbar.update(1) comfy_pbar.update(1)
if args["fastercache"] is not None:
for block in self.dit.blocks:
if (hasattr, block, "cached_x_attention") and block.cached_x_attention is not None:
block.cached_x_attention = None
block.cached_y_attention = None
self.dit.to(self.offload_device) self.dit.to(self.offload_device)
mm.soft_empty_cache()
logging.info(f"samples shape: {z.shape}") logging.info(f"samples shape: {z.shape}")
return z return z

View File

@ -446,6 +446,36 @@ class MochiTextEncode:
} }
return (t5_embeds, clip,) return (t5_embeds, clip,)
#region FasterCache
class MochiFasterCache:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"start_step": ("INT", {"default": 10, "min": 0, "max": 1024, "step": 1, "tooltip": "The step to start caching, sigma schedule should be adjusted accordingly"}),
"hf_step": ("INT", {"default": 22, "min": 0, "max": 1024, "step": 1}),
"lf_step": ("INT", {"default": 28, "min": 0, "max": 1024, "step": 1}),
"cache_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}),
},
}
RETURN_TYPES = ("FASTERCACHEARGS",)
RETURN_NAMES = ("fastercache", )
FUNCTION = "args"
CATEGORY = "CogVideoWrapper"
DESCRIPTION = "FasterCache (https://github.com/Vchitect/FasterCache) settings for the MochiWrapper"
def args(self, start_step, hf_step, lf_step, cache_device):
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
fastercache = {
"start_step" : start_step,
"hf_step" : hf_step,
"lf_step" : lf_step,
"cache_device" : device if cache_device == "main_device" else offload_device
}
return (fastercache,)
#region Sampler #region Sampler
class MochiSampler: class MochiSampler:
@classmethod @classmethod
@ -466,6 +496,7 @@ class MochiSampler:
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}), "cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
"opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}), "opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}),
"samples": ("LATENT", ), "samples": ("LATENT", ),
"fastercache": ("FASTERCACHEARGS", {"tooltip": "Optional FasterCache settings"}),
} }
} }
@ -474,7 +505,7 @@ class MochiSampler:
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "MochiWrapper" CATEGORY = "MochiWrapper"
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None): def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None, fastercache=None):
mm.unload_all_models() mm.unload_all_models()
mm.soft_empty_cache() mm.soft_empty_cache()
@ -517,6 +548,7 @@ class MochiSampler:
"negative_embeds": negative, "negative_embeds": negative,
"seed": seed, "seed": seed,
"samples": samples["samples"] if samples is not None else None, "samples": samples["samples"] if samples is not None else None,
"fastercache": fastercache
} }
latents = model.run(args) latents = model.run(args)
@ -848,7 +880,8 @@ NODE_CLASS_MAPPINGS = {
"MochiTorchCompileSettings": MochiTorchCompileSettings, "MochiTorchCompileSettings": MochiTorchCompileSettings,
"MochiImageEncode": MochiImageEncode, "MochiImageEncode": MochiImageEncode,
"MochiLatentPreview": MochiLatentPreview, "MochiLatentPreview": MochiLatentPreview,
"MochiSigmaSchedule": MochiSigmaSchedule "MochiSigmaSchedule": MochiSigmaSchedule,
"MochiFasterCache": MochiFasterCache
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadMochiModel": "(Down)load Mochi Model", "DownloadAndLoadMochiModel": "(Down)load Mochi Model",
@ -862,5 +895,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"MochiTorchCompileSettings": "Mochi Torch Compile Settings", "MochiTorchCompileSettings": "Mochi Torch Compile Settings",
"MochiImageEncode": "Mochi Image Encode", "MochiImageEncode": "Mochi Image Encode",
"MochiLatentPreview": "Mochi Latent Preview", "MochiLatentPreview": "Mochi Latent Preview",
"MochiSigmaSchedule": "Mochi Sigma Schedule" "MochiSigmaSchedule": "Mochi Sigma Schedule",
"MochiFasterCache": "Mochi Faster Cache"
} }