cleanup and align more to comfy code, switch to using cpu seed as well
This commit is contained in:
parent
3cf9289e08
commit
0dc011d1b6
@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from .layers import (
|
||||
FeedForward,
|
||||
@ -14,20 +13,12 @@ from .layers import (
|
||||
|
||||
from .mod_rmsnorm import modulated_rmsnorm
|
||||
from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm)
|
||||
from .rope_mixed import (
|
||||
compute_mixed_rotation,
|
||||
create_position_matrix,
|
||||
)
|
||||
from .rope_mixed import (compute_mixed_rotation, create_position_matrix)
|
||||
from .temporal_rope import apply_rotary_emb_qk_real
|
||||
from .utils import (
|
||||
pool_tokens,
|
||||
modulate,
|
||||
pad_and_split_xy,
|
||||
unify_streams,
|
||||
)
|
||||
from .utils import (pool_tokens, modulate)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||
from flash_attn import flash_attn_func
|
||||
FLASH_ATTN_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASH_ATTN_IS_AVAILABLE = False
|
||||
@ -45,6 +36,7 @@ backends.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
backends.append(SDPBackend.MATH)
|
||||
|
||||
import comfy.model_management as mm
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(
|
||||
@ -103,10 +95,9 @@ class AttentionPool(nn.Module):
|
||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||
|
||||
# Compute attention.
|
||||
with sdpa_kernel(backends):
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=0.0
|
||||
) # (B, H, 1, head_dim)
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
@ -166,79 +157,23 @@ class AsymmetricAttention(nn.Module):
|
||||
if update_y
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def run_qkv_y(self, y):
|
||||
local_heads = self.num_heads
|
||||
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
|
||||
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
|
||||
q_y, k_y, v_y = qkv_y.unbind(2)
|
||||
return q_y, k_y, v_y
|
||||
|
||||
|
||||
def prepare_qkv(
|
||||
self,
|
||||
x: torch.Tensor, # (B, N, dim_x)
|
||||
y: torch.Tensor, # (B, L, dim_y)
|
||||
*,
|
||||
scale_x: torch.Tensor,
|
||||
scale_y: torch.Tensor,
|
||||
rope_cos: torch.Tensor,
|
||||
rope_sin: torch.Tensor,
|
||||
valid_token_indices: torch.Tensor,
|
||||
):
|
||||
# Pre-norm for visual features
|
||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||
|
||||
# Process visual features
|
||||
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
||||
#assert qkv_x.dtype == torch.bfloat16
|
||||
|
||||
# Move QKV dimension to the front.
|
||||
# B M (3 H d) -> 3 B M H d
|
||||
B, M, _ = qkv_x.size()
|
||||
qkv_x = qkv_x.view(B, M, 3, self.num_heads, -1)
|
||||
qkv_x = qkv_x.permute(2, 0, 1, 3, 4)
|
||||
|
||||
# Process text features
|
||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
|
||||
q_y = self.q_norm_y(q_y)
|
||||
k_y = self.k_norm_y(k_y)
|
||||
|
||||
# Split qkv_x into q, k, v
|
||||
q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
|
||||
q_x = self.q_norm_x(q_x)
|
||||
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||
k_x = self.k_norm_x(k_x)
|
||||
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||
|
||||
# Unite streams
|
||||
qkv = unify_streams(
|
||||
q_x,
|
||||
k_x,
|
||||
v_x,
|
||||
q_y,
|
||||
k_y,
|
||||
v_y,
|
||||
valid_token_indices,
|
||||
)
|
||||
|
||||
return qkv
|
||||
|
||||
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
|
||||
def flash_attention(self, q, k ,v):
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
b, _, _, dim_head = q.shape
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen_in_batch,
|
||||
out: torch.Tensor = flash_attn_func( #q: (batch_size, seqlen, nheads, headdim)
|
||||
q, k, v,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.softmax_scale,
|
||||
) # (total, local_heads, head_dim)
|
||||
return out.view(total, local_dim)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
|
||||
|
||||
def sdpa_attention(self, qkv):
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||
def sdpa_attention(self, q, k, v):
|
||||
b, _, _, dim_head = q.shape
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
with sdpa_kernel(backends):
|
||||
out = F.scaled_dot_product_attention(
|
||||
@ -249,13 +184,10 @@ class AsymmetricAttention(nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
|
||||
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
|
||||
|
||||
def sage_attention(self, qkv):
|
||||
#q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||
|
||||
def sage_attention(self, q, k, v):
|
||||
b, _, _, dim_head = q.shape
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out = sageattn(
|
||||
q,
|
||||
@ -265,14 +197,9 @@ class AsymmetricAttention(nn.Module):
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
#print(out.shape)
|
||||
#out = rearrange(out, 'b h s d -> s (b h d)')
|
||||
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
|
||||
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
|
||||
|
||||
def comfy_attention(self, qkv):
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
|
||||
def comfy_attention(self, q, k, v):
|
||||
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
|
||||
out = optimized_attention(
|
||||
q,
|
||||
@ -281,41 +208,24 @@ class AsymmetricAttention(nn.Module):
|
||||
heads = self.num_heads,
|
||||
skip_reshape=True
|
||||
)
|
||||
return out.squeeze(0)
|
||||
return out
|
||||
|
||||
@torch.compiler.disable()
|
||||
def run_attention(
|
||||
self,
|
||||
qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim)
|
||||
*,
|
||||
B: int,
|
||||
L: int,
|
||||
M: int,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen_in_batch: int,
|
||||
valid_token_indices: torch.Tensor,
|
||||
):
|
||||
local_dim = self.num_heads * self.head_dim
|
||||
total = qkv.size(0)
|
||||
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
):
|
||||
if self.attention_mode == "flash_attn":
|
||||
out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim)
|
||||
out = self.flash_attention(q, k ,v)
|
||||
elif self.attention_mode == "sdpa":
|
||||
out = self.sdpa_attention(qkv)
|
||||
out = self.sdpa_attention(q, k, v)
|
||||
elif self.attention_mode == "sage_attn":
|
||||
out = self.sage_attention(qkv)
|
||||
out = self.sage_attention(q, k, v)
|
||||
elif self.attention_mode == "comfy":
|
||||
out = self.comfy_attention(qkv)
|
||||
|
||||
x, y = pad_and_split_xy(out, valid_token_indices, B, M, L, qkv.dtype)
|
||||
assert x.size() == (B, M, local_dim)
|
||||
assert y.size() == (B, L, local_dim)
|
||||
|
||||
x = x.view(B, M, self.num_heads, self.head_dim)
|
||||
x = x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
|
||||
x = self.proj_x(x) # (B, M, dim_x)
|
||||
y = self.proj_y(y) # (B, L, dim_y)
|
||||
return x, y
|
||||
out = self.comfy_attention(q, k, v)
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -324,45 +234,41 @@ class AsymmetricAttention(nn.Module):
|
||||
*,
|
||||
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
||||
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
num_tokens,
|
||||
**rope_rotation,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass of asymmetric multi-modal attention.
|
||||
|
||||
Args:
|
||||
x: (B, N, dim_x) tensor for visual tokens
|
||||
y: (B, L, dim_y) tensor of text token features
|
||||
packed_indices: Dict with keys for Flash Attention
|
||||
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
||||
rope_cos = rope_rotation.get("rope_cos")
|
||||
rope_sin = rope_rotation.get("rope_sin")
|
||||
|
||||
# Pre-norm for visual features
|
||||
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
||||
# Process text features
|
||||
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
||||
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
|
||||
Returns:
|
||||
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
|
||||
y: (B, L, dim_y) tensor of text token features after multi-modal attention
|
||||
"""
|
||||
B, L, _ = y.shape
|
||||
_, M, _ = x.shape
|
||||
q_y = self.q_norm_y(q_y)
|
||||
k_y = self.k_norm_y(k_y)
|
||||
|
||||
# Predict a packed QKV tensor from visual and text features.
|
||||
# Don't checkpoint the all_to_all.
|
||||
qkv = self.prepare_qkv(
|
||||
x=x,
|
||||
y=y,
|
||||
scale_x=scale_x,
|
||||
scale_y=scale_y,
|
||||
rope_cos=rope_rotation.get("rope_cos"),
|
||||
rope_sin=rope_rotation.get("rope_sin"),
|
||||
valid_token_indices=packed_indices["valid_token_indices_kv"],
|
||||
) # (total <= B * (N + L), 3, local_heads, head_dim)
|
||||
# Split qkv_x into q, k, v
|
||||
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
|
||||
q_x = self.q_norm_x(q_x)
|
||||
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
||||
k_x = self.k_norm_x(k_x)
|
||||
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
||||
|
||||
x, y = self.run_attention(
|
||||
qkv,
|
||||
B=B,
|
||||
L=L,
|
||||
M=M,
|
||||
cu_seqlens=packed_indices["cu_seqlens_kv"],
|
||||
max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
|
||||
valid_token_indices=packed_indices["valid_token_indices_kv"],
|
||||
)
|
||||
q = torch.cat([q_x, q_y[:, :num_tokens]], dim=1).transpose(1, 2)
|
||||
k = torch.cat([k_x, k_y[:, :num_tokens]], dim=1).transpose(1, 2)
|
||||
v = torch.cat([v_x, v_y[:, :num_tokens]], dim=1).transpose(1, 2)
|
||||
|
||||
xy = self.run_attention(q, k, v)
|
||||
|
||||
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
|
||||
x = self.proj_x(x)
|
||||
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
|
||||
o[:, :y.shape[1]] = y
|
||||
|
||||
y = self.proj_y(o)
|
||||
return x, y
|
||||
|
||||
class AsymmetricJointBlock(nn.Module):
|
||||
@ -453,7 +359,7 @@ class AsymmetricJointBlock(nn.Module):
|
||||
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
||||
else:
|
||||
scale_msa_y = mod_y
|
||||
|
||||
|
||||
# Self-attention block.
|
||||
x_attn, y_attn = self.attn(
|
||||
x,
|
||||
@ -467,12 +373,12 @@ class AsymmetricJointBlock(nn.Module):
|
||||
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
||||
if self.update_y:
|
||||
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
||||
|
||||
|
||||
# MLP block.
|
||||
x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
|
||||
if self.update_y:
|
||||
y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
|
||||
|
||||
|
||||
return x, y
|
||||
|
||||
def ff_block_x(self, x, scale_x, gate_x):
|
||||
@ -676,7 +582,6 @@ class AsymmDiTJoint(nn.Module):
|
||||
sigma: torch.Tensor,
|
||||
y_feat: List[torch.Tensor],
|
||||
y_mask: List[torch.Tensor],
|
||||
packed_indices: Dict[str, torch.Tensor] = None,
|
||||
rope_cos: torch.Tensor = None,
|
||||
rope_sin: torch.Tensor = None,
|
||||
):
|
||||
@ -687,18 +592,18 @@ class AsymmDiTJoint(nn.Module):
|
||||
sigma: (B,) tensor of noise standard deviations
|
||||
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
||||
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
||||
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
||||
"""
|
||||
B, _, T, H, W = x.shape
|
||||
|
||||
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
|
||||
# Have to call sdpa_kernel outside of a torch.compile region.
|
||||
num_tokens = max(1, torch.sum(y_mask[0]).item())
|
||||
with sdpa_kernel(backends):
|
||||
x, c, y_feat, rope_cos, rope_sin = self.prepare(
|
||||
x, sigma, y_feat[0], y_mask[0]
|
||||
)
|
||||
del y_mask
|
||||
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
x, y_feat = block(
|
||||
x,
|
||||
@ -706,20 +611,20 @@ class AsymmDiTJoint(nn.Module):
|
||||
y_feat,
|
||||
rope_cos=rope_cos,
|
||||
rope_sin=rope_sin,
|
||||
packed_indices=packed_indices,
|
||||
num_tokens=num_tokens,
|
||||
) # (B, M, D), (B, L, D)
|
||||
del y_feat # Final layers don't use dense text features.
|
||||
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
x = rearrange(
|
||||
x,
|
||||
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
||||
T=T,
|
||||
hp=H // self.patch_size,
|
||||
wp=W // self.patch_size,
|
||||
p1=self.patch_size,
|
||||
p2=self.patch_size,
|
||||
c=self.out_channels,
|
||||
)
|
||||
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
||||
|
||||
hp = H // self.patch_size
|
||||
wp = W // self.patch_size
|
||||
p1 = self.patch_size
|
||||
p2 = self.patch_size
|
||||
c = self.out_channels
|
||||
|
||||
x = x.view(B, T, hp, wp, p1, p2, c)
|
||||
x = x.permute(0, 6, 1, 2, 4, 3, 5)
|
||||
x = x.reshape(B, c, T, hp * p1, wp * p2)
|
||||
|
||||
return x
|
||||
|
||||
@ -1,163 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from einops import rearrange
|
||||
|
||||
_CONTEXT_PARALLEL_GROUP = None
|
||||
_CONTEXT_PARALLEL_RANK = None
|
||||
_CONTEXT_PARALLEL_GROUP_SIZE = None
|
||||
_CONTEXT_PARALLEL_GROUP_RANKS = None
|
||||
|
||||
|
||||
def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
return x
|
||||
|
||||
cp_rank, cp_size = get_cp_rank_size()
|
||||
return x.tensor_split(cp_size, dim=dim)[cp_rank]
|
||||
|
||||
|
||||
def set_cp_group(cp_group, ranks, global_rank):
|
||||
global \
|
||||
_CONTEXT_PARALLEL_GROUP, \
|
||||
_CONTEXT_PARALLEL_RANK, \
|
||||
_CONTEXT_PARALLEL_GROUP_SIZE, \
|
||||
_CONTEXT_PARALLEL_GROUP_RANKS
|
||||
if _CONTEXT_PARALLEL_GROUP is not None:
|
||||
raise RuntimeError("CP group already initialized.")
|
||||
_CONTEXT_PARALLEL_GROUP = cp_group
|
||||
_CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
|
||||
_CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
|
||||
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
|
||||
|
||||
assert (
|
||||
_CONTEXT_PARALLEL_RANK == ranks.index(global_rank)
|
||||
), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
|
||||
assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
|
||||
ranks
|
||||
), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
|
||||
|
||||
|
||||
def get_cp_group():
|
||||
if _CONTEXT_PARALLEL_GROUP is None:
|
||||
raise RuntimeError("CP group not initialized")
|
||||
return _CONTEXT_PARALLEL_GROUP
|
||||
|
||||
|
||||
def is_cp_active():
|
||||
return _CONTEXT_PARALLEL_GROUP is not None
|
||||
|
||||
|
||||
def get_cp_rank_size():
|
||||
if _CONTEXT_PARALLEL_GROUP:
|
||||
return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
|
||||
else:
|
||||
return 0, 1
|
||||
|
||||
|
||||
class AllGatherIntoTensorFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
|
||||
ctx.reduce_dtype = reduce_dtype
|
||||
ctx.group = group
|
||||
ctx.batch_size = x.size(0)
|
||||
group_size = dist.get_world_size(group)
|
||||
|
||||
x = x.contiguous()
|
||||
output = torch.empty(
|
||||
group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device
|
||||
)
|
||||
dist.all_gather_into_tensor(output, x, group=group)
|
||||
return output
|
||||
|
||||
|
||||
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
return tensor
|
||||
|
||||
return AllGatherIntoTensorFunction.apply(
|
||||
tensor, torch.float32, _CONTEXT_PARALLEL_GROUP
|
||||
)
|
||||
|
||||
|
||||
@torch.compiler.disable()
|
||||
def _all_to_all_single(output, input, group):
|
||||
# Disable compilation since torch compile changes contiguity.
|
||||
assert input.is_contiguous(), "Input tensor must be contiguous."
|
||||
assert output.is_contiguous(), "Output tensor must be contiguous."
|
||||
return dist.all_to_all_single(output, input, group=group)
|
||||
|
||||
|
||||
class CollectTokens(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
|
||||
"""Redistribute heads and receive tokens.
|
||||
|
||||
Args:
|
||||
qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
|
||||
|
||||
Returns:
|
||||
qkv: shape: [3, B, N, local_heads, head_dim]
|
||||
|
||||
where M is the number of local tokens,
|
||||
N = cp_size * M is the number of global tokens,
|
||||
local_heads = num_heads // cp_size is the number of local heads.
|
||||
"""
|
||||
ctx.group = group
|
||||
ctx.num_heads = num_heads
|
||||
cp_size = dist.get_world_size(group)
|
||||
assert num_heads % cp_size == 0
|
||||
ctx.local_heads = num_heads // cp_size
|
||||
|
||||
qkv = rearrange(
|
||||
qkv,
|
||||
"B M (qkv G h d) -> G M h B (qkv d)",
|
||||
qkv=3,
|
||||
G=cp_size,
|
||||
h=ctx.local_heads,
|
||||
).contiguous()
|
||||
|
||||
output_chunks = torch.empty_like(qkv)
|
||||
_all_to_all_single(output_chunks, qkv, group=group)
|
||||
|
||||
return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
|
||||
|
||||
|
||||
def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
# Move QKV dimension to the front.
|
||||
# B M (3 H d) -> 3 B M H d
|
||||
B, M, _ = x.size()
|
||||
x = x.view(B, M, 3, num_heads, -1)
|
||||
return x.permute(2, 0, 1, 3, 4)
|
||||
|
||||
return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
|
||||
|
||||
|
||||
class CollectHeads(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
|
||||
"""Redistribute tokens and receive heads.
|
||||
|
||||
Args:
|
||||
x: Output of attention. Shape: [B, N, local_heads, head_dim]
|
||||
|
||||
Returns:
|
||||
Shape: [B, M, num_heads * head_dim]
|
||||
"""
|
||||
ctx.group = group
|
||||
ctx.local_heads = x.size(2)
|
||||
ctx.head_dim = x.size(3)
|
||||
group_size = dist.get_world_size(group)
|
||||
x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
|
||||
output = torch.empty_like(x)
|
||||
_all_to_all_single(output, x, group=group)
|
||||
del x
|
||||
return rearrange(output, "G h M B D -> B M (G h D)")
|
||||
|
||||
|
||||
def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
|
||||
if not _CONTEXT_PARALLEL_GROUP:
|
||||
# Merge heads.
|
||||
return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
|
||||
|
||||
return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)
|
||||
@ -130,8 +130,6 @@ class PatchEmbed(nn.Module):
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h))
|
||||
|
||||
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
||||
#print("x",x.dtype, x.device)
|
||||
#print(self.proj.weight.dtype, self.proj.weight.device)
|
||||
x = self.proj(x)
|
||||
|
||||
# Flatten temporal and spatial dimensions.
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import functools
|
||||
#import functools
|
||||
import math
|
||||
|
||||
import torch
|
||||
@ -21,7 +21,7 @@ def centers(start: float, stop, num, dtype=None, device=None):
|
||||
return (edges[:-1] + edges[1:]) / 2
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
#@functools.lru_cache(maxsize=1)
|
||||
def create_position_matrix(
|
||||
T: int,
|
||||
pH: int,
|
||||
|
||||
@ -28,95 +28,4 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
|
||||
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||
return pooled
|
||||
|
||||
|
||||
class PadSplitXY(torch.autograd.Function):
|
||||
"""
|
||||
Merge heads, pad and extract visual and text tokens,
|
||||
and split along the sequence length.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
xy: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
B: int,
|
||||
N: int,
|
||||
L: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim).
|
||||
indices: Valid token indices out of unpacked tensor. Shape: (total,)
|
||||
|
||||
Returns:
|
||||
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
|
||||
y: Text tokens. Shape: (B, L, num_heads * head_dim).
|
||||
"""
|
||||
ctx.save_for_backward(indices)
|
||||
ctx.B, ctx.N, ctx.L = B, N, L
|
||||
D = xy.size(1)
|
||||
|
||||
# Pad sequences to (B, N + L, dim).
|
||||
assert indices.ndim == 1
|
||||
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
|
||||
indices = indices.unsqueeze(1).expand(
|
||||
-1, D
|
||||
) # (total,) -> (total, num_heads * head_dim)
|
||||
output.scatter_(0, indices, xy)
|
||||
xy = output.view(B, N + L, D)
|
||||
|
||||
# Split visual and text tokens along the sequence length.
|
||||
return torch.tensor_split(xy, (N,), dim=1)
|
||||
|
||||
|
||||
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return PadSplitXY.apply(xy, indices, B, N, L, dtype)
|
||||
|
||||
|
||||
class UnifyStreams(torch.autograd.Function):
|
||||
"""Unify visual and text streams."""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q_x: torch.Tensor,
|
||||
k_x: torch.Tensor,
|
||||
v_x: torch.Tensor,
|
||||
q_y: torch.Tensor,
|
||||
k_y: torch.Tensor,
|
||||
v_y: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
q_x: (B, N, num_heads, head_dim)
|
||||
k_x: (B, N, num_heads, head_dim)
|
||||
v_x: (B, N, num_heads, head_dim)
|
||||
q_y: (B, L, num_heads, head_dim)
|
||||
k_y: (B, L, num_heads, head_dim)
|
||||
v_y: (B, L, num_heads, head_dim)
|
||||
indices: (total <= B * (N + L))
|
||||
|
||||
Returns:
|
||||
qkv: (total <= B * (N + L), 3, num_heads, head_dim)
|
||||
"""
|
||||
ctx.save_for_backward(indices)
|
||||
B, N, num_heads, head_dim = q_x.size()
|
||||
ctx.B, ctx.N, ctx.L = B, N, q_y.size(1)
|
||||
D = num_heads * head_dim
|
||||
|
||||
q = torch.cat([q_x, q_y], dim=1)
|
||||
k = torch.cat([k_x, k_y], dim=1)
|
||||
v = torch.cat([v_x, v_y], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D)
|
||||
|
||||
indices = indices[:, None, None].expand(-1, 3, D)
|
||||
qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim)
|
||||
return qkv.unflatten(2, (num_heads, head_dim))
|
||||
|
||||
|
||||
def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor:
|
||||
return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices)
|
||||
|
||||
@ -33,9 +33,7 @@ except:
|
||||
pass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
from einops import rearrange, repeat
|
||||
|
||||
#from .dit.joint_model.context_parallel import get_cp_rank_size
|
||||
from tqdm import tqdm
|
||||
@ -79,48 +77,6 @@ def unnormalize_latents(
|
||||
assert z.size(1) == mean.size(0) == std.size(0)
|
||||
return z * std.to(z) + mean.to(z)
|
||||
|
||||
|
||||
|
||||
def compute_packed_indices(
|
||||
N: int,
|
||||
text_mask: List[torch.Tensor],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80
|
||||
|
||||
Args:
|
||||
N: Number of visual tokens.
|
||||
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding.
|
||||
|
||||
Returns:
|
||||
packed_indices: Dict with keys for Flash Attention:
|
||||
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding)
|
||||
in the packed sequence.
|
||||
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence.
|
||||
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch.
|
||||
"""
|
||||
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
|
||||
assert N > 0 and len(text_mask) == 1
|
||||
text_mask = text_mask[0]
|
||||
|
||||
mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L)
|
||||
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
|
||||
valid_token_indices = torch.nonzero(
|
||||
mask.flatten(), as_tuple=False
|
||||
).flatten() # up to (B * (N + L),)
|
||||
|
||||
assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,)
|
||||
cu_seqlens = F.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
||||
)
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
|
||||
return {
|
||||
"cu_seqlens_kv": cu_seqlens,
|
||||
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
|
||||
"valid_token_indices_kv": valid_token_indices,
|
||||
}
|
||||
|
||||
class T2VSynthMochiModel:
|
||||
def __init__(
|
||||
self,
|
||||
@ -166,6 +122,17 @@ class T2VSynthMochiModel:
|
||||
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
|
||||
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
|
||||
dit_sd = load_torch_file(dit_checkpoint_path)
|
||||
|
||||
#comfy format
|
||||
prefix = "model.diffusion_model."
|
||||
first_key = next(iter(dit_sd), None)
|
||||
if first_key and first_key.startswith(prefix):
|
||||
new_dit_sd = {
|
||||
key[len(prefix):] if key.startswith(prefix) else key: value
|
||||
for key, value in dit_sd.items()
|
||||
}
|
||||
dit_sd = new_dit_sd
|
||||
|
||||
if "gguf" in dit_checkpoint_path.lower():
|
||||
logging.info("Loading GGUF model state_dict...")
|
||||
from .. import mz_gguf_loader
|
||||
@ -209,14 +176,6 @@ class T2VSynthMochiModel:
|
||||
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
|
||||
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
|
||||
|
||||
def get_packed_indices(self, y_mask, *, lT, lW, lH):
|
||||
patch_size = 2
|
||||
N = lT * lH * lW // (patch_size**2)
|
||||
assert len(y_mask) == 1
|
||||
packed_indices = compute_packed_indices(N, y_mask)
|
||||
self.move_to_device_(packed_indices)
|
||||
return packed_indices
|
||||
|
||||
def move_to_device_(self, sample):
|
||||
if isinstance(sample, dict):
|
||||
for key in sample.keys():
|
||||
@ -227,7 +186,7 @@ class T2VSynthMochiModel:
|
||||
torch.manual_seed(args["seed"])
|
||||
torch.cuda.manual_seed(args["seed"])
|
||||
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator = torch.Generator(device=torch.device("cpu"))
|
||||
generator.manual_seed(args["seed"])
|
||||
|
||||
num_frames = args["num_frames"]
|
||||
@ -259,14 +218,13 @@ class T2VSynthMochiModel:
|
||||
T = (num_frames - 1) // temporal_downsample + 1
|
||||
H = height // spatial_downsample
|
||||
W = width // spatial_downsample
|
||||
latent_dims = dict(lT=T, lW=W, lH=H)
|
||||
|
||||
z = torch.randn(
|
||||
(B, C, T, H, W),
|
||||
device=self.device,
|
||||
device=torch.device("cpu"),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
).to(self.device)
|
||||
if in_samples is not None:
|
||||
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device)
|
||||
|
||||
@ -277,14 +235,7 @@ class T2VSynthMochiModel:
|
||||
sample_null = {
|
||||
"y_mask": [args["negative_embeds"]["attention_mask"].to(self.device)],
|
||||
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
|
||||
}
|
||||
|
||||
sample["packed_indices"] = self.get_packed_indices(
|
||||
sample["y_mask"], **latent_dims
|
||||
)
|
||||
sample_null["packed_indices"] = self.get_packed_indices(
|
||||
sample_null["y_mask"], **latent_dims
|
||||
)
|
||||
}
|
||||
|
||||
self.dit.to(self.device)
|
||||
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
|
||||
@ -292,7 +243,7 @@ class T2VSynthMochiModel:
|
||||
else:
|
||||
autocast_dtype = torch.bfloat16
|
||||
|
||||
def model_fn(*, z, sigma, cfg_scale):
|
||||
def model_fn(*, z, sigma, cfg_scale):
|
||||
nonlocal sample, sample_null
|
||||
if cfg_scale > 1.0:
|
||||
out_cond = self.dit(z, sigma, **sample)
|
||||
@ -314,8 +265,7 @@ class T2VSynthMochiModel:
|
||||
sigma=torch.full([B], sigma, device=z.device),
|
||||
cfg_scale=cfg_schedule[i],
|
||||
)
|
||||
pred = pred.to(z)
|
||||
z = z + dsigma * pred
|
||||
z = z + dsigma * pred.to(z)
|
||||
if callback is not None:
|
||||
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
|
||||
else:
|
||||
|
||||
@ -1,152 +0,0 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size
|
||||
|
||||
|
||||
def cast_tuple(t, length=1):
|
||||
return t if isinstance(t, tuple) else ((t,) * length)
|
||||
|
||||
|
||||
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass that handles communication between ranks for inference.
|
||||
Args:
|
||||
x: Tensor of shape (B, C, T, H, W)
|
||||
frames_to_send: int, number of frames to communicate between ranks
|
||||
Returns:
|
||||
output: Tensor of shape (B, C, T', H, W)
|
||||
"""
|
||||
cp_rank, cp_world_size = cp.get_cp_rank_size()
|
||||
if frames_to_send == 0 or cp_world_size == 1:
|
||||
return x
|
||||
|
||||
group = get_cp_group()
|
||||
global_rank = dist.get_rank()
|
||||
|
||||
# Send to next rank
|
||||
if cp_rank < cp_world_size - 1:
|
||||
assert x.size(2) >= frames_to_send
|
||||
tail = x[:, :, -frames_to_send:].contiguous()
|
||||
dist.send(tail, global_rank + 1, group=group)
|
||||
|
||||
# Receive from previous rank
|
||||
if cp_rank > 0:
|
||||
B, C, _, H, W = x.shape
|
||||
recv_buffer = torch.empty(
|
||||
(B, C, frames_to_send, H, W),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
dist.recv(recv_buffer, global_rank - 1, group=group)
|
||||
x = torch.cat([recv_buffer, x], dim=2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
|
||||
if max_T > x.size(2):
|
||||
pad_T = max_T - x.size(2)
|
||||
pad_dims = (0, 0, 0, 0, 0, pad_T)
|
||||
return F.pad(x, pad_dims)
|
||||
return x
|
||||
|
||||
|
||||
def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Gathers all frames from all processes for inference.
|
||||
Args:
|
||||
x: Tensor of shape (B, C, T, H, W)
|
||||
Returns:
|
||||
output: Tensor of shape (B, C, T_total, H, W)
|
||||
"""
|
||||
cp_rank, cp_size = get_cp_rank_size()
|
||||
cp_group = get_cp_group()
|
||||
|
||||
# Ensure the tensor is contiguous for collective operations
|
||||
x = x.contiguous()
|
||||
|
||||
# Get the local time dimension size
|
||||
local_T = x.size(2)
|
||||
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
|
||||
|
||||
# Gather all T sizes from all processes
|
||||
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
|
||||
dist.all_gather(all_T, local_T_tensor, group=cp_group)
|
||||
all_T = [t.item() for t in all_T]
|
||||
|
||||
# Pad the tensor at the end of the time dimension to match max_T
|
||||
max_T = max(all_T)
|
||||
x = _pad_to_max(x, max_T).contiguous()
|
||||
|
||||
# Prepare a list to hold the gathered tensors
|
||||
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
|
||||
|
||||
# Perform the all_gather operation
|
||||
dist.all_gather(gathered_x, x, group=cp_group)
|
||||
|
||||
# Slice each gathered tensor back to its original T size
|
||||
for idx, t_size in enumerate(all_T):
|
||||
gathered_x[idx] = gathered_x[idx][:, :, :t_size]
|
||||
|
||||
return torch.cat(gathered_x, dim=2)
|
||||
|
||||
|
||||
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
|
||||
"""Estimate memory usage based on input tensor size and data type."""
|
||||
element_size = input.element_size() # Size in bytes of each element
|
||||
memory_bytes = input.numel() * element_size
|
||||
memory_gb = memory_bytes / 1024**3
|
||||
return memory_gb > max_gb
|
||||
|
||||
|
||||
class ContextParallelCausalConv3d(torch.nn.Conv3d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
**kwargs,
|
||||
):
|
||||
kernel_size = cast_tuple(kernel_size, 3)
|
||||
stride = cast_tuple(stride, 3)
|
||||
height_pad = (kernel_size[1] - 1) // 2
|
||||
width_pad = (kernel_size[2] - 1) // 2
|
||||
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=(1, 1, 1),
|
||||
padding=(0, height_pad, width_pad),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
cp_rank, cp_world_size = get_cp_rank_size()
|
||||
|
||||
context_size = self.kernel_size[0] - 1
|
||||
if cp_rank == 0:
|
||||
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
||||
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
|
||||
|
||||
if cp_world_size == 1:
|
||||
return super().forward(x)
|
||||
|
||||
if all(s == 1 for s in self.stride):
|
||||
# Receive some frames from previous rank.
|
||||
x = cp_pass_frames(x, context_size)
|
||||
return super().forward(x)
|
||||
|
||||
# Less efficient implementation for strided convs.
|
||||
# All gather x, infer and chunk.
|
||||
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
|
||||
x = super().forward(x)
|
||||
x_chunks = x.tensor_split(cp_world_size, dim=2)
|
||||
assert len(x_chunks) == cp_world_size
|
||||
return x_chunks[cp_rank]
|
||||
Loading…
x
Reference in New Issue
Block a user