From 00a550e81c2f0fb5f8c4235272b8485da6c09478 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 24 Oct 2024 02:25:57 +0300 Subject: [PATCH] Should work without flash_attn (thanks @logtd), add sage_attn tested to work in Linux at least --- .../dit/joint_model/asymm_models_joint.py | 90 ++++++++++++------- mochi_preview/t2v_synth_mochi.py | 2 + nodes.py | 6 +- 3 files changed, 64 insertions(+), 34 deletions(-) diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index 1de23ef..b0f89cc 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -36,6 +36,11 @@ try: FLASH_ATTN_IS_AVAILABLE = True except ImportError: FLASH_ATTN_IS_AVAILABLE = False +try: + from sageattention import sageattn + SAGEATTN_IS_AVAILABLE = True +except ImportError: + SAGEATTN_IS_AVAILABLE = False COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1" COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1" @@ -55,9 +60,8 @@ class AsymmetricAttention(nn.Module): attend_to_padding: bool = False, softmax_scale: Optional[float] = None, device: Optional[torch.device] = None, - clip_feat_dim: Optional[int] = None, - pooled_caption_mlp_bias: bool = True, - use_transformer_engine: bool = False, + attention_mode: str = "sdpa", + ): super().__init__() self.dim_x = dim_x @@ -68,6 +72,7 @@ class AsymmetricAttention(nn.Module): self.update_y = update_y self.attend_to_padding = attend_to_padding self.softmax_scale = softmax_scale + self.attention_mode = attention_mode if dim_x % num_heads != 0: raise ValueError( f"dim_x={dim_x} should be divisible by num_heads={num_heads}" @@ -160,6 +165,43 @@ class AsymmetricAttention(nn.Module): ) return qkv + + def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim): + with torch.autocast("cuda", enabled=False): + out: torch.Tensor = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen_in_batch, + dropout_p=0.0, + softmax_scale=self.softmax_scale, + ) # (total, local_heads, head_dim) + return out.view(total, local_dim) + + def sdpa_attention(self, qkv): + q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) + with torch.autocast("cuda", enabled=False): + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + return rearrange(out, 'b h s d -> s (b h d)') + + def sage_attention(self, qkv): + q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) + with torch.autocast("cuda", enabled=False): + out = sageattn( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + return rearrange(out, 'b h s d -> s (b h d)') @torch.compiler.disable() def run_attention( @@ -180,34 +222,13 @@ class AsymmetricAttention(nn.Module): local_dim = local_heads * self.head_dim total = qkv.size(0) - if FLASH_ATTN_IS_AVAILABLE: - with torch.autocast("cuda", enabled=False): - out: torch.Tensor = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen_in_batch, - dropout_p=0.0, - softmax_scale=self.softmax_scale, - ) # (total, local_heads, head_dim) - out = out.view(total, local_dim) - else: - raise NotImplementedError("Flash attention is currently required.") - print("qkv: ",qkv.shape, qkv.dtype, qkv.device) - expected_size = 2 * 44520 * 3 * 24 * 128 - actual_size = qkv.numel() - print(f"Expected size: {expected_size}, Actual size: {actual_size}") - q, k, v = qkv.reshape(B, N, 3, local_heads, self.head_dim).permute(2, 0, 3, 1, 4) - with torch.autocast("cuda", enabled=False): - out = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - out = out.transpose(1, 2).reshape(B, -1, local_heads * self.head_dim) - + if self.attention_mode == "flash_attn": + out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim) + elif self.attention_mode == "sdpa": + out = self.sdpa_attention(qkv) + elif self.attention_mode == "sage_attn": + out = self.sage_attention(qkv) + x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) assert x.size() == (B, N, local_dim) assert y.size() == (B, L, local_dim) @@ -284,18 +305,20 @@ class AsymmetricJointBlock(nn.Module): mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens. update_y: bool = True, # Whether to update text tokens in this block. device: Optional[torch.device] = None, + attention_mode: str = "sdpa", **block_kwargs, ): super().__init__() self.update_y = update_y self.hidden_size_x = hidden_size_x self.hidden_size_y = hidden_size_y + self.attention_mode = attention_mode self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device) if self.update_y: self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device) else: self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device) - + # Self-attention: self.attn = AsymmetricAttention( hidden_size_x, @@ -303,6 +326,7 @@ class AsymmetricJointBlock(nn.Module): num_heads=num_heads, update_y=update_y, device=device, + attention_mode=attention_mode, **block_kwargs, ) @@ -449,6 +473,7 @@ class AsymmDiTJoint(nn.Module): use_extended_posenc: bool = False, rope_theta: float = 10000.0, device: Optional[torch.device] = None, + attention_mode: str = "sdpa", **block_kwargs, ): super().__init__() @@ -516,6 +541,7 @@ class AsymmDiTJoint(nn.Module): mlp_ratio_y=mlp_ratio_y, update_y=update_y, device=device, + attention_mode=attention_mode, **block_kwargs, ) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 2d7c318..3561859 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -119,6 +119,7 @@ class T2VSynthMochiModel: dit_checkpoint_path: str, weight_dtype: torch.dtype = torch.float8_e4m3fn, fp8_fastmode: bool = False, + attention_mode: str = "sdpa" ): super().__init__() self.device = device @@ -144,6 +145,7 @@ class T2VSynthMochiModel: t5_feat_dim=4096, t5_token_length=256, rope_theta=10000.0, + attention_mode=attention_mode, ) params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} diff --git a/nodes.py b/nodes.py index 3558955..76ac104 100644 --- a/nodes.py +++ b/nodes.py @@ -60,7 +60,8 @@ class DownloadAndLoadMochiModel: {"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", }, ), "precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"], - {"default": "fp8_e4m3fn", } + {"default": "fp8_e4m3fn", }), + "attention_mode": (["sdpa","flash_attn","sage_attn"], ), }, } @@ -71,7 +72,7 @@ class DownloadAndLoadMochiModel: CATEGORY = "MochiWrapper" DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface" - def loadmodel(self, model, vae, precision): + def loadmodel(self, model, vae, precision, attention_mode): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -115,6 +116,7 @@ class DownloadAndLoadMochiModel: dit_checkpoint_path=model_path, weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, + attention_mode=attention_mode ) with (init_empty_weights() if is_accelerate_available else nullcontext()): vae = Decoder(