diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 9aeb81a..7ac5208 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -1,4 +1,3 @@ -import os from typing import Dict, List, Optional, Tuple import torch @@ -8,7 +7,6 @@ from einops import rearrange from torch.nn.attention import sdpa_kernel, SDPBackend -from .context_parallel import all_to_all_collect_tokens, all_to_all_collect_heads, all_gather, get_cp_rank_size, is_cp_active from .layers import ( FeedForward, PatchEmbed, @@ -17,9 +15,7 @@ from .layers import ( ) from .mod_rmsnorm import modulated_rmsnorm -from .residual_tanh_gated_rmsnorm import ( - residual_tanh_gated_rmsnorm, -) +from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm) from .rope_mixed import ( compute_mixed_rotation, create_position_matrix, @@ -108,25 +104,13 @@ class AsymmetricAttention(nn.Module): ) def run_qkv_y(self, y): - cp_rank, cp_size = get_cp_rank_size() - local_heads = self.num_heads // cp_size - - if is_cp_active(): - # Only predict local heads. - assert not self.qkv_bias - W_qkv_y = self.qkv_y.weight.view( - 3, self.num_heads, self.head_dim, self.dim_y - ) - W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads) - W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y) - qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim) - else: - qkv_y = self.qkv_y(y) # (B, L, 3 * dim) - + local_heads = self.num_heads + qkv_y = self.qkv_y(y) # (B, L, 3 * dim) qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim) q_y, k_y, v_y = qkv_y.unbind(2) return q_y, k_y, v_y + def prepare_qkv( self, x: torch.Tensor, # (B, N, dim_x) @@ -144,9 +128,12 @@ class AsymmetricAttention(nn.Module): # Process visual features qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x) #assert qkv_x.dtype == torch.bfloat16 - qkv_x = all_to_all_collect_tokens( - qkv_x, self.num_heads - ) # (3, B, N, local_h, head_dim) + + # Move QKV dimension to the front. + # B M (3 H d) -> 3 B M H d + B, M, _ = qkv_x.size() + qkv_x = qkv_x.view(B, M, 3, self.num_heads, -1) + qkv_x = qkv_x.permute(2, 0, 1, 3, 4) # Process text features y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y) @@ -237,11 +224,7 @@ class AsymmetricAttention(nn.Module): max_seqlen_in_batch: int, valid_token_indices: torch.Tensor, ): - _, cp_size = get_cp_rank_size() - N = cp_size * M - assert self.num_heads % cp_size == 0 - local_heads = self.num_heads // cp_size - local_dim = local_heads * self.head_dim + local_dim = self.num_heads * self.head_dim total = qkv.size(0) if self.attention_mode == "flash_attn": @@ -253,19 +236,13 @@ class AsymmetricAttention(nn.Module): elif self.attention_mode == "comfy": out = self.comfy_attention(qkv) - x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) - assert x.size() == (B, N, local_dim) + x, y = pad_and_split_xy(out, valid_token_indices, B, M, L, qkv.dtype) + assert x.size() == (B, M, local_dim) assert y.size() == (B, L, local_dim) - x = x.view(B, N, local_heads, self.head_dim) - x = all_to_all_collect_heads(x) # (B, M, dim_x = num_heads * head_dim) + x = x.view(B, M, self.num_heads, self.head_dim) + x = x.view(x.size(0), x.size(1), x.size(2) * x.size(3)) x = self.proj_x(x) # (B, M, dim_x) - - if is_cp_active(): - y = all_gather(y) # (cp_size * B, L, local_heads * head_dim) - y = rearrange( - y, "(G B) L D -> B L (G D)", G=cp_size, D=local_dim - ) # (B, L, dim_x) y = self.proj_y(y) # (B, L, dim_y) return x, y @@ -593,46 +570,28 @@ class AsymmDiTJoint(nn.Module): ): """Prepare input and conditioning embeddings.""" #("X", x.shape) - with torch.profiler.record_function("x_emb_pe"): - # Visual patch embeddings with positional encoding. - T, H, W = x.shape[-3:] - pH, pW = H // self.patch_size, W // self.patch_size - x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2 - assert x.ndim == 3 - B = x.size(0) + # Visual patch embeddings with positional encoding. + T, H, W = x.shape[-3:] + pH, pW = H // self.patch_size, W // self.patch_size + x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2 + assert x.ndim == 3 - with torch.profiler.record_function("rope_cis"): - # Construct position array of size [N, 3]. - # pos[:, 0] is the frame index for each location, - # pos[:, 1] is the row index for each location, and - # pos[:, 2] is the column index for each location. - pH, pW = H // self.patch_size, W // self.patch_size - N = T * pH * pW - assert x.size(1) == N - pos = create_position_matrix( - T, pH=pH, pW=pW, device=x.device, dtype=torch.float32 - ) # (N, 3) - rope_cos, rope_sin = compute_mixed_rotation( - freqs=self.pos_frequencies, pos=pos - ) # Each are (N, num_heads, dim // 2) + # Construct position array of size [N, 3]. + # pos[:, 0] is the frame index for each location, + # pos[:, 1] is the row index for each location, and + # pos[:, 2] is the column index for each location. + pH, pW = H // self.patch_size, W // self.patch_size + N = T * pH * pW + assert x.size(1) == N + pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3) + rope_cos, rope_sin = compute_mixed_rotation(freqs=self.pos_frequencies, pos=pos) # Each are (N, num_heads, dim // 2) - with torch.profiler.record_function("t_emb"): - # Global vector embedding for conditionings. - c_t = self.t_embedder(1 - sigma) # (B, D) + # Global vector embedding for conditionings. + c_t = self.t_embedder(1 - sigma) # (B, D) - with torch.profiler.record_function("t5_pool"): - # Pool T5 tokens using attention pooler - # Note y_feat[1] contains T5 token features. - # print("B", B) - # print("t5 feat shape",t5_feat.shape) - # print("t5 mask shape", t5_mask.shape) - assert ( - t5_feat.size(1) == self.t5_token_length - ), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat." - t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D) - assert ( - t5_y_pool.size(0) == B - ), f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool." + # Pool T5 tokens using attention pooler + # Note y_feat[1] contains T5 token features. + t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D) c = c_t + t5_y_pool @@ -669,21 +628,6 @@ class AsymmDiTJoint(nn.Module): ) del y_mask - cp_rank, cp_size = get_cp_rank_size() - N = x.size(1) - M = N // cp_size - assert ( - N % cp_size == 0 - ), f"Visual sequence length ({x.shape[1]}) must be divisible by cp_size ({cp_size})." - - if cp_size > 1: - x = x.narrow(1, cp_rank * M, M) - - assert self.num_heads % cp_size == 0 - local_heads = self.num_heads // cp_size - rope_cos = rope_cos.narrow(1, cp_rank * local_heads, local_heads) - rope_sin = rope_sin.narrow(1, cp_rank * local_heads, local_heads) - for i, block in enumerate(self.blocks): x, y_feat = block( x, @@ -695,11 +639,7 @@ class AsymmDiTJoint(nn.Module): ) # (B, M, D), (B, L, D) del y_feat # Final layers don't use dense text features. - x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) - - patch = x.size(2) - x = all_gather(x) - x = rearrange(x, "(G B) M P -> B (G M) P", G=cp_size, P=patch) + x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) x = rearrange( x, "B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)", diff --git a/mochi_preview/vae/model.py b/mochi_preview/vae/model.py index 2eef51a..823343d 100644 --- a/mochi_preview/vae/model.py +++ b/mochi_preview/vae/model.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from ..dit.joint_model.context_parallel import get_cp_rank_size, local_shard +from ..dit.joint_model.context_parallel import get_cp_rank_size from ..vae.cp_conv import cp_pass_frames, gather_all_frames