update, works
This commit is contained in:
parent
3535a846a8
commit
24a7edfca6
@ -3,31 +3,22 @@ from typing import Dict, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
PatchEmbed,
|
||||
RMSNorm,
|
||||
TimestepEmbedder,
|
||||
)
|
||||
|
||||
|
||||
from .mod_rmsnorm import modulated_rmsnorm
|
||||
from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm)
|
||||
from .rope_mixed import (
|
||||
compute_mixed_rotation,
|
||||
create_position_matrix,
|
||||
)
|
||||
from .residual_tanh_gated_rmsnorm import residual_tanh_gated_rmsnorm
|
||||
from .rope_mixed import compute_mixed_rotation, create_position_matrix
|
||||
from .temporal_rope import apply_rotary_emb_qk_real
|
||||
from .utils import (
|
||||
pool_tokens,
|
||||
modulate,
|
||||
pad_and_split_xy,
|
||||
unify_streams,
|
||||
)
|
||||
from .utils import pool_tokens, modulate
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
from flash_attn import flash_attn_func
|
||||
FLASH_ATTN_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASH_ATTN_IS_AVAILABLE = False
|
||||
@ -45,6 +36,7 @@ backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
backends.append(SDPBackend.MATH)
|
||||
|
||||
import comfy.model_management as mm
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(
|
||||
@ -103,16 +95,16 @@ class AttentionPool(nn.Module):
|
||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||
|
||||
# Compute attention.
|
||||
with sdpa_kernel(backends):
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
x = self.to_out(x)
|
||||
return x
|
||||
|
||||
|
||||
#region Attention
|
||||
class AsymmetricAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -128,6 +120,7 @@ class AsymmetricAttention(nn.Module):
|
||||
softmax_scale: Optional[float] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
attention_mode: str = "sdpa",
|
||||
rms_norm_func: bool = False,
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
@ -154,6 +147,28 @@ class AsymmetricAttention(nn.Module):
|
||||
|
||||
# Query and key normalization for stability.
|
||||
assert qk_norm
|
||||
if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func
|
||||
try:
|
||||
from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster
|
||||
@torch.compiler.disable() #cause NaNs when compiled for some reason
|
||||
class RMSNorm(FlashTritonRMSNorm):
|
||||
pass
|
||||
except:
|
||||
raise ImportError("Flash Triton RMSNorm not available.")
|
||||
elif rms_norm_func == "flash_attn":
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm as FlashRMSNorm #slightly faster
|
||||
@torch.compiler.disable() #cause NaNs when compiled for some reason
|
||||
class RMSNorm(FlashRMSNorm):
|
||||
pass
|
||||
except:
|
||||
raise ImportError("Flash RMSNorm not available.")
|
||||
elif rms_norm_func == "apex":
|
||||
from apex.normalization import FusedRMSNorm as ApexRMSNorm
|
||||
class RMSNorm(ApexRMSNorm):
|
||||
pass
|
||||
else:
|
||||
from .layers import RMSNorm
|
||||
self.q_norm_x = RMSNorm(self.head_dim, device=device)
|
||||
self.k_norm_x = RMSNorm(self.head_dim, device=device)
|
||||
self.q_norm_y = RMSNorm(self.head_dim, device=device)
|
||||
@ -166,79 +181,23 @@ class AsymmetricAttention(nn.Module):
|
||||
if update_y
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def run_qkv_y(self, y):
|
||||
local_heads = self.num_heads
|
||||
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
|
||||
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
|
||||
q_y, k_y, v_y = qkv_y.unbind(2)
|
||||
return q_y, k_y, v_y
|
||||
|
||||
|
||||
def prepare_qkv(
|
||||
self,
|
||||
x: torch.Tensor, # (B, N, dim_x)
|
||||
y: torch.Tensor, # (B, L, dim_y)
|
||||
*,
|
||||
scale_x: torch.Tensor,
|
||||
scale_y: torch.Tensor,
|
||||
rope_cos: torch.Tensor,
|
||||
rope_sin: torch.Tensor,
|
||||
valid_token_indices: torch.Tensor,
|
||||
):
|
||||
# Pre-norm for visual features
|
||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||
|
||||
# Process visual features
|
||||
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||
#assert qkv_x.dtype == torch.bfloat16
|
||||
|
||||
# Move QKV dimension to the front.
|
||||
# B M (3 H d) -> 3 B M H d
|
||||
B, M, _ = qkv_x.size()
|
||||
qkv_x = qkv_x.view(B, M, 3, self.num_heads, -1)
|
||||
qkv_x = qkv_x.permute(2, 0, 1, 3, 4)
|
||||
|
||||
# Process text features
|
||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
|
||||
q_y = self.q_norm_y(q_y)
|
||||
k_y = self.k_norm_y(k_y)
|
||||
|
||||
# Split qkv_x into q, k, v
|
||||
q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
|
||||
q_x = self.q_norm_x(q_x)
|
||||
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||
k_x = self.k_norm_x(k_x)
|
||||
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||
|
||||
# Unite streams
|
||||
qkv = unify_streams(
|
||||
q_x,
|
||||
k_x,
|
||||
v_x,
|
||||
q_y,
|
||||
k_y,
|
||||
v_y,
|
||||
valid_token_indices,
|
||||
)
|
||||
|
||||
return qkv
|
||||
|
||||
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
|
||||
def flash_attention(self, q, k ,v):
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
b, _, _, dim_head = q.shape
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen_in_batch,
|
||||
out: torch.Tensor = flash_attn_func( #q: (batch_size, seqlen, nheads, headdim)
|
||||
q, k, v,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
) # (total, local_heads, head_dim)
|
||||
return out.view(total, local_dim)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
|
||||
|
||||
def sdpa_attention(self, qkv):
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||
def sdpa_attention(self, q, k, v):
|
||||
b, _, _, dim_head = q.shape
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
with sdpa_kernel(backends):
|
||||
out = F.scaled_dot_product_attention(
|
||||
@ -249,13 +208,10 @@ class AsymmetricAttention(nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
|
||||
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
|
||||
|
||||
def sage_attention(self, qkv):
|
||||
#q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||
|
||||
def sage_attention(self, q, k, v):
|
||||
b, _, _, dim_head = q.shape
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out = sageattn(
|
||||
q,
|
||||
@ -265,14 +221,9 @@ class AsymmetricAttention(nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
#print(out.shape)
|
||||
#out = rearrange(out, 'b h s d -> s (b h d)')
|
||||
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
|
||||
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
|
||||
|
||||
def comfy_attention(self, qkv):
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||
def comfy_attention(self, q, k, v):
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out = optimized_attention(
|
||||
q,
|
||||
@ -281,41 +232,23 @@ class AsymmetricAttention(nn.Module):
|
||||
heads = self.num_heads,
|
||||
skip_reshape=True
|
||||
)
|
||||
return out.squeeze(0)
|
||||
return out
|
||||
|
||||
@torch.compiler.disable()
|
||||
def run_attention(
|
||||
self,
|
||||
qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim)
|
||||
*,
|
||||
B: int,
|
||||
L: int,
|
||||
M: int,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen_in_batch: int,
|
||||
valid_token_indices: torch.Tensor,
|
||||
):
|
||||
local_dim = self.num_heads * self.head_dim
|
||||
total = qkv.size(0)
|
||||
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
):
|
||||
if self.attention_mode == "flash_attn":
|
||||
out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim)
|
||||
out = self.flash_attention(q, k ,v)
|
||||
elif self.attention_mode == "sdpa":
|
||||
out = self.sdpa_attention(qkv)
|
||||
out = self.sdpa_attention(q, k, v)
|
||||
elif self.attention_mode == "sage_attn":
|
||||
out = self.sage_attention(qkv)
|
||||
out = self.sage_attention(q, k, v)
|
||||
elif self.attention_mode == "comfy":
|
||||
out = self.comfy_attention(qkv)
|
||||
|
||||
x, y = pad_and_split_xy(out, valid_token_indices, B, M, L, qkv.dtype)
|
||||
assert x.size() == (B, M, local_dim)
|
||||
assert y.size() == (B, L, local_dim)
|
||||
|
||||
x = x.view(B, M, self.num_heads, self.head_dim)
|
||||
x = x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
|
||||
x = self.proj_x(x) # (B, M, dim_x)
|
||||
y = self.proj_y(y) # (B, L, dim_y)
|
||||
return x, y
|
||||
out = self.comfy_attention(q, k, v)
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -324,47 +257,44 @@ class AsymmetricAttention(nn.Module):
|
||||
*,
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
num_tokens,
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass of asymmetric multi-modal attention.
|
||||
|
||||
Args:
|
||||
x: (B, N, dim_x) tensor for visual tokens
|
||||
y: (B, L, dim_y) tensor of text token features
|
||||
packed_indices: Dict with keys for Flash Attention
|
||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
rope_sin = rope_rotation.get("rope_sin")
|
||||
|
||||
# Pre-norm for visual features
|
||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||
# Process text features
|
||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
|
||||
Returns:
|
||||
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
|
||||
y: (B, L, dim_y) tensor of text token features after multi-modal attention
|
||||
"""
|
||||
B, L, _ = y.shape
|
||||
_, M, _ = x.shape
|
||||
q_y = self.q_norm_y(q_y)
|
||||
k_y = self.k_norm_y(k_y)
|
||||
|
||||
# Predict a packed QKV tensor from visual and text features.
|
||||
# Don't checkpoint the all_to_all.
|
||||
qkv = self.prepare_qkv(
|
||||
x=x,
|
||||
y=y,
|
||||
scale_x=scale_x,
|
||||
scale_y=scale_y,
|
||||
rope_cos=rope_rotation.get("rope_cos"),
|
||||
rope_sin=rope_rotation.get("rope_sin"),
|
||||
valid_token_indices=packed_indices["valid_token_indices_kv"],
|
||||
) # (total <= B * (N + L), 3, local_heads, head_dim)
|
||||
# Split qkv_x into q, k, v
|
||||
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
q_x = self.q_norm_x(q_x)
|
||||
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||
k_x = self.k_norm_x(k_x)
|
||||
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||
|
||||
x, y = self.run_attention(
|
||||
qkv,
|
||||
B=B,
|
||||
L=L,
|
||||
M=M,
|
||||
cu_seqlens=packed_indices["cu_seqlens_kv"],
|
||||
max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
|
||||
valid_token_indices=packed_indices["valid_token_indices_kv"],
|
||||
)
|
||||
q = torch.cat([q_x, q_y[:, :num_tokens]], dim=1).transpose(1, 2)
|
||||
k = torch.cat([k_x, k_y[:, :num_tokens]], dim=1).transpose(1, 2)
|
||||
v = torch.cat([v_x, v_y[:, :num_tokens]], dim=1).transpose(1, 2)
|
||||
|
||||
xy = self.run_attention(q, k, v)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
|
||||
o[:, :y.shape[1]] = y
|
||||
|
||||
y = self.proj_y(o)
|
||||
return x, y
|
||||
|
||||
|
||||
#region Blocks
|
||||
class AsymmetricJointBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -377,6 +307,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
update_y: bool = True, # Whether to update text tokens in this block.
|
||||
device: Optional[torch.device] = None,
|
||||
attention_mode: str = "sdpa",
|
||||
rms_norm_func: str = "default",
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@ -401,6 +332,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
update_y=update_y,
|
||||
device=device,
|
||||
attention_mode=attention_mode,
|
||||
rms_norm_func=rms_norm_func,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
@ -432,7 +364,8 @@ class AsymmetricJointBlock(nn.Module):
|
||||
c: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
fastercache_counter: Optional[int] = 0,
|
||||
fastercache_start_step: Optional[int] = 15,
|
||||
fastercache_start_step: Optional[int] = 1000,
|
||||
fastercache_device: Optional[torch.device] = None,
|
||||
**attn_kwargs,
|
||||
):
|
||||
"""Forward pass of a block.
|
||||
@ -459,10 +392,11 @@ class AsymmetricJointBlock(nn.Module):
|
||||
else:
|
||||
scale_msa_y = mod_y
|
||||
|
||||
#fastercache
|
||||
#region fastercache
|
||||
B = x.shape[0]
|
||||
#print("x", x.shape) #([1, 9540, 3072])
|
||||
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B:
|
||||
print("using fastercache")
|
||||
x_attn = (
|
||||
self.cached_x_attention[1][:B] +
|
||||
(self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B])
|
||||
@ -483,22 +417,24 @@ class AsymmetricJointBlock(nn.Module):
|
||||
**attn_kwargs,
|
||||
)
|
||||
if fastercache_counter == fastercache_start_step:
|
||||
self.cached_x_attention = [x_attn, x_attn]
|
||||
self.cached_y_attention = [y_attn, y_attn]
|
||||
print("caching attention")
|
||||
self.cached_x_attention = [x_attn.to(fastercache_device), x_attn.to(fastercache_device)]
|
||||
self.cached_y_attention = [y_attn.to(fastercache_device), y_attn.to(fastercache_device)]
|
||||
elif fastercache_counter > fastercache_start_step:
|
||||
self.cached_x_attention[-1].copy_(x_attn)
|
||||
self.cached_y_attention[-1].copy_(y_attn)
|
||||
print("updating attention")
|
||||
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
|
||||
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
|
||||
|
||||
assert x_attn.size(1) == N
|
||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||
if self.update_y:
|
||||
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||
|
||||
|
||||
# MLP block.
|
||||
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||
if self.update_y:
|
||||
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||
|
||||
|
||||
return x, y
|
||||
|
||||
def ff_block_x(self, x, scale_x, gate_x):
|
||||
@ -542,7 +478,7 @@ class FinalLayer(nn.Module):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
#region Model
|
||||
class AsymmDiTJoint(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
@ -570,6 +506,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
rope_theta: float = 10000.0,
|
||||
device: Optional[torch.device] = None,
|
||||
attention_mode: str = "sdpa",
|
||||
rms_norm_func: str = "default",
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@ -638,6 +575,7 @@ class AsymmDiTJoint(nn.Module):
|
||||
update_y=update_y,
|
||||
device=device,
|
||||
attention_mode=attention_mode,
|
||||
rms_norm_func=rms_norm_func,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
@ -701,12 +639,11 @@ class AsymmDiTJoint(nn.Module):
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
fastercache_counter: int,
|
||||
fastercache_start_step: int,
|
||||
y_feat: List[torch.Tensor],
|
||||
y_mask: List[torch.Tensor],
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
fastercache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Forward pass of DiT.
|
||||
|
||||
@ -715,18 +652,25 @@ class AsymmDiTJoint(nn.Module):
|
||||
sigma: (B,) tensor of noise standard deviations
|
||||
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
|
||||
# Have to call sdpa_kernel outside of a torch.compile region.
|
||||
num_tokens = max(1, torch.sum(y_mask[0]).item())
|
||||
with sdpa_kernel(backends):
|
||||
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||
x, sigma, y_feat[0], y_mask[0]
|
||||
)
|
||||
del y_mask
|
||||
|
||||
if fastercache is not None:
|
||||
fastercache_start_step = fastercache["start_step"]
|
||||
fastercache_device = fastercache["cache_device"]
|
||||
else:
|
||||
fastercache_start_step = 1000
|
||||
fastercache_device = None
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
x, y_feat = block(
|
||||
x,
|
||||
@ -734,22 +678,24 @@ class AsymmDiTJoint(nn.Module):
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
packed_indices=packed_indices,
|
||||
num_tokens=num_tokens,
|
||||
fastercache_counter = fastercache_counter,
|
||||
fastercache_start_step = fastercache_start_step,
|
||||
fastercache_device = fastercache_device,
|
||||
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||
T=T,
|
||||
hp=H // self.patch_size,
|
||||
wp=W // self.patch_size,
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
c=self.out_channels,
|
||||
)
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
|
||||
hp = H // self.patch_size
|
||||
wp = W // self.patch_size
|
||||
p1 = self.patch_size
|
||||
p2 = self.patch_size
|
||||
c = self.out_channels
|
||||
|
||||
x = x.view(B, T, hp, wp, p1, p2, c)
|
||||
x = x.permute(0, 6, 1, 2, 4, 3, 5)
|
||||
x = x.reshape(B, c, T, hp * p1, wp * p2)
|
||||
|
||||
return x
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional, Union
|
||||
from einops import rearrange
|
||||
|
||||
#temporary patch to fix torch compile bug in Windows
|
||||
def patched_write_atomic(
|
||||
@ -33,11 +33,8 @@ except:
|
||||
pass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
from einops import rearrange, repeat
|
||||
|
||||
#from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||
from tqdm import tqdm
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
import comfy.model_management as mm
|
||||
@ -95,59 +92,17 @@ def unnormalize_latents(
|
||||
assert z.size(1) == mean.size(0) == std.size(0)
|
||||
return z * std.to(z) + mean.to(z)
|
||||
|
||||
|
||||
|
||||
def compute_packed_indices(
|
||||
N: int,
|
||||
text_mask: List[torch.Tensor],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80
|
||||
|
||||
Args:
|
||||
N: Number of visual tokens.
|
||||
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding.
|
||||
|
||||
Returns:
|
||||
packed_indices: Dict with keys for Flash Attention:
|
||||
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding)
|
||||
in the packed sequence.
|
||||
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence.
|
||||
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch.
|
||||
"""
|
||||
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
|
||||
assert N > 0 and len(text_mask) == 1
|
||||
text_mask = text_mask[0]
|
||||
|
||||
mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L)
|
||||
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
|
||||
valid_token_indices = torch.nonzero(
|
||||
mask.flatten(), as_tuple=False
|
||||
).flatten() # up to (B * (N + L),)
|
||||
|
||||
assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,)
|
||||
cu_seqlens = F.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
||||
)
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
|
||||
return {
|
||||
"cu_seqlens_kv": cu_seqlens,
|
||||
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
|
||||
"valid_token_indices_kv": valid_token_indices,
|
||||
}
|
||||
|
||||
class T2VSynthMochiModel:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
device: torch.device,
|
||||
offload_device: torch.device,
|
||||
vae_stats_path: str,
|
||||
dit_checkpoint_path: str,
|
||||
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||
fp8_fastmode: bool = False,
|
||||
attention_mode: str = "sdpa",
|
||||
rms_norm_func: str = "default",
|
||||
compile_args: Optional[Dict] = None,
|
||||
cublas_ops: Optional[bool] = False,
|
||||
):
|
||||
@ -177,11 +132,23 @@ class T2VSynthMochiModel:
|
||||
t5_token_length=256,
|
||||
rope_theta=10000.0,
|
||||
attention_mode=attention_mode,
|
||||
rms_norm_func=rms_norm_func,
|
||||
)
|
||||
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||
|
||||
#comfy format
|
||||
prefix = "model.diffusion_model."
|
||||
first_key = next(iter(dit_sd), None)
|
||||
if first_key and first_key.startswith(prefix):
|
||||
new_dit_sd = {
|
||||
key[len(prefix):] if key.startswith(prefix) else key: value
|
||||
for key, value in dit_sd.items()
|
||||
}
|
||||
dit_sd = new_dit_sd
|
||||
|
||||
if "gguf" in dit_checkpoint_path.lower():
|
||||
logging.info("Loading GGUF model state_dict...")
|
||||
from .. import mz_gguf_loader
|
||||
@ -220,18 +187,10 @@ class T2VSynthMochiModel:
|
||||
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
|
||||
|
||||
self.dit = model
|
||||
|
||||
vae_stats = json.load(open(vae_stats_path))
|
||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||
|
||||
def get_packed_indices(self, y_mask, *, lT, lW, lH):
|
||||
patch_size = 2
|
||||
N = lT * lH * lW // (patch_size**2)
|
||||
assert len(y_mask) == 1
|
||||
packed_indices = compute_packed_indices(N, y_mask)
|
||||
self.move_to_device_(packed_indices)
|
||||
return packed_indices
|
||||
def get_packed_indices(self, y_mask, **latent_dims):
|
||||
# temporary dummy func for compatibility
|
||||
return []
|
||||
|
||||
def move_to_device_(self, sample):
|
||||
if isinstance(sample, dict):
|
||||
@ -243,7 +202,7 @@ class T2VSynthMochiModel:
|
||||
torch.manual_seed(args["seed"])
|
||||
torch.cuda.manual_seed(args["seed"])
|
||||
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator = torch.Generator(device=torch.device("cpu"))
|
||||
generator.manual_seed(args["seed"])
|
||||
|
||||
num_frames = args["num_frames"]
|
||||
@ -275,50 +234,49 @@ class T2VSynthMochiModel:
|
||||
T = (num_frames - 1) // temporal_downsample + 1
|
||||
H = height // spatial_downsample
|
||||
W = width // spatial_downsample
|
||||
latent_dims = dict(lT=T, lW=W, lH=H)
|
||||
|
||||
z = torch.randn(
|
||||
(B, C, T, H, W),
|
||||
device=self.device,
|
||||
device=torch.device("cpu"),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
).to(self.device)
|
||||
if in_samples is not None:
|
||||
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device)
|
||||
|
||||
sample = {
|
||||
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
|
||||
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)],
|
||||
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
|
||||
}
|
||||
sample_null = {
|
||||
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)],
|
||||
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
|
||||
}
|
||||
|
||||
sample["packed_indices"] = self.get_packed_indices(
|
||||
sample["y_mask"], **latent_dims
|
||||
)
|
||||
sample_null["packed_indices"] = self.get_packed_indices(
|
||||
sample_null["y_mask"], **latent_dims
|
||||
)
|
||||
self.use_fastercache = True
|
||||
if args["fastercache"]:
|
||||
self.fastercache_start_step = args["fastercache"]["start_step"]
|
||||
self.fastercache_lf_step = args["fastercache"]["lf_step"]
|
||||
self.fastercache_hf_step = args["fastercache"]["hf_step"]
|
||||
else:
|
||||
self.fastercache_start_step = 1000
|
||||
self.fastercache_counter = 0
|
||||
self.fastercache_start_step = 15
|
||||
self.fastercache_lf_step = 40
|
||||
self.fastercache_hf_step = 30
|
||||
|
||||
def model_fn(*, z, sigma, cfg_scale):
|
||||
nonlocal sample, sample_null
|
||||
if self.use_fastercache:
|
||||
if args["fastercache"]:
|
||||
self.fastercache_counter+=1
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
||||
out_cond = self.dit(z, sigma,self.fastercache_counter, self.fastercache_start_step, **sample)
|
||||
out_cond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample)
|
||||
|
||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
lf_c, hf_c = fft(cond.float())
|
||||
#lf_step = 40
|
||||
#hf_step = 30
|
||||
if self.fastercache_counter <= self.fastercache_lf_step:
|
||||
self.delta_lf = self.delta_lf * 1.1
|
||||
if self.fastercache_counter >= self.fastercache_hf_step:
|
||||
@ -334,12 +292,19 @@ class T2VSynthMochiModel:
|
||||
|
||||
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
|
||||
else:
|
||||
out_cond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample)
|
||||
out_uncond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample_null)
|
||||
#print("out_cond.shape",out_cond.shape) #([1, 12, 3, 60, 106])
|
||||
out_cond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample)
|
||||
|
||||
out_uncond = self.dit(
|
||||
z,
|
||||
sigma,
|
||||
self.fastercache_counter,
|
||||
**sample_null)
|
||||
|
||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||
|
||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||
@ -352,7 +317,6 @@ class T2VSynthMochiModel:
|
||||
|
||||
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||
|
||||
|
||||
comfy_pbar = ProgressBar(sample_steps)
|
||||
|
||||
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
||||
@ -373,13 +337,19 @@ class T2VSynthMochiModel:
|
||||
sigma=torch.full([B], sigma, device=z.device),
|
||||
cfg_scale=cfg_schedule[i],
|
||||
)
|
||||
pred = pred.to(z)
|
||||
z = z + dsigma * pred
|
||||
z = z + dsigma * pred.to(z)
|
||||
if callback is not None:
|
||||
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
|
||||
else:
|
||||
comfy_pbar.update(1)
|
||||
|
||||
if args["fastercache"] is not None:
|
||||
for block in self.dit.blocks:
|
||||
if (hasattr, block, "cached_x_attention") and block.cached_x_attention is not None:
|
||||
block.cached_x_attention = None
|
||||
block.cached_y_attention = None
|
||||
|
||||
self.dit.to(self.offload_device)
|
||||
mm.soft_empty_cache()
|
||||
logging.info(f"samples shape: {z.shape}")
|
||||
return z
|
||||
|
||||
40
nodes.py
40
nodes.py
@ -446,6 +446,36 @@ class MochiTextEncode:
|
||||
}
|
||||
return (t5_embeds, clip,)
|
||||
|
||||
#region FasterCache
|
||||
class MochiFasterCache:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"start_step": ("INT", {"default": 10, "min": 0, "max": 1024, "step": 1, "tooltip": "The step to start caching, sigma schedule should be adjusted accordingly"}),
|
||||
"hf_step": ("INT", {"default": 22, "min": 0, "max": 1024, "step": 1}),
|
||||
"lf_step": ("INT", {"default": 28, "min": 0, "max": 1024, "step": 1}),
|
||||
"cache_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FASTERCACHEARGS",)
|
||||
RETURN_NAMES = ("fastercache", )
|
||||
FUNCTION = "args"
|
||||
CATEGORY = "CogVideoWrapper"
|
||||
DESCRIPTION = "FasterCache (https://github.com/Vchitect/FasterCache) settings for the MochiWrapper"
|
||||
|
||||
def args(self, start_step, hf_step, lf_step, cache_device):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
fastercache = {
|
||||
"start_step" : start_step,
|
||||
"hf_step" : hf_step,
|
||||
"lf_step" : lf_step,
|
||||
"cache_device" : device if cache_device == "main_device" else offload_device
|
||||
}
|
||||
return (fastercache,)
|
||||
|
||||
#region Sampler
|
||||
class MochiSampler:
|
||||
@classmethod
|
||||
@ -466,6 +496,7 @@ class MochiSampler:
|
||||
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
|
||||
"opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}),
|
||||
"samples": ("LATENT", ),
|
||||
"fastercache": ("FASTERCACHEARGS", {"tooltip": "Optional FasterCache settings"}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -474,7 +505,7 @@ class MochiSampler:
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "MochiWrapper"
|
||||
|
||||
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
|
||||
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None, fastercache=None):
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
|
||||
@ -517,6 +548,7 @@ class MochiSampler:
|
||||
"negative_embeds": negative,
|
||||
"seed": seed,
|
||||
"samples": samples["samples"] if samples is not None else None,
|
||||
"fastercache": fastercache
|
||||
}
|
||||
latents = model.run(args)
|
||||
|
||||
@ -848,7 +880,8 @@ NODE_CLASS_MAPPINGS = {
|
||||
"MochiTorchCompileSettings": MochiTorchCompileSettings,
|
||||
"MochiImageEncode": MochiImageEncode,
|
||||
"MochiLatentPreview": MochiLatentPreview,
|
||||
"MochiSigmaSchedule": MochiSigmaSchedule
|
||||
"MochiSigmaSchedule": MochiSigmaSchedule,
|
||||
"MochiFasterCache": MochiFasterCache
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
||||
@ -862,5 +895,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"MochiTorchCompileSettings": "Mochi Torch Compile Settings",
|
||||
"MochiImageEncode": "Mochi Image Encode",
|
||||
"MochiLatentPreview": "Mochi Latent Preview",
|
||||
"MochiSigmaSchedule": "Mochi Sigma Schedule"
|
||||
"MochiSigmaSchedule": "Mochi Sigma Schedule",
|
||||
"MochiFasterCache": "Mochi Faster Cache"
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user