2024-11-05 20:21:47 +02:00

704 lines
24 KiB
Python

from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .layers import (
FeedForward,
PatchEmbed,
TimestepEmbedder,
)
from .mod_rmsnorm import modulated_rmsnorm
from .residual_tanh_gated_rmsnorm import residual_tanh_gated_rmsnorm
from .rope_mixed import compute_mixed_rotation, create_position_matrix
from .temporal_rope import apply_rotary_emb_qk_real
from .utils import pool_tokens, modulate
try:
from flash_attn import flash_attn_func
FLASH_ATTN_IS_AVAILABLE = True
except ImportError:
FLASH_ATTN_IS_AVAILABLE = False
try:
from sageattention import sageattn
SAGEATTN_IS_AVAILABLE = True
except ImportError:
SAGEATTN_IS_AVAILABLE = False
from torch.nn.attention import sdpa_kernel, SDPBackend
backends = []
backends.append(SDPBackend.CUDNN_ATTENTION)
backends.append(SDPBackend.EFFICIENT_ATTENTION)
backends.append(SDPBackend.MATH)
import comfy.model_management as mm
from comfy.ldm.modules.attention import optimized_attention
class AttentionPool(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
output_dim: int = None,
device: Optional[torch.device] = None,
):
"""
Args:
spatial_dim (int): Number of tokens in sequence length.
embed_dim (int): Dimensionality of input tokens.
num_heads (int): Number of attention heads.
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
"""
super().__init__()
self.num_heads = num_heads
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
def forward(self, x, mask):
"""
Args:
x (torch.Tensor): (B, L, D) tensor of input tokens.
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
NOTE: We assume x does not require gradients.
Returns:
x (torch.Tensor): (B, D) tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_heads
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x
#region Attention
class AsymmetricAttention(nn.Module):
def __init__(
self,
dim_x: int,
dim_y: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.0,
update_y: bool = True,
out_bias: bool = True,
attend_to_padding: bool = False,
softmax_scale: Optional[float] = None,
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
rms_norm_func: bool = False,
):
super().__init__()
self.dim_x = dim_x
self.dim_y = dim_y
self.num_heads = num_heads
self.head_dim = dim_x // num_heads
self.attn_drop = attn_drop
self.update_y = update_y
self.attend_to_padding = attend_to_padding
self.softmax_scale = softmax_scale
self.attention_mode = attention_mode
self.device = device
if dim_x % num_heads != 0:
raise ValueError(
f"dim_x={dim_x} should be divisible by num_heads={num_heads}"
)
# Input layers.
self.qkv_bias = qkv_bias
self.qkv_x = nn.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device)
# Project text features to match visual features (dim_y -> dim_x)
self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device)
# Query and key normalization for stability.
#assert qk_norm
if rms_norm_func == "flash_attn_triton": #use the same rms_norm_func
try:
from flash_attn.ops.triton.layer_norm import RMSNorm as FlashTritonRMSNorm #slightly faster
@torch.compiler.disable() #cause NaNs when compiled for some reason
class RMSNorm(FlashTritonRMSNorm):
pass
except:
raise ImportError("Flash Triton RMSNorm not available.")
elif rms_norm_func == "flash_attn":
try:
from flash_attn.ops.rms_norm import RMSNorm as FlashRMSNorm #slightly faster
@torch.compiler.disable() #cause NaNs when compiled for some reason
class RMSNorm(FlashRMSNorm):
pass
except:
raise ImportError("Flash RMSNorm not available.")
elif rms_norm_func == "apex":
from apex.normalization import FusedRMSNorm as ApexRMSNorm
class RMSNorm(ApexRMSNorm):
pass
else:
from .layers import RMSNorm
norm_kwargs = {}
if rms_norm_func != "apex":
norm_kwargs['device'] = device
self.q_norm_x = RMSNorm(self.head_dim, **norm_kwargs)
self.k_norm_x = RMSNorm(self.head_dim, **norm_kwargs)
self.q_norm_y = RMSNorm(self.head_dim, **norm_kwargs)
self.k_norm_y = RMSNorm(self.head_dim, **norm_kwargs)
# Output layers. y features go back down from dim_x -> dim_y.
self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device)
self.proj_y = (
nn.Linear(dim_x, dim_y, bias=out_bias, device=device)
if update_y
else nn.Identity()
)
def flash_attention(self, q, k ,v):
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
b, _, _, dim_head = q.shape
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out: torch.Tensor = flash_attn_func( #q: (batch_size, seqlen, nheads, headdim)
q, k, v,
dropout_p=0.0,
softmax_scale=self.softmax_scale,
) # (total, local_heads, head_dim)
out = out.permute(0, 2, 1, 3)
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def sdpa_attention(self, q, k, v):
b, _, _, dim_head = q.shape
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
with sdpa_kernel(backends):
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def sage_attention(self, q, k, v):
b, _, _, dim_head = q.shape
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = sageattn(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def comfy_attention(self, q, k, v):
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = optimized_attention(
q,
k,
v,
heads = self.num_heads,
skip_reshape=True
)
return out
def run_attention(
self,
q,
k,
v,
):
if self.attention_mode == "flash_attn":
out = self.flash_attention(q, k ,v)
elif self.attention_mode == "sdpa":
out = self.sdpa_attention(q, k, v)
elif self.attention_mode == "sage_attn":
out = self.sage_attention(q, k, v)
elif self.attention_mode == "comfy":
out = self.comfy_attention(q, k, v)
return out
def forward(
self,
x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y)
*,
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
num_tokens,
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
rope_sin = rope_rotation.get("rope_sin")
# Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y)
# Split qkv_x into q, k, v
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
q_x = self.q_norm_x(q_x)
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
k_x = self.k_norm_x(k_x)
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
q = torch.cat([q_x, q_y[:, :num_tokens]], dim=1).transpose(1, 2)
k = torch.cat([k_x, k_y[:, :num_tokens]], dim=1).transpose(1, 2)
v = torch.cat([v_x, v_y[:, :num_tokens]], dim=1).transpose(1, 2)
xy = self.run_attention(q, k, v)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
o[:, :y.shape[1]] = y
y = self.proj_y(o)
return x, y
#region Blocks
class AsymmetricJointBlock(nn.Module):
def __init__(
self,
hidden_size_x: int,
hidden_size_y: int,
num_heads: int,
*,
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
update_y: bool = True, # Whether to update text tokens in this block.
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
rms_norm_func: str = "default",
**block_kwargs,
):
super().__init__()
self.update_y = update_y
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.attention_mode = attention_mode
self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device)
if self.update_y:
self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
else:
self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
self.cached_x_attention = [None, None]
self.cached_y_attention = [None, None]
# Self-attention:
self.attn = AsymmetricAttention(
hidden_size_x,
hidden_size_y,
num_heads=num_heads,
update_y=update_y,
device=device,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
**block_kwargs,
)
# MLP.
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
#assert mlp_hidden_dim_x == int(1536 * 8)
self.mlp_x = FeedForward(
in_features=hidden_size_x,
hidden_size=mlp_hidden_dim_x,
multiple_of=256,
ffn_dim_multiplier=None,
device=device,
)
# MLP for text not needed in last block.
if self.update_y:
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
self.mlp_y = FeedForward(
in_features=hidden_size_y,
hidden_size=mlp_hidden_dim_y,
multiple_of=256,
ffn_dim_multiplier=None,
device=device,
)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
fastercache_counter: Optional[int] = 0,
fastercache_start_step: Optional[int] = 1000,
fastercache_device: Optional[torch.device] = None,
**attn_kwargs,
):
"""Forward pass of a block.
Args:
x: (B, N, dim) tensor of visual tokens
c: (B, dim) tensor of conditioned features
y: (B, L, dim) tensor of text tokens
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
Returns:
x: (B, N, dim) tensor of visual tokens after block
y: (B, L, dim) tensor of text tokens after block
"""
N = x.size(1)
c = F.silu(c)
mod_x = self.mod_x(c)
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
mod_y = self.mod_y(c)
if self.update_y:
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
else:
scale_msa_y = mod_y
#region fastercache
B = x.shape[0]
#print("x", x.shape) #([1, 9540, 3072])
if fastercache_counter >= fastercache_start_step + 3 and fastercache_counter%3!=0 and self.cached_x_attention[-1].shape[0] >= B:
x_attn = (
self.cached_x_attention[1][:B] +
(self.cached_x_attention[1][:B] - self.cached_x_attention[0][:B])
* 0.3
).to(x.device, non_blocking=True)
y_attn = (
self.cached_y_attention[1][:B] +
(self.cached_y_attention[1][:B] - self.cached_y_attention[0][:B])
* 0.3
).to(x.device, non_blocking=True)
else:
# Self-attention block.
x_attn, y_attn = self.attn(
x,
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
**attn_kwargs,
)
if fastercache_counter == fastercache_start_step:
self.cached_x_attention = [x_attn.to(fastercache_device), x_attn.to(fastercache_device)]
self.cached_y_attention = [y_attn.to(fastercache_device), y_attn.to(fastercache_device)]
elif fastercache_counter > fastercache_start_step:
self.cached_x_attention[-1].copy_(x_attn.to(fastercache_device))
self.cached_y_attention[-1].copy_(y_attn.to(fastercache_device))
#assert x_attn.size(1) == N
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
if self.update_y:
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
# MLP block.
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
if self.update_y:
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
return x, y
def ff_block_x(self, x, scale_x, gate_x):
x_mod = modulated_rmsnorm(x, scale_x)
x_res = self.mlp_x(x_mod)
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
return x
def ff_block_y(self, y, scale_y, gate_y):
y_mod = modulated_rmsnorm(y, scale_y)
y_res = self.mlp_y(y_mod)
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
return y
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(
self,
hidden_size,
patch_size,
out_channels,
device: Optional[torch.device] = None,
):
super().__init__()
self.norm_final = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, device=device
)
self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, device=device
)
def forward(self, x, c):
c = F.silu(c)
shift, scale = self.mod(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
#region Model
class AsymmDiTJoint(nn.Module):
"""
Diffusion model with a Transformer backbone.
Ingests text embeddings instead of a label.
"""
def __init__(
self,
*,
patch_size=2,
in_channels=4,
hidden_size_x=1152,
hidden_size_y=1152,
depth=48,
num_heads=16,
mlp_ratio_x=8.0,
mlp_ratio_y=4.0,
t5_feat_dim: int = 4096,
t5_token_length: int = 256,
patch_embed_bias: bool = True,
timestep_mlp_bias: bool = True,
timestep_scale: Optional[float] = None,
use_extended_posenc: bool = False,
rope_theta: float = 10000.0,
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
rms_norm_func: str = "default",
**block_kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size_x = hidden_size_x
self.hidden_size_y = hidden_size_y
self.head_dim = (
hidden_size_x // num_heads
) # Head dimension and count is determined by visual.
self.use_extended_posenc = use_extended_posenc
self.t5_token_length = t5_token_length
self.t5_feat_dim = t5_feat_dim
self.rope_theta = (
rope_theta # Scaling factor for frequency computation for temporal RoPE.
)
self.x_embedder = PatchEmbed(
patch_size=patch_size,
in_chans=in_channels,
embed_dim=hidden_size_x,
bias=patch_embed_bias,
device=device,
)
# Conditionings
# Timestep
self.t_embedder = TimestepEmbedder(
hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale
)
# Caption Pooling (T5)
self.t5_y_embedder = AttentionPool(
t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device
)
# Dense Embedding Projection (T5)
self.t5_yproj = nn.Linear(
t5_feat_dim, hidden_size_y, bias=True, device=device
)
# Initialize pos_frequencies as an empty parameter.
self.pos_frequencies = nn.Parameter(
torch.empty(3, self.num_heads, self.head_dim // 2, device=device)
)
# for depth 48:
# b = 0: AsymmetricJointBlock, update_y=True
# b = 1: AsymmetricJointBlock, update_y=True
# ...
# b = 46: AsymmetricJointBlock, update_y=True
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
blocks = []
for b in range(depth):
# Joint multi-modal block
update_y = b < depth - 1
block = AsymmetricJointBlock(
hidden_size_x,
hidden_size_y,
num_heads,
mlp_ratio_x=mlp_ratio_x,
mlp_ratio_y=mlp_ratio_y,
update_y=update_y,
device=device,
attention_mode=attention_mode,
rms_norm_func=rms_norm_func,
**block_kwargs,
)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.final_layer = FinalLayer(
hidden_size_x, patch_size, self.out_channels, device=device
)
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, C=12, T, H, W) tensor of visual tokens
Returns:
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
"""
return self.x_embedder(x) # Convert BcTHW to BCN
def prepare(
self,
x: torch.Tensor,
sigma: torch.Tensor,
t5_feat: torch.Tensor,
t5_mask: torch.Tensor,
):
"""Prepare input and conditioning embeddings."""
#("X", x.shape)
# Visual patch embeddings with positional encoding.
T, H, W = x.shape[-3:]
pH, pW = H // self.patch_size, W // self.patch_size
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
#assert x.ndim == 3
# Construct position array of size [N, 3].
# pos[:, 0] is the frame index for each location,
# pos[:, 1] is the row index for each location, and
# pos[:, 2] is the column index for each location.
pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW
#assert x.size(1) == N
pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3)
rope_cos, rope_sin = compute_mixed_rotation(freqs=self.pos_frequencies, pos=pos) # Each are (N, num_heads, dim // 2)
# Global vector embedding for conditionings.
c_t = self.t_embedder(1 - sigma) # (B, D)
# Pool T5 tokens using attention pooler
# Note y_feat[1] contains T5 token features.
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
c = c_t + t5_y_pool
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
return x, c, y_feat, rope_cos, rope_sin
def forward(
self,
x: torch.Tensor,
sigma: torch.Tensor,
y_feat: List[torch.Tensor],
y_mask: List[torch.Tensor],
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
fastercache: Optional[Dict[str, torch.Tensor]] = None,
fastercache_counter: Optional[int]=0,
):
"""Forward pass of DiT.
Args:
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
sigma: (B,) tensor of noise standard deviations
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
"""
B, _, T, H, W = x.shape
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
# Have to call sdpa_kernel outside of a torch.compile region.
num_tokens = max(1, torch.sum(y_mask[0]).item())
with sdpa_kernel(backends):
x, c, y_feat, rope_cos, rope_sin = self.prepare(
x, sigma, y_feat[0], y_mask[0]
)
del y_mask
if fastercache is not None:
fastercache_start_step = fastercache["start_step"]
fastercache_device = fastercache["cache_device"]
else:
fastercache_start_step = 1000
fastercache_device = None
#print(fastercache_counter)
for i, block in enumerate(self.blocks):
x, y_feat = block(
x,
c,
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
num_tokens=num_tokens,
fastercache_counter = fastercache_counter,
fastercache_start_step = fastercache_start_step,
fastercache_device = fastercache_device,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
hp = H // self.patch_size
wp = W // self.patch_size
p1 = self.patch_size
p2 = self.patch_size
c = self.out_channels
x = x.view(B, T, hp, wp, p1, p2, c)
x = x.permute(0, 6, 1, 2, 4, 3, 5)
x = x.reshape(B, c, T, hp * p1, wp * p2)
return x