From 3613700752243f51788c19acc08fbb3f6d90e067 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 27 Oct 2024 12:23:01 +0200 Subject: [PATCH] possible sdpa kernel fixes and add optional cfg scheduling --- .../dit/joint_model/asymm_models_joint.py | 96 +++++++++++++++---- mochi_preview/dit/joint_model/utils.py | 67 ------------- mochi_preview/t2v_synth_mochi.py | 5 +- nodes.py | 10 +- 4 files changed, 90 insertions(+), 88 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 7ac5208..81c2332 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -5,8 +5,6 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from torch.nn.attention import sdpa_kernel, SDPBackend - from .layers import ( FeedForward, PatchEmbed, @@ -22,7 +20,7 @@ from .rope_mixed import ( ) from .temporal_rope import apply_rotary_emb_qk_real from .utils import ( - AttentionPool, + pool_tokens, modulate, pad_and_split_xy, unify_streams, @@ -39,16 +37,81 @@ try: except ImportError: SAGEATTN_IS_AVAILABLE = False -backends = [] -if torch.cuda.get_device_properties(0).major < 7: - backends.append(SDPBackend.MATH) - backends.append(SDPBackend.EFFICIENT_ATTENTION) -if torch.cuda.get_device_properties(0).major >= 9.0: - backends.append(SDPBackend.EFFICIENT_ATTENTION) - backends.append(SDPBackend.CUDNN_ATTENTION) -else: - backends.append(SDPBackend.EFFICIENT_ATTENTION) +from torch.nn.attention import sdpa_kernel, SDPBackend +backends = [] +backends.append(SDPBackend.CUDNN_ATTENTION) +backends.append(SDPBackend.EFFICIENT_ATTENTION) +backends.append(SDPBackend.MATH) + +import comfy.model_management as mm + +class AttentionPool(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + output_dim: int = None, + device: Optional[torch.device] = None, + ): + """ + Args: + spatial_dim (int): Number of tokens in sequence length. + embed_dim (int): Dimensionality of input tokens. + num_heads (int): Number of attention heads. + output_dim (int): Dimensionality of output tokens. Defaults to embed_dim. + """ + super().__init__() + self.num_heads = num_heads + self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device) + self.to_q = nn.Linear(embed_dim, embed_dim, device=device) + self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device) + + def forward(self, x, mask): + """ + Args: + x (torch.Tensor): (B, L, D) tensor of input tokens. + mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding. + + NOTE: We assume x does not require gradients. + + Returns: + x (torch.Tensor): (B, D) tensor of pooled tokens. + """ + D = x.size(2) + + # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L). + attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L). + attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L). + + # Average non-padding token features. These will be used as the query. + x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D) + + # Concat pooled features to input sequence. + x = torch.cat([x_pool, x], dim=1) # (B, L+1, D) + + # Compute queries, keys, values. Only the mean token is used to create a query. + kv = self.to_kv(x) # (B, L+1, 2 * D) + q = self.to_q(x[:, 0]) # (B, D) + + # Extract heads. + head_dim = D // self.num_heads + kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim) + kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim) + k, v = kv.unbind(2) # (B, H, 1+L, head_dim) + q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim) + q = q.unsqueeze(2) # (B, H, 1, head_dim) + + # Compute attention. + with sdpa_kernel(backends): + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=0.0 + ) # (B, H, 1, head_dim) + + # Concatenate heads and run output. + x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) + x = self.to_out(x) + return x class AsymmetricAttention(nn.Module): def __init__( @@ -77,6 +140,7 @@ class AsymmetricAttention(nn.Module): self.attend_to_padding = attend_to_padding self.softmax_scale = softmax_scale self.attention_mode = attention_mode + self.device = device if dim_x % num_heads != 0: raise ValueError( f"dim_x={dim_x} should be divisible by num_heads={num_heads}" @@ -162,7 +226,7 @@ class AsymmetricAttention(nn.Module): return qkv def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim): - with torch.autocast("cuda", enabled=False): + with torch.autocast(mm.get_autocast_device(self.device), enabled=False): out: torch.Tensor = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, @@ -174,7 +238,7 @@ class AsymmetricAttention(nn.Module): def sdpa_attention(self, qkv): q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) - with torch.autocast("cuda", enabled=False): + with torch.autocast(mm.get_autocast_device(self.device), enabled=False): with sdpa_kernel(backends): out = F.scaled_dot_product_attention( q, @@ -188,7 +252,7 @@ class AsymmetricAttention(nn.Module): def sage_attention(self, qkv): q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) - with torch.autocast("cuda", enabled=False): + with torch.autocast(mm.get_autocast_device(self.device), enabled=False): out = sageattn( q, k, @@ -202,7 +266,7 @@ class AsymmetricAttention(nn.Module): def comfy_attention(self, qkv): from comfy.ldm.modules.attention import optimized_attention q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) - with torch.autocast("cuda", enabled=False): + with torch.autocast(mm.get_autocast_device(self.device), enabled=False): out = optimized_attention( q, k, diff --git a/mochi_preview/dit/joint_model/utils.py b/mochi_preview/dit/joint_model/utils.py index 502e3ec..85fd2df 100644 --- a/mochi_preview/dit/joint_model/utils.py +++ b/mochi_preview/dit/joint_model/utils.py @@ -30,73 +30,6 @@ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch. return pooled -class AttentionPool(nn.Module): - def __init__( - self, - embed_dim: int, - num_heads: int, - output_dim: int = None, - device: Optional[torch.device] = None, - ): - """ - Args: - spatial_dim (int): Number of tokens in sequence length. - embed_dim (int): Dimensionality of input tokens. - num_heads (int): Number of attention heads. - output_dim (int): Dimensionality of output tokens. Defaults to embed_dim. - """ - super().__init__() - self.num_heads = num_heads - self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device) - self.to_q = nn.Linear(embed_dim, embed_dim, device=device) - self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device) - - def forward(self, x, mask): - """ - Args: - x (torch.Tensor): (B, L, D) tensor of input tokens. - mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding. - - NOTE: We assume x does not require gradients. - - Returns: - x (torch.Tensor): (B, D) tensor of pooled tokens. - """ - D = x.size(2) - - # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L). - attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L). - attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L). - - # Average non-padding token features. These will be used as the query. - x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D) - - # Concat pooled features to input sequence. - x = torch.cat([x_pool, x], dim=1) # (B, L+1, D) - - # Compute queries, keys, values. Only the mean token is used to create a query. - kv = self.to_kv(x) # (B, L+1, 2 * D) - q = self.to_q(x[:, 0]) # (B, D) - - # Extract heads. - head_dim = D // self.num_heads - kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim) - kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim) - k, v = kv.unbind(2) # (B, H, 1+L, head_dim) - q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim) - q = q.unsqueeze(2) # (B, H, 1, head_dim) - - # Compute attention. - x = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, dropout_p=0.0 - ) # (B, H, 1, head_dim) - - # Concatenate heads and run output. - x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim) - x = self.to_out(x) - return x - - class PadSplitXY(torch.autograd.Function): """ Merge heads, pad and extract visual and text tokens, diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 61b632f..6b598b7 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -352,11 +352,10 @@ class T2VSynthMochiModel: z = z + dsigma * pred comfy_pbar.update(1) - #cp_rank, cp_size = get_cp_rank_size() if batch_cfg: z = z[:B] - #z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim - self.dit.to(self.offload_device, non_blocking=True) + + self.dit.to(self.offload_device) samples = unnormalize_latents(z.float(), self.vae_mean, self.vae_std) logging.info(f"samples shape: {samples.shape}") diff --git a/nodes.py b/nodes.py index 50156f1..1eeaa40 100644 --- a/nodes.py +++ b/nodes.py @@ -354,6 +354,9 @@ class MochiSampler: "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), #"batch_cfg": ("BOOLEAN", {"default": False, "tooltip": "Enable batched cfg"}), }, + "optional": { + "cfg_schedule": ("FLOAT", {"forceInput": True,}), + } } RETURN_TYPES = ("LATENT",) @@ -361,19 +364,22 @@ class MochiSampler: FUNCTION = "process" CATEGORY = "MochiWrapper" - def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames): + def process(self, model, positive, negative, steps, cfg, seed, height, width, num_frames, cfg_schedule=None): mm.soft_empty_cache() device = mm.get_torch_device() offload_device = mm.unet_offload_device() + cfg_schedule = cfg_schedule or [cfg] * steps + logging.info(f"Using cfg schedule: {cfg_schedule}") + args = { "height": height, "width": width, "num_frames": num_frames, "mochi_args": { "sigma_schedule": linear_quadratic_schedule(steps, 0.025), - "cfg_schedule": [cfg] * steps, + "cfg_schedule": cfg_schedule, "num_inference_steps": steps, "batch_cfg": False, },