cleanup
This commit is contained in:
parent
e82e6ee3f7
commit
e20eb66f93
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -8,7 +7,6 @@ from einops import rearrange
|
|||||||
|
|
||||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||||
|
|
||||||
from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
PatchEmbed,
|
PatchEmbed,
|
||||||
@ -17,9 +15,7 @@ from .layers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .mod_rmsnorm import modulated_rmsnorm
|
from .mod_rmsnorm import modulated_rmsnorm
|
||||||
from .residual_tanh_gated_rmsnorm import (
|
from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm)
|
||||||
residual_tanh_gated_rmsnorm,
|
|
||||||
)
|
|
||||||
from .rope_mixed import (
|
from .rope_mixed import (
|
||||||
compute_mixed_rotation,
|
compute_mixed_rotation,
|
||||||
create_position_matrix,
|
create_position_matrix,
|
||||||
@ -108,25 +104,13 @@ class AsymmetricAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def run_qkv_y(self, y):
|
def run_qkv_y(self, y):
|
||||||
cp_rank, cp_size = get_cp_rank_size()
|
local_heads = self.num_heads
|
||||||
local_heads = self.num_heads // cp_size
|
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
|
||||||
|
|
||||||
if is_cp_active():
|
|
||||||
# Only predict local heads.
|
|
||||||
assert not self.qkv_bias
|
|
||||||
W_qkv_y = self.qkv_y.weight.view(
|
|
||||||
3, self.num_heads, self.head_dim, self.dim_y
|
|
||||||
)
|
|
||||||
W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads)
|
|
||||||
W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y)
|
|
||||||
qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim)
|
|
||||||
else:
|
|
||||||
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
|
|
||||||
|
|
||||||
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
|
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
|
||||||
q_y, k_y, v_y = qkv_y.unbind(2)
|
q_y, k_y, v_y = qkv_y.unbind(2)
|
||||||
return q_y, k_y, v_y
|
return q_y, k_y, v_y
|
||||||
|
|
||||||
|
|
||||||
def prepare_qkv(
|
def prepare_qkv(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor, # (B, N, dim_x)
|
x: torch.Tensor, # (B, N, dim_x)
|
||||||
@ -144,9 +128,12 @@ class AsymmetricAttention(nn.Module):
|
|||||||
# Process visual features
|
# Process visual features
|
||||||
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||||
#assert qkv_x.dtype == torch.bfloat16
|
#assert qkv_x.dtype == torch.bfloat16
|
||||||
qkv_x = all_to_all_collect_tokens(
|
|
||||||
qkv_x, self.num_heads
|
# Move QKV dimension to the front.
|
||||||
) # (3, B, N, local_h, head_dim)
|
# B M (3 H d) -> 3 B M H d
|
||||||
|
B, M, _ = qkv_x.size()
|
||||||
|
qkv_x = qkv_x.view(B, M, 3, self.num_heads, -1)
|
||||||
|
qkv_x = qkv_x.permute(2, 0, 1, 3, 4)
|
||||||
|
|
||||||
# Process text features
|
# Process text features
|
||||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||||
@ -237,11 +224,7 @@ class AsymmetricAttention(nn.Module):
|
|||||||
max_seqlen_in_batch: int,
|
max_seqlen_in_batch: int,
|
||||||
valid_token_indices: torch.Tensor,
|
valid_token_indices: torch.Tensor,
|
||||||
):
|
):
|
||||||
_, cp_size = get_cp_rank_size()
|
local_dim = self.num_heads * self.head_dim
|
||||||
N = cp_size * M
|
|
||||||
assert self.num_heads % cp_size == 0
|
|
||||||
local_heads = self.num_heads // cp_size
|
|
||||||
local_dim = local_heads * self.head_dim
|
|
||||||
total = qkv.size(0)
|
total = qkv.size(0)
|
||||||
|
|
||||||
if self.attention_mode == "flash_attn":
|
if self.attention_mode == "flash_attn":
|
||||||
@ -253,19 +236,13 @@ class AsymmetricAttention(nn.Module):
|
|||||||
elif self.attention_mode == "comfy":
|
elif self.attention_mode == "comfy":
|
||||||
out = self.comfy_attention(qkv)
|
out = self.comfy_attention(qkv)
|
||||||
|
|
||||||
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
|
x, y = pad_and_split_xy(out, valid_token_indices, B, M, L, qkv.dtype)
|
||||||
assert x.size() == (B, N, local_dim)
|
assert x.size() == (B, M, local_dim)
|
||||||
assert y.size() == (B, L, local_dim)
|
assert y.size() == (B, L, local_dim)
|
||||||
|
|
||||||
x = x.view(B, N, local_heads, self.head_dim)
|
x = x.view(B, M, self.num_heads, self.head_dim)
|
||||||
x = all_to_all_collect_heads(x) # (B, M, dim_x = num_heads * head_dim)
|
x = x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
|
||||||
x = self.proj_x(x) # (B, M, dim_x)
|
x = self.proj_x(x) # (B, M, dim_x)
|
||||||
|
|
||||||
if is_cp_active():
|
|
||||||
y = all_gather(y) # (cp_size * B, L, local_heads * head_dim)
|
|
||||||
y = rearrange(
|
|
||||||
y, "(G B) L D -> B L (G D)", G=cp_size, D=local_dim
|
|
||||||
) # (B, L, dim_x)
|
|
||||||
y = self.proj_y(y) # (B, L, dim_y)
|
y = self.proj_y(y) # (B, L, dim_y)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
@ -593,46 +570,28 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
):
|
):
|
||||||
"""Prepare input and conditioning embeddings."""
|
"""Prepare input and conditioning embeddings."""
|
||||||
#("X", x.shape)
|
#("X", x.shape)
|
||||||
with torch.profiler.record_function("x_emb_pe"):
|
# Visual patch embeddings with positional encoding.
|
||||||
# Visual patch embeddings with positional encoding.
|
T, H, W = x.shape[-3:]
|
||||||
T, H, W = x.shape[-3:]
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
pH, pW = H // self.patch_size, W // self.patch_size
|
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
assert x.ndim == 3
|
||||||
assert x.ndim == 3
|
|
||||||
B = x.size(0)
|
|
||||||
|
|
||||||
with torch.profiler.record_function("rope_cis"):
|
# Construct position array of size [N, 3].
|
||||||
# Construct position array of size [N, 3].
|
# pos[:, 0] is the frame index for each location,
|
||||||
# pos[:, 0] is the frame index for each location,
|
# pos[:, 1] is the row index for each location, and
|
||||||
# pos[:, 1] is the row index for each location, and
|
# pos[:, 2] is the column index for each location.
|
||||||
# pos[:, 2] is the column index for each location.
|
pH, pW = H // self.patch_size, W // self.patch_size
|
||||||
pH, pW = H // self.patch_size, W // self.patch_size
|
N = T * pH * pW
|
||||||
N = T * pH * pW
|
assert x.size(1) == N
|
||||||
assert x.size(1) == N
|
pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3)
|
||||||
pos = create_position_matrix(
|
rope_cos, rope_sin = compute_mixed_rotation(freqs=self.pos_frequencies, pos=pos) # Each are (N, num_heads, dim // 2)
|
||||||
T, pH=pH, pW=pW, device=x.device, dtype=torch.float32
|
|
||||||
) # (N, 3)
|
|
||||||
rope_cos, rope_sin = compute_mixed_rotation(
|
|
||||||
freqs=self.pos_frequencies, pos=pos
|
|
||||||
) # Each are (N, num_heads, dim // 2)
|
|
||||||
|
|
||||||
with torch.profiler.record_function("t_emb"):
|
# Global vector embedding for conditionings.
|
||||||
# Global vector embedding for conditionings.
|
c_t = self.t_embedder(1 - sigma) # (B, D)
|
||||||
c_t = self.t_embedder(1 - sigma) # (B, D)
|
|
||||||
|
|
||||||
with torch.profiler.record_function("t5_pool"):
|
# Pool T5 tokens using attention pooler
|
||||||
# Pool T5 tokens using attention pooler
|
# Note y_feat[1] contains T5 token features.
|
||||||
# Note y_feat[1] contains T5 token features.
|
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
||||||
# print("B", B)
|
|
||||||
# print("t5 feat shape",t5_feat.shape)
|
|
||||||
# print("t5 mask shape", t5_mask.shape)
|
|
||||||
assert (
|
|
||||||
t5_feat.size(1) == self.t5_token_length
|
|
||||||
), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat."
|
|
||||||
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
|
||||||
assert (
|
|
||||||
t5_y_pool.size(0) == B
|
|
||||||
), f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool."
|
|
||||||
|
|
||||||
c = c_t + t5_y_pool
|
c = c_t + t5_y_pool
|
||||||
|
|
||||||
@ -669,21 +628,6 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
)
|
)
|
||||||
del y_mask
|
del y_mask
|
||||||
|
|
||||||
cp_rank, cp_size = get_cp_rank_size()
|
|
||||||
N = x.size(1)
|
|
||||||
M = N // cp_size
|
|
||||||
assert (
|
|
||||||
N % cp_size == 0
|
|
||||||
), f"Visual sequence length ({x.shape[1]}) must be divisible by cp_size ({cp_size})."
|
|
||||||
|
|
||||||
if cp_size > 1:
|
|
||||||
x = x.narrow(1, cp_rank * M, M)
|
|
||||||
|
|
||||||
assert self.num_heads % cp_size == 0
|
|
||||||
local_heads = self.num_heads // cp_size
|
|
||||||
rope_cos = rope_cos.narrow(1, cp_rank * local_heads, local_heads)
|
|
||||||
rope_sin = rope_sin.narrow(1, cp_rank * local_heads, local_heads)
|
|
||||||
|
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
x, y_feat = block(
|
x, y_feat = block(
|
||||||
x,
|
x,
|
||||||
@ -695,11 +639,7 @@ class AsymmDiTJoint(nn.Module):
|
|||||||
) # (B, M, D), (B, L, D)
|
) # (B, M, D), (B, L, D)
|
||||||
del y_feat # Final layers don't use dense text features.
|
del y_feat # Final layers don't use dense text features.
|
||||||
|
|
||||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||||
|
|
||||||
patch = x.size(2)
|
|
||||||
x = all_gather(x)
|
|
||||||
x = rearrange(x, "(G B) M P -> B (G M) P", G=cp_size, P=patch)
|
|
||||||
x = rearrange(
|
x = rearrange(
|
||||||
x,
|
x,
|
||||||
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from ..dit.joint_model.context_parallel import get_cp_rank_size, local_shard
|
from ..dit.joint_model.context_parallel import get_cp_rank_size
|
||||||
from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
from ..vae.cp_conv import cp_pass_frames, gather_all_frames
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user