update, works
This commit is contained in:
parent
3535a846a8
commit
24a7edfca6
@ -3,31 +3,22 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
PatchEmbed,
|
PatchEmbed,
|
||||||
RMSNorm,
|
|
||||||
TimestepEmbedder,
|
TimestepEmbedder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
from .mod_rmsnorm import modulated_rmsnorm
|
from .mod_rmsnorm import modulated_rmsnorm
|
||||||
from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm)
|
from .residual_tanh_gated_rmsnorm import residual_tanh_gated_rmsnorm
|
||||||
from .rope_mixed import (
|
from .rope_mixed import compute_mixed_rotation, create_position_matrix
|
||||||
compute_mixed_rotation,
|
|
||||||
create_position_matrix,
|
|
||||||
)
|
|
||||||
from .temporal_rope import apply_rotary_emb_qk_real
|
from .temporal_rope import apply_rotary_emb_qk_real
|
||||||
from .utils import (
|
from .utils import pool_tokens, modulate
|
||||||
pool_tokens,
|
|
||||||
modulate,
|
|
||||||
pad_and_split_xy,
|
|
||||||
unify_streams,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
from flash_attn import flash_attn_func
|
||||||
FLASH_ATTN_IS_AVAILABLE = True
|
FLASH_ATTN_IS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
FLASH_ATTN_IS_AVAILABLE = False
|
FLASH_ATTN_IS_AVAILABLE = False
|
||||||
@ -45,6 +36,7 @@ backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
|||||||
backends.append(SDPBackend.MATH)
|
backends.append(SDPBackend.MATH)
|
||||||
|
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
class AttentionPool(nn.Module):
|
class AttentionPool(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -103,16 +95,16 @@ class AttentionPool(nn.Module):
|
|||||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||||
|
|
||||||
# Compute attention.
|
# Compute attention.
|
||||||
with sdpa_kernel(backends):
|
x = F.scaled_dot_product_attention(
|
||||||
x = F.scaled_dot_product_attention(
|
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
) # (B, H, 1, head_dim)
|
||||||
) # (B, H, 1, head_dim)
|
|
||||||
|
|
||||||
# Concatenate heads and run output.
|
# Concatenate heads and run output.
|
||||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||||
x = self.to_out(x)
|
x = self.to_out(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
#region Attention
|
||||||
class AsymmetricAttention(nn.Module):
|
class AsymmetricAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -128,6 +120,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
attention_mode: str = "sdpa",
|
attention_mode: str = "sdpa",
|
||||||
|
rms_norm_func: bool = False,
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -154,6 +147,28 @@ class AsymmetricAttention(nn.Module):
|
|||||||
|
|
||||||
# Query and key normalization for stability.
|
# Query and key normalization for stability.
|
||||||
assert qk_norm
|
assert qk_norm
|
||||||
|
if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func
|
||||||
|
try:
|
||||||
|
from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster
|
||||||
|
@torch.compiler.disable() #cause NaNs when compiled for some reason
|
||||||
|
class RMSNorm(FlashTritonRMSNorm):
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
raise ImportError("Flash Triton RMSNorm not available.")
|
||||||
|
elif rms_norm_func == "flash_attn":
|
||||||
|
try:
|
||||||
|
from flash_attn.ops.rms_norm import RMSNorm as FlashRMSNorm #slightly faster
|
||||||
|
@torch.compiler.disable() #cause NaNs when compiled for some reason
|
||||||
|
class RMSNorm(FlashRMSNorm):
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
raise ImportError("Flash RMSNorm not available.")
|
||||||
|
elif rms_norm_func == "apex":
|
||||||
|
from apex.normalization import FusedRMSNorm as ApexRMSNorm
|
||||||
|
class RMSNorm(ApexRMSNorm):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from .layers import RMSNorm
|
||||||
self.q_norm_x = RMSNorm(self.head_dim, device=device)
|
self.q_norm_x = RMSNorm(self.head_dim, device=device)
|
||||||
self.k_norm_x = RMSNorm(self.head_dim, device=device)
|
self.k_norm_x = RMSNorm(self.head_dim, device=device)
|
||||||
self.q_norm_y = RMSNorm(self.head_dim, device=device)
|
self.q_norm_y = RMSNorm(self.head_dim, device=device)
|
||||||
@ -166,79 +181,23 @@ class AsymmetricAttention(nn.Module):
|
|||||||
if update_y
|
if update_y
|
||||||
else nn.Identity()
|
else nn.Identity()
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_qkv_y(self, y):
|
|
||||||
local_heads = self.num_heads
|
|
||||||
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
|
|
||||||
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
|
|
||||||
q_y, k_y, v_y = qkv_y.unbind(2)
|
|
||||||
return q_y, k_y, v_y
|
|
||||||
|
|
||||||
|
def flash_attention(self, q, k ,v):
|
||||||
def prepare_qkv(
|
q = q.permute(0, 2, 1, 3)
|
||||||
self,
|
k = k.permute(0, 2, 1, 3)
|
||||||
x: torch.Tensor, # (B, N, dim_x)
|
v = v.permute(0, 2, 1, 3)
|
||||||
y: torch.Tensor, # (B, L, dim_y)
|
b, _, _, dim_head = q.shape
|
||||||
*,
|
|
||||||
scale_x: torch.Tensor,
|
|
||||||
scale_y: torch.Tensor,
|
|
||||||
rope_cos: torch.Tensor,
|
|
||||||
rope_sin: torch.Tensor,
|
|
||||||
valid_token_indices: torch.Tensor,
|
|
||||||
):
|
|
||||||
# Pre-norm for visual features
|
|
||||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
|
||||||
|
|
||||||
# Process visual features
|
|
||||||
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
|
||||||
#assert qkv_x.dtype == torch.bfloat16
|
|
||||||
|
|
||||||
# Move QKV dimension to the front.
|
|
||||||
# B M (3 H d) -> 3 B M H d
|
|
||||||
B, M, _ = qkv_x.size()
|
|
||||||
qkv_x = qkv_x.view(B, M, 3, self.num_heads, -1)
|
|
||||||
qkv_x = qkv_x.permute(2, 0, 1, 3, 4)
|
|
||||||
|
|
||||||
# Process text features
|
|
||||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
|
||||||
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
|
|
||||||
q_y = self.q_norm_y(q_y)
|
|
||||||
k_y = self.k_norm_y(k_y)
|
|
||||||
|
|
||||||
# Split qkv_x into q, k, v
|
|
||||||
q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
|
|
||||||
q_x = self.q_norm_x(q_x)
|
|
||||||
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
|
||||||
k_x = self.k_norm_x(k_x)
|
|
||||||
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
|
||||||
|
|
||||||
# Unite streams
|
|
||||||
qkv = unify_streams(
|
|
||||||
q_x,
|
|
||||||
k_x,
|
|
||||||
v_x,
|
|
||||||
q_y,
|
|
||||||
k_y,
|
|
||||||
v_y,
|
|
||||||
valid_token_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
return qkv
|
|
||||||
|
|
||||||
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
|
|
||||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
out: torch.Tensor = flash_attn_func( #q: (batch_size, seqlen, nheads, headdim)
|
||||||
qkv,
|
q, k, v,
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen_in_batch,
|
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
) # (total, local_heads, head_dim)
|
) # (total, local_heads, head_dim)
|
||||||
return out.view(total, local_dim)
|
out = out.permute(0, 2, 1, 3)
|
||||||
|
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
|
||||||
|
|
||||||
def sdpa_attention(self, qkv):
|
def sdpa_attention(self, q, k, v):
|
||||||
q, k, v = qkv.unbind(dim=1)
|
b, _, _, dim_head = q.shape
|
||||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
|
||||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
with sdpa_kernel(backends):
|
with sdpa_kernel(backends):
|
||||||
out = F.scaled_dot_product_attention(
|
out = F.scaled_dot_product_attention(
|
||||||
@ -249,13 +208,10 @@ class AsymmetricAttention(nn.Module):
|
|||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
is_causal=False
|
is_causal=False
|
||||||
)
|
)
|
||||||
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
|
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
|
||||||
|
|
||||||
def sage_attention(self, qkv):
|
def sage_attention(self, q, k, v):
|
||||||
#q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
b, _, _, dim_head = q.shape
|
||||||
q, k, v = qkv.unbind(dim=1)
|
|
||||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
|
||||||
|
|
||||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
out = sageattn(
|
out = sageattn(
|
||||||
q,
|
q,
|
||||||
@ -265,14 +221,9 @@ class AsymmetricAttention(nn.Module):
|
|||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
is_causal=False
|
is_causal=False
|
||||||
)
|
)
|
||||||
#print(out.shape)
|
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
|
||||||
#out = rearrange(out, 'b h s d -> s (b h d)')
|
|
||||||
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
|
|
||||||
|
|
||||||
def comfy_attention(self, qkv):
|
def comfy_attention(self, q, k, v):
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
|
||||||
q, k, v = qkv.unbind(dim=1)
|
|
||||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
|
||||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||||
out = optimized_attention(
|
out = optimized_attention(
|
||||||
q,
|
q,
|
||||||
@ -281,41 +232,23 @@ class AsymmetricAttention(nn.Module):
|
|||||||
heads = self.num_heads,
|
heads = self.num_heads,
|
||||||
skip_reshape=True
|
skip_reshape=True
|
||||||
)
|
)
|
||||||
return out.squeeze(0)
|
return out
|
||||||
|
|
||||||
@torch.compiler.disable()
|
|
||||||
def run_attention(
|
def run_attention(
|
||||||
self,
|
self,
|
||||||
qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim)
|
q,
|
||||||
*,
|
k,
|
||||||
B: int,
|
v,
|
||||||
L: int,
|
):
|
||||||
M: int,
|
|
||||||
cu_seqlens: torch.Tensor,
|
|
||||||
max_seqlen_in_batch: int,
|
|
||||||
valid_token_indices: torch.Tensor,
|
|
||||||
):
|
|
||||||
local_dim = self.num_heads * self.head_dim
|
|
||||||
total = qkv.size(0)
|
|
||||||
|
|
||||||
if self.attention_mode == "flash_attn":
|
if self.attention_mode == "flash_attn":
|
||||||
out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim)
|
out = self.flash_attention(q, k ,v)
|
||||||
elif self.attention_mode == "sdpa":
|
elif self.attention_mode == "sdpa":
|
||||||
out = self.sdpa_attention(qkv)
|
out = self.sdpa_attention(q, k, v)
|
||||||
elif self.attention_mode == "sage_attn":
|
elif self.attention_mode == "sage_attn":
|
||||||
out = self.sage_attention(qkv)
|
out = self.sage_attention(q, k, v)
|
||||||
elif self.attention_mode == "comfy":
|
elif self.attention_mode == "comfy":
|
||||||
out = self.comfy_attention(qkv)
|
out = self.comfy_attention(q, k, v)
|
||||||
|
return out
|
||||||
x, y = pad_and_split_xy(out, valid_token_indices, B, M, L, qkv.dtype)
|
|
||||||
assert x.size() == (B, M, local_dim)
|
|
||||||
assert y.size() == (B, L, local_dim)
|
|
||||||
|
|
||||||
x = x.view(B, M, self.num_heads, self.head_dim)
|
|
||||||
x = x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
|
|
||||||
x = self.proj_x(x) # (B, M, dim_x)
|
|
||||||
y = self.proj_y(y) # (B, L, dim_y)
|
|
||||||
return x, y
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -324,47 +257,44 @@ class AsymmetricAttention(nn.Module):
|
|||||||
*,
|
*,
|
||||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||||
packed_indices: Dict[str, torch.Tensor] = None,
|
num_tokens,
|
||||||
**rope_rotation,
|
**rope_rotation,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Forward pass of asymmetric multi-modal attention.
|
|
||||||
|
|
||||||
Args:
|
rope_cos = rope_rotation.get("rope_cos")
|
||||||
x: (B, N, dim_x) tensor for visual tokens
|
rope_sin = rope_rotation.get("rope_sin")
|
||||||
y: (B, L, dim_y) tensor of text token features
|
|
||||||
packed_indices: Dict with keys for Flash Attention
|
# Pre-norm for visual features
|
||||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||||
|
# Process text features
|
||||||
|
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||||
|
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||||
|
|
||||||
Returns:
|
q_y = self.q_norm_y(q_y)
|
||||||
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
|
k_y = self.k_norm_y(k_y)
|
||||||
y: (B, L, dim_y) tensor of text token features after multi-modal attention
|
|
||||||
"""
|
|
||||||
B, L, _ = y.shape
|
|
||||||
_, M, _ = x.shape
|
|
||||||
|
|
||||||
# Predict a packed QKV tensor from visual and text features.
|
# Split qkv_x into q, k, v
|
||||||
# Don't checkpoint the all_to_all.
|
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||||
qkv = self.prepare_qkv(
|
q_x = self.q_norm_x(q_x)
|
||||||
x=x,
|
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||||
y=y,
|
k_x = self.k_norm_x(k_x)
|
||||||
scale_x=scale_x,
|
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||||
scale_y=scale_y,
|
|
||||||
rope_cos=rope_rotation.get("rope_cos"),
|
|
||||||
rope_sin=rope_rotation.get("rope_sin"),
|
|
||||||
valid_token_indices=packed_indices["valid_token_indices_kv"],
|
|
||||||
) # (total <= B * (N + L), 3, local_heads, head_dim)
|
|
||||||
|
|
||||||
x, y = self.run_attention(
|
q = torch.cat([q_x, q_y[:, :num_tokens]], dim=1).transpose(1, 2)
|
||||||
qkv,
|
k = torch.cat([k_x, k_y[:, :num_tokens]], dim=1).transpose(1, 2)
|
||||||
B=B,
|
v = torch.cat([v_x, v_y[:, :num_tokens]], dim=1).transpose(1, 2)
|
||||||
L=L,
|
|
||||||
M=M,
|
xy = self.run_attention(q, k, v)
|
||||||
cu_seqlens=packed_indices["cu_seqlens_kv"],
|
|
||||||
max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
|
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||||
valid_token_indices=packed_indices["valid_token_indices_kv"],
|
x = self.proj_x(x)
|
||||||
)
|
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
|
||||||
|
o[:, :y.shape[1]] = y
|
||||||
|
|
||||||
|
y = self.proj_y(o)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
|
#region Blocks
|
||||||
class AsymmetricJointBlock(nn.Module):
|
class AsymmetricJointBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -377,6 +307,7 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
update_y: bool = True, # Whether to update text tokens in this block.
|
update_y: bool = True, # Whether to update text tokens in this block.
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
attention_mode: str = "sdpa",
|
attention_mode: str = "sdpa",
|
||||||
|
rms_norm_func: str = "default",
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -401,6 +332,7 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
update_y=update_y,
|
update_y=update_y,
|
||||||
device=device,
|
device=device,
|
||||||
attention_mode=attention_mode,
|
attention_mode=attention_mode,
|
||||||
|
rms_norm_func=rms_norm_func,
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -432,7 +364,8 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
c: torch.Tensor,
|
c: torch.Tensor,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
fastercache_counter: Optional[int] = 0,
|
fastercache_counter: Optional[int] = 0,
|
||||||
fastercache_start_step: Optional[int] = 15,
|
fastercache_start_step: Optional[int] = 1000,
|
||||||
|
fastercache_device: Optional[torch.device] = None,
|
||||||
**attn_kwargs,
|
**attn_kwargs,
|
||||||
):
|
):
|
||||||
"""Forward pass of a block.
|
"""Forward pass of a block.
|
||||||
@ -459,10 +392,11 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
scale_msa_y = mod_y
|
scale_msa_y = mod_y
|
||||||
|
|
||||||
#fastercache
|
#region fastercache
|
||||||
B = x.shape[0]
|
B = x.shape[0]
|
||||||
#print("x", x.shape) #([1, 9540, 3072])
|
#print("x", x.shape) #([1, 9540, 3072])
|
||||||
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B:
|
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B:
|
||||||
|
print("using fastercache")
|
||||||
x_attn = (
|
x_attn = (
|
||||||
self.cached_x_attention[1][:B] +
|
self.cached_x_attention[1][:B] +
|
||||||
(self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B])
|
(self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B])
|
||||||
@ -483,22 +417,24 @@ class AsymmetricJointBlock(nn.Module):
|
|||||||
**attn_kwargs,
|
**attn_kwargs,
|
||||||
)
|
)
|
||||||
if fastercache_counter == fastercache_start_step:
|
if fastercache_counter == fastercache_start_step:
|
||||||
self.cached_x_attention = [x_attn, x_attn]
|
print("caching attention")
|
||||||
self.cached_y_attention = [y_attn, y_attn]
|
self.cached_x_attention = [x_attn.to(fastercache_device), x_attn.to(fastercache_device)]
|
||||||
|
self.cached_y_attention = [y_attn.to(fastercache_device), y_attn.to(fastercache_device)]
|
||||||
elif fastercache_counter > fastercache_start_step:
|
elif fastercache_counter > fastercache_start_step:
|
||||||
self.cached_x_attention[-1].copy_(x_attn)
|
print("updating attention")
|
||||||
self.cached_y_attention[-1].copy_(y_attn)
|
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
|
||||||
|
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
|
||||||
|
|
||||||
assert x_attn.size(1) == N
|
assert x_attn.size(1) == N
|
||||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||||
if self.update_y:
|
if self.update_y:
|
||||||
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||||
|
|
||||||
# MLP block.
|
# MLP block.
|
||||||
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||||
if self.update_y:
|
if self.update_y:
|
||||||
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||||
|
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
def ff_block_x(self, x, scale_x, gate_x):
|
def ff_block_x(self, x, scale_x, gate_x):
|
||||||
@ -542,7 +478,7 @@ class FinalLayer(nn.Module):
|
|||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
#region Model
|
||||||
class AsymmDiTJoint(nn.Module):
|
class AsymmDiTJoint(nn.Module):
|
||||||
"""
|
"""
|
||||||
Diffusion model with a Transformer backbone.
|
Diffusion model with a Transformer backbone.
|
||||||
@ -570,6 +506,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
rope_theta: float = 10000.0,
|
rope_theta: float = 10000.0,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
attention_mode: str = "sdpa",
|
attention_mode: str = "sdpa",
|
||||||
|
rms_norm_func: str = "default",
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -638,6 +575,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
update_y=update_y,
|
update_y=update_y,
|
||||||
device=device,
|
device=device,
|
||||||
attention_mode=attention_mode,
|
attention_mode=attention_mode,
|
||||||
|
rms_norm_func=rms_norm_func,
|
||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -701,12 +639,11 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma: torch.Tensor,
|
sigma: torch.Tensor,
|
||||||
fastercache_counter: int,
|
fastercache_counter: int,
|
||||||
fastercache_start_step: int,
|
|
||||||
y_feat: List[torch.Tensor],
|
y_feat: List[torch.Tensor],
|
||||||
y_mask: List[torch.Tensor],
|
y_mask: List[torch.Tensor],
|
||||||
packed_indices: Dict[str, torch.Tensor] = None,
|
|
||||||
rope_cos: torch.Tensor = None,
|
rope_cos: torch.Tensor = None,
|
||||||
rope_sin: torch.Tensor = None,
|
rope_sin: torch.Tensor = None,
|
||||||
|
fastercache: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
):
|
):
|
||||||
"""Forward pass of DiT.
|
"""Forward pass of DiT.
|
||||||
|
|
||||||
@ -715,18 +652,25 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
sigma: (B,) tensor of noise standard deviations
|
sigma: (B,) tensor of noise standard deviations
|
||||||
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||||
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||||
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
|
||||||
"""
|
"""
|
||||||
B, _, T, H, W = x.shape
|
B, _, T, H, W = x.shape
|
||||||
|
|
||||||
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
|
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
|
||||||
# Have to call sdpa_kernel outside of a torch.compile region.
|
# Have to call sdpa_kernel outside of a torch.compile region.
|
||||||
|
num_tokens = max(1, torch.sum(y_mask[0]).item())
|
||||||
with sdpa_kernel(backends):
|
with sdpa_kernel(backends):
|
||||||
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||||
x, sigma, y_feat[0], y_mask[0]
|
x, sigma, y_feat[0], y_mask[0]
|
||||||
)
|
)
|
||||||
del y_mask
|
del y_mask
|
||||||
|
|
||||||
|
if fastercache is not None:
|
||||||
|
fastercache_start_step = fastercache["start_step"]
|
||||||
|
fastercache_device = fastercache["cache_device"]
|
||||||
|
else:
|
||||||
|
fastercache_start_step = 1000
|
||||||
|
fastercache_device = None
|
||||||
|
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
x, y_feat = block(
|
x, y_feat = block(
|
||||||
x,
|
x,
|
||||||
@ -734,22 +678,24 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
y_feat,
|
y_feat,
|
||||||
rope_cos=rope_cos,
|
rope_cos=rope_cos,
|
||||||
rope_sin=rope_sin,
|
rope_sin=rope_sin,
|
||||||
packed_indices=packed_indices,
|
num_tokens=num_tokens,
|
||||||
fastercache_counter = fastercache_counter,
|
fastercache_counter = fastercache_counter,
|
||||||
fastercache_start_step = fastercache_start_step,
|
fastercache_start_step = fastercache_start_step,
|
||||||
|
fastercache_device = fastercache_device,
|
||||||
|
|
||||||
) # (B, M, D), (B, L, D)
|
) # (B, M, D), (B, L, D)
|
||||||
del y_feat # Final layers don't use dense text features.
|
del y_feat # Final layers don't use dense text features.
|
||||||
|
|
||||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||||
x = rearrange(
|
|
||||||
x,
|
hp = H // self.patch_size
|
||||||
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
wp = W // self.patch_size
|
||||||
T=T,
|
p1 = self.patch_size
|
||||||
hp=H // self.patch_size,
|
p2 = self.patch_size
|
||||||
wp=W // self.patch_size,
|
c = self.out_channels
|
||||||
p1=self.patch_size,
|
|
||||||
p2=self.patch_size,
|
x = x.view(B, T, hp, wp, p1, p2, c)
|
||||||
c=self.out_channels,
|
x = x.permute(0, 6, 1, 2, 4, 3, 5)
|
||||||
)
|
x = x.reshape(B, c, T, hp * p1, wp * p2)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import json
|
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
#temporary patch to fix torch compile bug in Windows
|
#temporary patch to fix torch compile bug in Windows
|
||||||
def patched_write_atomic(
|
def patched_write_atomic(
|
||||||
@ -33,11 +33,8 @@ except:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
#from .dit.joint_model.context_parallel import get_cp_rank_size
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from comfy.utils import ProgressBar, load_torch_file
|
from comfy.utils import ProgressBar, load_torch_file
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
@ -95,59 +92,17 @@ def unnormalize_latents(
|
|||||||
assert z.size(1) == mean.size(0) == std.size(0)
|
assert z.size(1) == mean.size(0) == std.size(0)
|
||||||
return z * std.to(z) + mean.to(z)
|
return z * std.to(z) + mean.to(z)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compute_packed_indices(
|
|
||||||
N: int,
|
|
||||||
text_mask: List[torch.Tensor],
|
|
||||||
) -> Dict[str, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80
|
|
||||||
|
|
||||||
Args:
|
|
||||||
N: Number of visual tokens.
|
|
||||||
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
packed_indices: Dict with keys for Flash Attention:
|
|
||||||
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding)
|
|
||||||
in the packed sequence.
|
|
||||||
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence.
|
|
||||||
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch.
|
|
||||||
"""
|
|
||||||
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
|
|
||||||
assert N > 0 and len(text_mask) == 1
|
|
||||||
text_mask = text_mask[0]
|
|
||||||
|
|
||||||
mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L)
|
|
||||||
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
|
|
||||||
valid_token_indices = torch.nonzero(
|
|
||||||
mask.flatten(), as_tuple=False
|
|
||||||
).flatten() # up to (B * (N + L),)
|
|
||||||
|
|
||||||
assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,)
|
|
||||||
cu_seqlens = F.pad(
|
|
||||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
|
||||||
)
|
|
||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"cu_seqlens_kv": cu_seqlens,
|
|
||||||
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
|
|
||||||
"valid_token_indices_kv": valid_token_indices,
|
|
||||||
}
|
|
||||||
|
|
||||||
class T2VSynthMochiModel:
|
class T2VSynthMochiModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
offload_device: torch.device,
|
offload_device: torch.device,
|
||||||
vae_stats_path: str,
|
|
||||||
dit_checkpoint_path: str,
|
dit_checkpoint_path: str,
|
||||||
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
weight_dtype: torch.dtype = torch.float8_e4m3fn,
|
||||||
fp8_fastmode: bool = False,
|
fp8_fastmode: bool = False,
|
||||||
attention_mode: str = "sdpa",
|
attention_mode: str = "sdpa",
|
||||||
|
rms_norm_func: str = "default",
|
||||||
compile_args: Optional[Dict] = None,
|
compile_args: Optional[Dict] = None,
|
||||||
cublas_ops: Optional[bool] = False,
|
cublas_ops: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
@ -177,11 +132,23 @@ class T2VSynthMochiModel:
|
|||||||
t5_token_length=256,
|
t5_token_length=256,
|
||||||
rope_theta=10000.0,
|
rope_theta=10000.0,
|
||||||
attention_mode=attention_mode,
|
attention_mode=attention_mode,
|
||||||
|
rms_norm_func=rms_norm_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||||
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
|
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||||
dit_sd = load_torch_file(dit_checkpoint_path)
|
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||||
|
|
||||||
|
#comfy format
|
||||||
|
prefix = "model.diffusion_model."
|
||||||
|
first_key = next(iter(dit_sd), None)
|
||||||
|
if first_key and first_key.startswith(prefix):
|
||||||
|
new_dit_sd = {
|
||||||
|
key[len(prefix):] if key.startswith(prefix) else key: value
|
||||||
|
for key, value in dit_sd.items()
|
||||||
|
}
|
||||||
|
dit_sd = new_dit_sd
|
||||||
|
|
||||||
if "gguf" in dit_checkpoint_path.lower():
|
if "gguf" in dit_checkpoint_path.lower():
|
||||||
logging.info("Loading GGUF model state_dict...")
|
logging.info("Loading GGUF model state_dict...")
|
||||||
from .. import mz_gguf_loader
|
from .. import mz_gguf_loader
|
||||||
@ -220,18 +187,10 @@ class T2VSynthMochiModel:
|
|||||||
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
|
model.final_layer = torch.compile(model.final_layer, fullgraph=compile_args["fullgraph"], dynamic=False, backend=compile_args["backend"])
|
||||||
|
|
||||||
self.dit = model
|
self.dit = model
|
||||||
|
|
||||||
vae_stats = json.load(open(vae_stats_path))
|
|
||||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
|
||||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
|
||||||
|
|
||||||
def get_packed_indices(self, y_mask, *, lT, lW, lH):
|
def get_packed_indices(self, y_mask, **latent_dims):
|
||||||
patch_size = 2
|
# temporary dummy func for compatibility
|
||||||
N = lT * lH * lW // (patch_size**2)
|
return []
|
||||||
assert len(y_mask) == 1
|
|
||||||
packed_indices = compute_packed_indices(N, y_mask)
|
|
||||||
self.move_to_device_(packed_indices)
|
|
||||||
return packed_indices
|
|
||||||
|
|
||||||
def move_to_device_(self, sample):
|
def move_to_device_(self, sample):
|
||||||
if isinstance(sample, dict):
|
if isinstance(sample, dict):
|
||||||
@ -243,7 +202,7 @@ class T2VSynthMochiModel:
|
|||||||
torch.manual_seed(args["seed"])
|
torch.manual_seed(args["seed"])
|
||||||
torch.cuda.manual_seed(args["seed"])
|
torch.cuda.manual_seed(args["seed"])
|
||||||
|
|
||||||
generator = torch.Generator(device=self.device)
|
generator = torch.Generator(device=torch.device("cpu"))
|
||||||
generator.manual_seed(args["seed"])
|
generator.manual_seed(args["seed"])
|
||||||
|
|
||||||
num_frames = args["num_frames"]
|
num_frames = args["num_frames"]
|
||||||
@ -275,50 +234,49 @@ class T2VSynthMochiModel:
|
|||||||
T = (num_frames - 1) // temporal_downsample + 1
|
T = (num_frames - 1) // temporal_downsample + 1
|
||||||
H = height // spatial_downsample
|
H = height // spatial_downsample
|
||||||
W = width // spatial_downsample
|
W = width // spatial_downsample
|
||||||
latent_dims = dict(lT=T, lW=W, lH=H)
|
|
||||||
|
|
||||||
z = torch.randn(
|
z = torch.randn(
|
||||||
(B, C, T, H, W),
|
(B, C, T, H, W),
|
||||||
device=self.device,
|
device=torch.device("cpu"),
|
||||||
generator=generator,
|
generator=generator,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
).to(self.device)
|
||||||
if in_samples is not None:
|
if in_samples is not None:
|
||||||
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device)
|
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device)
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
"y_mask": [args["positive_embeds"]["attention_mask"].to(self.device)],
|
||||||
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)]
|
"y_feat": [args["positive_embeds"]["embeds"].to(self.device)],
|
||||||
|
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
|
||||||
}
|
}
|
||||||
sample_null = {
|
sample_null = {
|
||||||
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
||||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)],
|
||||||
}
|
"fastercache": args["fastercache"] if args["fastercache"] is not None else None
|
||||||
|
}
|
||||||
|
|
||||||
sample["packed_indices"] = self.get_packed_indices(
|
if args["fastercache"]:
|
||||||
sample["y_mask"], **latent_dims
|
self.fastercache_start_step = args["fastercache"]["start_step"]
|
||||||
)
|
self.fastercache_lf_step = args["fastercache"]["lf_step"]
|
||||||
sample_null["packed_indices"] = self.get_packed_indices(
|
self.fastercache_hf_step = args["fastercache"]["hf_step"]
|
||||||
sample_null["y_mask"], **latent_dims
|
else:
|
||||||
)
|
self.fastercache_start_step = 1000
|
||||||
self.use_fastercache = True
|
|
||||||
self.fastercache_counter = 0
|
self.fastercache_counter = 0
|
||||||
self.fastercache_start_step = 15
|
|
||||||
self.fastercache_lf_step = 40
|
|
||||||
self.fastercache_hf_step = 30
|
|
||||||
|
|
||||||
def model_fn(*, z, sigma, cfg_scale):
|
def model_fn(*, z, sigma, cfg_scale):
|
||||||
nonlocal sample, sample_null
|
nonlocal sample, sample_null
|
||||||
if self.use_fastercache:
|
if args["fastercache"]:
|
||||||
self.fastercache_counter+=1
|
self.fastercache_counter+=1
|
||||||
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
if self.fastercache_counter >= self.fastercache_start_step + 3 and self.fastercache_counter % 5 !=0:
|
||||||
out_cond = self.dit(z, sigma,self.fastercache_counter, self.fastercache_start_step, **sample)
|
out_cond = self.dit(
|
||||||
|
z,
|
||||||
|
sigma,
|
||||||
|
self.fastercache_counter,
|
||||||
|
**sample)
|
||||||
|
|
||||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||||
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
cond = rearrange(out_cond, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
lf_c, hf_c = fft(cond.float())
|
lf_c, hf_c = fft(cond.float())
|
||||||
#lf_step = 40
|
|
||||||
#hf_step = 30
|
|
||||||
if self.fastercache_counter <= self.fastercache_lf_step:
|
if self.fastercache_counter <= self.fastercache_lf_step:
|
||||||
self.delta_lf = self.delta_lf * 1.1
|
self.delta_lf = self.delta_lf * 1.1
|
||||||
if self.fastercache_counter >= self.fastercache_hf_step:
|
if self.fastercache_counter >= self.fastercache_hf_step:
|
||||||
@ -334,12 +292,19 @@ class T2VSynthMochiModel:
|
|||||||
|
|
||||||
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
|
return recovered_uncond + cfg_scale * (out_cond - recovered_uncond)
|
||||||
else:
|
else:
|
||||||
out_cond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample)
|
out_cond = self.dit(
|
||||||
out_uncond = self.dit(z, sigma, self.fastercache_counter, self.fastercache_start_step,**sample_null)
|
z,
|
||||||
#print("out_cond.shape",out_cond.shape) #([1, 12, 3, 60, 106])
|
sigma,
|
||||||
|
self.fastercache_counter,
|
||||||
|
**sample)
|
||||||
|
|
||||||
|
out_uncond = self.dit(
|
||||||
|
z,
|
||||||
|
sigma,
|
||||||
|
self.fastercache_counter,
|
||||||
|
**sample_null)
|
||||||
|
|
||||||
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
if self.fastercache_counter >= self.fastercache_start_step + 1:
|
||||||
|
|
||||||
(bb, cc, tt, hh, ww) = out_cond.shape
|
(bb, cc, tt, hh, ww) = out_cond.shape
|
||||||
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
cond = rearrange(out_cond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
uncond = rearrange(out_uncond.float(), "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
|
||||||
@ -352,7 +317,6 @@ class T2VSynthMochiModel:
|
|||||||
|
|
||||||
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
||||||
|
|
||||||
|
|
||||||
comfy_pbar = ProgressBar(sample_steps)
|
comfy_pbar = ProgressBar(sample_steps)
|
||||||
|
|
||||||
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
||||||
@ -373,13 +337,19 @@ class T2VSynthMochiModel:
|
|||||||
sigma=torch.full([B], sigma, device=z.device),
|
sigma=torch.full([B], sigma, device=z.device),
|
||||||
cfg_scale=cfg_schedule[i],
|
cfg_scale=cfg_schedule[i],
|
||||||
)
|
)
|
||||||
pred = pred.to(z)
|
z = z + dsigma * pred.to(z)
|
||||||
z = z + dsigma * pred
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
|
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
|
||||||
else:
|
else:
|
||||||
comfy_pbar.update(1)
|
comfy_pbar.update(1)
|
||||||
|
|
||||||
|
if args["fastercache"] is not None:
|
||||||
|
for block in self.dit.blocks:
|
||||||
|
if (hasattr, block, "cached_x_attention") and block.cached_x_attention is not None:
|
||||||
|
block.cached_x_attention = None
|
||||||
|
block.cached_y_attention = None
|
||||||
|
|
||||||
self.dit.to(self.offload_device)
|
self.dit.to(self.offload_device)
|
||||||
|
mm.soft_empty_cache()
|
||||||
logging.info(f"samples shape: {z.shape}")
|
logging.info(f"samples shape: {z.shape}")
|
||||||
return z
|
return z
|
||||||
|
|||||||
40
nodes.py
40
nodes.py
@ -446,6 +446,36 @@ class MochiTextEncode:
|
|||||||
}
|
}
|
||||||
return (t5_embeds, clip,)
|
return (t5_embeds, clip,)
|
||||||
|
|
||||||
|
#region FasterCache
|
||||||
|
class MochiFasterCache:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"start_step": ("INT", {"default": 10, "min": 0, "max": 1024, "step": 1, "tooltip": "The step to start caching, sigma schedule should be adjusted accordingly"}),
|
||||||
|
"hf_step": ("INT", {"default": 22, "min": 0, "max": 1024, "step": 1}),
|
||||||
|
"lf_step": ("INT", {"default": 28, "min": 0, "max": 1024, "step": 1}),
|
||||||
|
"cache_device": (["main_device", "offload_device"], {"default": "main_device", "tooltip": "The device to use for the cache, main_device is on GPU and uses a lot of VRAM"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("FASTERCACHEARGS",)
|
||||||
|
RETURN_NAMES = ("fastercache", )
|
||||||
|
FUNCTION = "args"
|
||||||
|
CATEGORY = "CogVideoWrapper"
|
||||||
|
DESCRIPTION = "FasterCache (https://github.com/Vchitect/FasterCache) settings for the MochiWrapper"
|
||||||
|
|
||||||
|
def args(self, start_step, hf_step, lf_step, cache_device):
|
||||||
|
device = mm.get_torch_device()
|
||||||
|
offload_device = mm.unet_offload_device()
|
||||||
|
fastercache = {
|
||||||
|
"start_step" : start_step,
|
||||||
|
"hf_step" : hf_step,
|
||||||
|
"lf_step" : lf_step,
|
||||||
|
"cache_device" : device if cache_device == "main_device" else offload_device
|
||||||
|
}
|
||||||
|
return (fastercache,)
|
||||||
|
|
||||||
#region Sampler
|
#region Sampler
|
||||||
class MochiSampler:
|
class MochiSampler:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -466,6 +496,7 @@ class MochiSampler:
|
|||||||
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
|
"cfg_schedule": ("FLOAT", {"forceInput": True, "tooltip": "Override cfg schedule with a list of ints"}),
|
||||||
"opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}),
|
"opt_sigmas": ("SIGMAS", {"tooltip": "Override sigma schedule and steps"}),
|
||||||
"samples": ("LATENT", ),
|
"samples": ("LATENT", ),
|
||||||
|
"fastercache": ("FASTERCACHEARGS", {"tooltip": "Optional FasterCache settings"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -474,7 +505,7 @@ class MochiSampler:
|
|||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "MochiWrapper"
|
CATEGORY = "MochiWrapper"
|
||||||
|
|
||||||
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
|
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None, fastercache=None):
|
||||||
mm.unload_all_models()
|
mm.unload_all_models()
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
@ -517,6 +548,7 @@ class MochiSampler:
|
|||||||
"negative_embeds": negative,
|
"negative_embeds": negative,
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
"samples": samples["samples"] if samples is not None else None,
|
"samples": samples["samples"] if samples is not None else None,
|
||||||
|
"fastercache": fastercache
|
||||||
}
|
}
|
||||||
latents = model.run(args)
|
latents = model.run(args)
|
||||||
|
|
||||||
@ -848,7 +880,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"MochiTorchCompileSettings": MochiTorchCompileSettings,
|
"MochiTorchCompileSettings": MochiTorchCompileSettings,
|
||||||
"MochiImageEncode": MochiImageEncode,
|
"MochiImageEncode": MochiImageEncode,
|
||||||
"MochiLatentPreview": MochiLatentPreview,
|
"MochiLatentPreview": MochiLatentPreview,
|
||||||
"MochiSigmaSchedule": MochiSigmaSchedule
|
"MochiSigmaSchedule": MochiSigmaSchedule,
|
||||||
|
"MochiFasterCache": MochiFasterCache
|
||||||
}
|
}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
"DownloadAndLoadMochiModel": "(Down)load Mochi Model",
|
||||||
@ -862,5 +895,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"MochiTorchCompileSettings": "Mochi Torch Compile Settings",
|
"MochiTorchCompileSettings": "Mochi Torch Compile Settings",
|
||||||
"MochiImageEncode": "Mochi Image Encode",
|
"MochiImageEncode": "Mochi Image Encode",
|
||||||
"MochiLatentPreview": "Mochi Latent Preview",
|
"MochiLatentPreview": "Mochi Latent Preview",
|
||||||
"MochiSigmaSchedule": "Mochi Sigma Schedule"
|
"MochiSigmaSchedule": "Mochi Sigma Schedule",
|
||||||
|
"MochiFasterCache": "Mochi Faster Cache"
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user