cleanup and align more to comfy code, switch to using cpu seed as well

This commit is contained in:
kijai 2024-11-03 18:36:23 +02:00
parent 3cf9289e08
commit 0dc011d1b6
8 changed files with 100 additions and 652 deletions

View File

@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .layers import (
FeedForward,
@ -14,20 +13,12 @@ from .layers import (
from .mod_rmsnorm import modulated_rmsnorm
from .residual_tanh_gated_rmsnorm import (residual_tanh_gated_rmsnorm)
from .rope_mixed import (
compute_mixed_rotation,
create_position_matrix,
)
from .rope_mixed import (compute_mixed_rotation, create_position_matrix)
from .temporal_rope import apply_rotary_emb_qk_real
from .utils import (
pool_tokens,
modulate,
pad_and_split_xy,
unify_streams,
)
from .utils import (pool_tokens, modulate)
try:
from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn import flash_attn_func
FLASH_ATTN_IS_AVAILABLE = True
except ImportError:
FLASH_ATTN_IS_AVAILABLE = False
@ -45,6 +36,7 @@ backends.append(SDPBackend.EFFICIENT_ATTENTION)
backends.append(SDPBackend.MATH)
import comfy.model_management as mm
from comfy.ldm.modules.attention import optimized_attention
class AttentionPool(nn.Module):
def __init__(
@ -103,10 +95,9 @@ class AttentionPool(nn.Module):
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
with sdpa_kernel(backends):
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0.0
) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
@ -167,78 +158,22 @@ class AsymmetricAttention(nn.Module):
else nn.Identity()
)
def run_qkv_y(self, y):
local_heads = self.num_heads
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
q_y, k_y, v_y = qkv_y.unbind(2)
return q_y, k_y, v_y
def prepare_qkv(
self,
x: torch.Tensor, # (B, N, dim_x)
y: torch.Tensor, # (B, L, dim_y)
*,
scale_x: torch.Tensor,
scale_y: torch.Tensor,
rope_cos: torch.Tensor,
rope_sin: torch.Tensor,
valid_token_indices: torch.Tensor,
):
# Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
# Process visual features
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
#assert qkv_x.dtype == torch.bfloat16
# Move QKV dimension to the front.
# B M (3 H d) -> 3 B M H d
B, M, _ = qkv_x.size()
qkv_x = qkv_x.view(B, M, 3, self.num_heads, -1)
qkv_x = qkv_x.permute(2, 0, 1, 3, 4)
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y)
# Split qkv_x into q, k, v
q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
q_x = self.q_norm_x(q_x)
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
k_x = self.k_norm_x(k_x)
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
# Unite streams
qkv = unify_streams(
q_x,
k_x,
v_x,
q_y,
k_y,
v_y,
valid_token_indices,
)
return qkv
def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
def flash_attention(self, q, k ,v):
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
b, _, _, dim_head = q.shape
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen_in_batch,
out: torch.Tensor = flash_attn_func( #q: (batch_size, seqlen, nheads, headdim)
q, k, v,
dropout_p=0.0,
softmax_scale=self.softmax_scale,
) # (total, local_heads, head_dim)
return out.view(total, local_dim)
out = out.permute(0, 2, 1, 3)
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def sdpa_attention(self, qkv):
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
def sdpa_attention(self, q, k, v):
b, _, _, dim_head = q.shape
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
with sdpa_kernel(backends):
out = F.scaled_dot_product_attention(
@ -249,13 +184,10 @@ class AsymmetricAttention(nn.Module):
dropout_p=0.0,
is_causal=False
)
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
def sage_attention(self, qkv):
#q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1)
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def sage_attention(self, q, k, v):
b, _, _, dim_head = q.shape
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = sageattn(
q,
@ -265,14 +197,9 @@ class AsymmetricAttention(nn.Module):
dropout_p=0.0,
is_causal=False
)
#print(out.shape)
#out = rearrange(out, 'b h s d -> s (b h d)')
return out.permute(2, 0, 1, 3).reshape(out.shape[2], -1)
return out.transpose(1, 2).reshape(b, -1, self.num_heads * dim_head)
def comfy_attention(self, qkv):
from comfy.ldm.modules.attention import optimized_attention
q, k, v = qkv.unbind(dim=1)
q, k, v = [x.permute(1, 0, 2).unsqueeze(0) for x in (q, k, v)]
def comfy_attention(self, q, k, v):
with torch.autocast(mm.get_autocast_device(self.device), enabled=False):
out = optimized_attention(
q,
@ -281,41 +208,24 @@ class AsymmetricAttention(nn.Module):
heads = self.num_heads,
skip_reshape=True
)
return out.squeeze(0)
return out
@torch.compiler.disable()
def run_attention(
self,
qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim)
*,
B: int,
L: int,
M: int,
cu_seqlens: torch.Tensor,
max_seqlen_in_batch: int,
valid_token_indices: torch.Tensor,
q,
k,
v,
):
local_dim = self.num_heads * self.head_dim
total = qkv.size(0)
if self.attention_mode == "flash_attn":
out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim)
out = self.flash_attention(q, k ,v)
elif self.attention_mode == "sdpa":
out = self.sdpa_attention(qkv)
out = self.sdpa_attention(q, k, v)
elif self.attention_mode == "sage_attn":
out = self.sage_attention(qkv)
out = self.sage_attention(q, k, v)
elif self.attention_mode == "comfy":
out = self.comfy_attention(qkv)
x, y = pad_and_split_xy(out, valid_token_indices, B, M, L, qkv.dtype)
assert x.size() == (B, M, local_dim)
assert y.size() == (B, L, local_dim)
x = x.view(B, M, self.num_heads, self.head_dim)
x = x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
x = self.proj_x(x) # (B, M, dim_x)
y = self.proj_y(y) # (B, L, dim_y)
return x, y
out = self.comfy_attention(q, k, v)
return out
def forward(
self,
@ -324,45 +234,41 @@ class AsymmetricAttention(nn.Module):
*,
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
packed_indices: Dict[str, torch.Tensor] = None,
num_tokens,
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass of asymmetric multi-modal attention.
Args:
x: (B, N, dim_x) tensor for visual tokens
y: (B, L, dim_y) tensor of text token features
packed_indices: Dict with keys for Flash Attention
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
rope_cos = rope_rotation.get("rope_cos")
rope_sin = rope_rotation.get("rope_sin")
Returns:
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
y: (B, L, dim_y) tensor of text token features after multi-modal attention
"""
B, L, _ = y.shape
_, M, _ = x.shape
# Pre-norm for visual features
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
# Process text features
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
q_y, k_y, v_y = self.qkv_y(y).view(y.shape[0], y.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
# Predict a packed QKV tensor from visual and text features.
# Don't checkpoint the all_to_all.
qkv = self.prepare_qkv(
x=x,
y=y,
scale_x=scale_x,
scale_y=scale_y,
rope_cos=rope_rotation.get("rope_cos"),
rope_sin=rope_rotation.get("rope_sin"),
valid_token_indices=packed_indices["valid_token_indices_kv"],
) # (total <= B * (N + L), 3, local_heads, head_dim)
q_y = self.q_norm_y(q_y)
k_y = self.k_norm_y(k_y)
x, y = self.run_attention(
qkv,
B=B,
L=L,
M=M,
cu_seqlens=packed_indices["cu_seqlens_kv"],
max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
valid_token_indices=packed_indices["valid_token_indices_kv"],
)
# Split qkv_x into q, k, v
q_x, k_x, v_x = self.qkv_x(x).view(x.shape[0], x.shape[1], 3, self.num_heads, -1).unbind(2) # (B, N, local_h, head_dim)
q_x = self.q_norm_x(q_x)
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
k_x = self.k_norm_x(k_x)
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
q = torch.cat([q_x, q_y[:, :num_tokens]], dim=1).transpose(1, 2)
k = torch.cat([k_x, k_y[:, :num_tokens]], dim=1).transpose(1, 2)
v = torch.cat([v_x, v_y[:, :num_tokens]], dim=1).transpose(1, 2)
xy = self.run_attention(q, k, v)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
o = torch.zeros(y.shape[0], q_y.shape[1], y.shape[-1], device=y.device, dtype=y.dtype)
o[:, :y.shape[1]] = y
y = self.proj_y(o)
return x, y
class AsymmetricJointBlock(nn.Module):
@ -676,7 +582,6 @@ class AsymmDiTJoint(nn.Module):
sigma: torch.Tensor,
y_feat: List[torch.Tensor],
y_mask: List[torch.Tensor],
packed_indices: Dict[str, torch.Tensor] = None,
rope_cos: torch.Tensor = None,
rope_sin: torch.Tensor = None,
):
@ -687,12 +592,12 @@ class AsymmDiTJoint(nn.Module):
sigma: (B,) tensor of noise standard deviations
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
"""
B, _, T, H, W = x.shape
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
# Have to call sdpa_kernel outside of a torch.compile region.
num_tokens = max(1, torch.sum(y_mask[0]).item())
with sdpa_kernel(backends):
x, c, y_feat, rope_cos, rope_sin = self.prepare(
x, sigma, y_feat[0], y_mask[0]
@ -706,20 +611,20 @@ class AsymmDiTJoint(nn.Module):
y_feat,
rope_cos=rope_cos,
rope_sin=rope_sin,
packed_indices=packed_indices,
num_tokens=num_tokens,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
x = rearrange(
x,
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
T=T,
hp=H // self.patch_size,
wp=W // self.patch_size,
p1=self.patch_size,
p2=self.patch_size,
c=self.out_channels,
)
hp = H // self.patch_size
wp = W // self.patch_size
p1 = self.patch_size
p2 = self.patch_size
c = self.out_channels
x = x.view(B, T, hp, wp, p1, p2, c)
x = x.permute(0, 6, 1, 2, 4, 3, 5)
x = x.reshape(B, c, T, hp * p1, wp * p2)
return x

View File

@ -1,163 +0,0 @@
import torch
import torch.distributed as dist
from einops import rearrange
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_RANK = None
_CONTEXT_PARALLEL_GROUP_SIZE = None
_CONTEXT_PARALLEL_GROUP_RANKS = None
def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
return x
cp_rank, cp_size = get_cp_rank_size()
return x.tensor_split(cp_size, dim=dim)[cp_rank]
def set_cp_group(cp_group, ranks, global_rank):
global \
_CONTEXT_PARALLEL_GROUP, \
_CONTEXT_PARALLEL_RANK, \
_CONTEXT_PARALLEL_GROUP_SIZE, \
_CONTEXT_PARALLEL_GROUP_RANKS
if _CONTEXT_PARALLEL_GROUP is not None:
raise RuntimeError("CP group already initialized.")
_CONTEXT_PARALLEL_GROUP = cp_group
_CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
_CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
assert (
_CONTEXT_PARALLEL_RANK == ranks.index(global_rank)
), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
ranks
), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
def get_cp_group():
if _CONTEXT_PARALLEL_GROUP is None:
raise RuntimeError("CP group not initialized")
return _CONTEXT_PARALLEL_GROUP
def is_cp_active():
return _CONTEXT_PARALLEL_GROUP is not None
def get_cp_rank_size():
if _CONTEXT_PARALLEL_GROUP:
return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
else:
return 0, 1
class AllGatherIntoTensorFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
ctx.reduce_dtype = reduce_dtype
ctx.group = group
ctx.batch_size = x.size(0)
group_size = dist.get_world_size(group)
x = x.contiguous()
output = torch.empty(
group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device
)
dist.all_gather_into_tensor(output, x, group=group)
return output
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
return tensor
return AllGatherIntoTensorFunction.apply(
tensor, torch.float32, _CONTEXT_PARALLEL_GROUP
)
@torch.compiler.disable()
def _all_to_all_single(output, input, group):
# Disable compilation since torch compile changes contiguity.
assert input.is_contiguous(), "Input tensor must be contiguous."
assert output.is_contiguous(), "Output tensor must be contiguous."
return dist.all_to_all_single(output, input, group=group)
class CollectTokens(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
"""Redistribute heads and receive tokens.
Args:
qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
Returns:
qkv: shape: [3, B, N, local_heads, head_dim]
where M is the number of local tokens,
N = cp_size * M is the number of global tokens,
local_heads = num_heads // cp_size is the number of local heads.
"""
ctx.group = group
ctx.num_heads = num_heads
cp_size = dist.get_world_size(group)
assert num_heads % cp_size == 0
ctx.local_heads = num_heads // cp_size
qkv = rearrange(
qkv,
"B M (qkv G h d) -> G M h B (qkv d)",
qkv=3,
G=cp_size,
h=ctx.local_heads,
).contiguous()
output_chunks = torch.empty_like(qkv)
_all_to_all_single(output_chunks, qkv, group=group)
return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
# Move QKV dimension to the front.
# B M (3 H d) -> 3 B M H d
B, M, _ = x.size()
x = x.view(B, M, 3, num_heads, -1)
return x.permute(2, 0, 1, 3, 4)
return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
class CollectHeads(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
"""Redistribute tokens and receive heads.
Args:
x: Output of attention. Shape: [B, N, local_heads, head_dim]
Returns:
Shape: [B, M, num_heads * head_dim]
"""
ctx.group = group
ctx.local_heads = x.size(2)
ctx.head_dim = x.size(3)
group_size = dist.get_world_size(group)
x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
output = torch.empty_like(x)
_all_to_all_single(output, x, group=group)
del x
return rearrange(output, "G h M B D -> B M (G h D)")
def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
if not _CONTEXT_PARALLEL_GROUP:
# Merge heads.
return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)

View File

@ -130,8 +130,6 @@ class PatchEmbed(nn.Module):
x = F.pad(x, (0, pad_w, 0, pad_h))
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
#print("x",x.dtype, x.device)
#print(self.proj.weight.dtype, self.proj.weight.device)
x = self.proj(x)
# Flatten temporal and spatial dimensions.

View File

@ -1,4 +1,4 @@
import functools
#import functools
import math
import torch
@ -21,7 +21,7 @@ def centers(start: float, stop, num, dtype=None, device=None):
return (edges[:-1] + edges[1:]) / 2
@functools.lru_cache(maxsize=1)
#@functools.lru_cache(maxsize=1)
def create_position_matrix(
T: int,
pH: int,

View File

@ -29,94 +29,3 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
class PadSplitXY(torch.autograd.Function):
"""
Merge heads, pad and extract visual and text tokens,
and split along the sequence length.
"""
@staticmethod
def forward(
ctx,
xy: torch.Tensor,
indices: torch.Tensor,
B: int,
N: int,
L: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim).
indices: Valid token indices out of unpacked tensor. Shape: (total,)
Returns:
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
y: Text tokens. Shape: (B, L, num_heads * head_dim).
"""
ctx.save_for_backward(indices)
ctx.B, ctx.N, ctx.L = B, N, L
D = xy.size(1)
# Pad sequences to (B, N + L, dim).
assert indices.ndim == 1
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
indices = indices.unsqueeze(1).expand(
-1, D
) # (total,) -> (total, num_heads * head_dim)
output.scatter_(0, indices, xy)
xy = output.view(B, N + L, D)
# Split visual and text tokens along the sequence length.
return torch.tensor_split(xy, (N,), dim=1)
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
return PadSplitXY.apply(xy, indices, B, N, L, dtype)
class UnifyStreams(torch.autograd.Function):
"""Unify visual and text streams."""
@staticmethod
def forward(
ctx,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
q_y: torch.Tensor,
k_y: torch.Tensor,
v_y: torch.Tensor,
indices: torch.Tensor,
):
"""
Args:
q_x: (B, N, num_heads, head_dim)
k_x: (B, N, num_heads, head_dim)
v_x: (B, N, num_heads, head_dim)
q_y: (B, L, num_heads, head_dim)
k_y: (B, L, num_heads, head_dim)
v_y: (B, L, num_heads, head_dim)
indices: (total <= B * (N + L))
Returns:
qkv: (total <= B * (N + L), 3, num_heads, head_dim)
"""
ctx.save_for_backward(indices)
B, N, num_heads, head_dim = q_x.size()
ctx.B, ctx.N, ctx.L = B, N, q_y.size(1)
D = num_heads * head_dim
q = torch.cat([q_x, q_y], dim=1)
k = torch.cat([k_x, k_y], dim=1)
v = torch.cat([v_x, v_y], dim=1)
qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D)
indices = indices[:, None, None].expand(-1, 3, D)
qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim)
return qkv.unflatten(2, (num_heads, head_dim))
def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor:
return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices)

View File

@ -33,9 +33,7 @@ except:
pass
import torch
import torch.nn.functional as F
import torch.utils.data
from einops import rearrange, repeat
#from .dit.joint_model.context_parallel import get_cp_rank_size
from tqdm import tqdm
@ -79,48 +77,6 @@ def unnormalize_latents(
assert z.size(1) == mean.size(0) == std.size(0)
return z * std.to(z) + mean.to(z)
def compute_packed_indices(
N: int,
text_mask: List[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80
Args:
N: Number of visual tokens.
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding.
Returns:
packed_indices: Dict with keys for Flash Attention:
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding)
in the packed sequence.
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence.
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch.
"""
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
assert N > 0 and len(text_mask) == 1
text_mask = text_mask[0]
mask = F.pad(text_mask, (N, 0), value=True) # (B, N + L)
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
valid_token_indices = torch.nonzero(
mask.flatten(), as_tuple=False
).flatten() # up to (B * (N + L),)
assert valid_token_indices.size(0) >= text_mask.size(0) * N # At least (B * N,)
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
max_seqlen_in_batch = seqlens_in_batch.max().item()
return {
"cu_seqlens_kv": cu_seqlens,
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
"valid_token_indices_kv": valid_token_indices,
}
class T2VSynthMochiModel:
def __init__(
self,
@ -166,6 +122,17 @@ class T2VSynthMochiModel:
params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"}
logging.info(f"Loading model state_dict from {dit_checkpoint_path}...")
dit_sd = load_torch_file(dit_checkpoint_path)
#comfy format
prefix = "model.diffusion_model."
first_key = next(iter(dit_sd), None)
if first_key and first_key.startswith(prefix):
new_dit_sd = {
key[len(prefix):] if key.startswith(prefix) else key: value
for key, value in dit_sd.items()
}
dit_sd = new_dit_sd
if "gguf" in dit_checkpoint_path.lower():
logging.info("Loading GGUF model state_dict...")
from .. import mz_gguf_loader
@ -209,14 +176,6 @@ class T2VSynthMochiModel:
self.vae_mean = torch.Tensor(vae_stats["mean"]).to(self.device)
self.vae_std = torch.Tensor(vae_stats["std"]).to(self.device)
def get_packed_indices(self, y_mask, *, lT, lW, lH):
patch_size = 2
N = lT * lH * lW // (patch_size**2)
assert len(y_mask) == 1
packed_indices = compute_packed_indices(N, y_mask)
self.move_to_device_(packed_indices)
return packed_indices
def move_to_device_(self, sample):
if isinstance(sample, dict):
for key in sample.keys():
@ -227,7 +186,7 @@ class T2VSynthMochiModel:
torch.manual_seed(args["seed"])
torch.cuda.manual_seed(args["seed"])
generator = torch.Generator(device=self.device)
generator = torch.Generator(device=torch.device("cpu"))
generator.manual_seed(args["seed"])
num_frames = args["num_frames"]
@ -259,14 +218,13 @@ class T2VSynthMochiModel:
T = (num_frames - 1) // temporal_downsample + 1
H = height // spatial_downsample
W = width // spatial_downsample
latent_dims = dict(lT=T, lW=W, lH=H)
z = torch.randn(
(B, C, T, H, W),
device=self.device,
device=torch.device("cpu"),
generator=generator,
dtype=torch.float32,
)
).to(self.device)
if in_samples is not None:
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * in_samples.to(self.device)
@ -279,13 +237,6 @@ class T2VSynthMochiModel:
"y_feat": [args["negative_embeds"]["embeds"].to(self.device)]
}
sample["packed_indices"] = self.get_packed_indices(
sample["y_mask"], **latent_dims
)
sample_null["packed_indices"] = self.get_packed_indices(
sample_null["y_mask"], **latent_dims
)
self.dit.to(self.device)
if hasattr(self.dit, "cublas_half_matmul") and self.dit.cublas_half_matmul:
autocast_dtype = torch.float16
@ -314,8 +265,7 @@ class T2VSynthMochiModel:
sigma=torch.full([B], sigma, device=z.device),
cfg_scale=cfg_schedule[i],
)
pred = pred.to(z)
z = z + dsigma * pred
z = z + dsigma * pred.to(z)
if callback is not None:
callback(i, z.detach()[0].permute(1,0,2,3), None, sample_steps)
else:

View File

@ -1,152 +0,0 @@
from typing import Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
from ..dit.joint_model.context_parallel import get_cp_group, get_cp_rank_size
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
"""
Forward pass that handles communication between ranks for inference.
Args:
x: Tensor of shape (B, C, T, H, W)
frames_to_send: int, number of frames to communicate between ranks
Returns:
output: Tensor of shape (B, C, T', H, W)
"""
cp_rank, cp_world_size = cp.get_cp_rank_size()
if frames_to_send == 0 or cp_world_size == 1:
return x
group = get_cp_group()
global_rank = dist.get_rank()
# Send to next rank
if cp_rank < cp_world_size - 1:
assert x.size(2) >= frames_to_send
tail = x[:, :, -frames_to_send:].contiguous()
dist.send(tail, global_rank + 1, group=group)
# Receive from previous rank
if cp_rank > 0:
B, C, _, H, W = x.shape
recv_buffer = torch.empty(
(B, C, frames_to_send, H, W),
dtype=x.dtype,
device=x.device,
)
dist.recv(recv_buffer, global_rank - 1, group=group)
x = torch.cat([recv_buffer, x], dim=2)
return x
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
if max_T > x.size(2):
pad_T = max_T - x.size(2)
pad_dims = (0, 0, 0, 0, 0, pad_T)
return F.pad(x, pad_dims)
return x
def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
"""
Gathers all frames from all processes for inference.
Args:
x: Tensor of shape (B, C, T, H, W)
Returns:
output: Tensor of shape (B, C, T_total, H, W)
"""
cp_rank, cp_size = get_cp_rank_size()
cp_group = get_cp_group()
# Ensure the tensor is contiguous for collective operations
x = x.contiguous()
# Get the local time dimension size
local_T = x.size(2)
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
# Gather all T sizes from all processes
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
dist.all_gather(all_T, local_T_tensor, group=cp_group)
all_T = [t.item() for t in all_T]
# Pad the tensor at the end of the time dimension to match max_T
max_T = max(all_T)
x = _pad_to_max(x, max_T).contiguous()
# Prepare a list to hold the gathered tensors
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
# Perform the all_gather operation
dist.all_gather(gathered_x, x, group=cp_group)
# Slice each gathered tensor back to its original T size
for idx, t_size in enumerate(all_T):
gathered_x[idx] = gathered_x[idx][:, :, :t_size]
return torch.cat(gathered_x, dim=2)
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
"""Estimate memory usage based on input tensor size and data type."""
element_size = input.element_size() # Size in bytes of each element
memory_bytes = input.numel() * element_size
memory_gb = memory_bytes / 1024**3
return memory_gb > max_gb
class ContextParallelCausalConv3d(torch.nn.Conv3d):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
**kwargs,
):
kernel_size = cast_tuple(kernel_size, 3)
stride = cast_tuple(stride, 3)
height_pad = (kernel_size[1] - 1) // 2
width_pad = (kernel_size[2] - 1) // 2
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=(1, 1, 1),
padding=(0, height_pad, width_pad),
**kwargs,
)
def forward(self, x: torch.Tensor):
cp_rank, cp_world_size = get_cp_rank_size()
context_size = self.kernel_size[0] - 1
if cp_rank == 0:
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
if cp_world_size == 1:
return super().forward(x)
if all(s == 1 for s in self.stride):
# Receive some frames from previous rank.
x = cp_pass_frames(x, context_size)
return super().forward(x)
# Less efficient implementation for strided convs.
# All gather x, infer and chunk.
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
x = super().forward(x)
x_chunks = x.tensor_split(cp_world_size, dim=2)
assert len(x_chunks) == cp_world_size
return x_chunks[cp_rank]

View File

@ -473,6 +473,7 @@ class MochiSampler:
CATEGORY = "MochiWrapper"
def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None, opt_sigmas=None, samples=None):
mm.unload_all_models()
mm.soft_empty_cache()
if opt_sigmas is not None: