diff --git a/mochi_preview/dit/joint_model/asymm_models_joint.py b/mochi_preview/dit/joint_model/asymm_models_joint.py index b0f89cc..1de23ef 100644 --- a/mochi_preview/dit/joint_model/asymm_models_joint.py +++ b/mochi_preview/dit/joint_model/asymm_models_joint.py @@ -36,11 +36,6 @@ try: FLASH_ATTN_IS_AVAILABLE = True except ImportError: FLASH_ATTN_IS_AVAILABLE = False -try: - from sageattention import sageattn - SAGEATTN_IS_AVAILABLE = True -except ImportError: - SAGEATTN_IS_AVAILABLE = False COMPILE_FINAL_LAYER = False #os.environ.get("COMPILE_DIT") == "1" COMPILE_MMDIT_BLOCK = False #os.environ.get("COMPILE_DIT") == "1" @@ -60,8 +55,9 @@ class AsymmetricAttention(nn.Module): attend_to_padding: bool = False, softmax_scale: Optional[float] = None, device: Optional[torch.device] = None, - attention_mode: str = "sdpa", - + clip_feat_dim: Optional[int] = None, + pooled_caption_mlp_bias: bool = True, + use_transformer_engine: bool = False, ): super().__init__() self.dim_x = dim_x @@ -72,7 +68,6 @@ class AsymmetricAttention(nn.Module): self.update_y = update_y self.attend_to_padding = attend_to_padding self.softmax_scale = softmax_scale - self.attention_mode = attention_mode if dim_x % num_heads != 0: raise ValueError( f"dim_x={dim_x} should be divisible by num_heads={num_heads}" @@ -165,43 +160,6 @@ class AsymmetricAttention(nn.Module): ) return qkv - - def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim): - with torch.autocast("cuda", enabled=False): - out: torch.Tensor = flash_attn_varlen_qkvpacked_func( - qkv, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen_in_batch, - dropout_p=0.0, - softmax_scale=self.softmax_scale, - ) # (total, local_heads, head_dim) - return out.view(total, local_dim) - - def sdpa_attention(self, qkv): - q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) - with torch.autocast("cuda", enabled=False): - out = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - return rearrange(out, 'b h s d -> s (b h d)') - - def sage_attention(self, qkv): - q, k, v = rearrange(qkv, '(b s) t h d -> t b h s d', b=1) - with torch.autocast("cuda", enabled=False): - out = sageattn( - q, - k, - v, - attn_mask=None, - dropout_p=0.0, - is_causal=False - ) - return rearrange(out, 'b h s d -> s (b h d)') @torch.compiler.disable() def run_attention( @@ -222,13 +180,34 @@ class AsymmetricAttention(nn.Module): local_dim = local_heads * self.head_dim total = qkv.size(0) - if self.attention_mode == "flash_attn": - out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim) - elif self.attention_mode == "sdpa": - out = self.sdpa_attention(qkv) - elif self.attention_mode == "sage_attn": - out = self.sage_attention(qkv) - + if FLASH_ATTN_IS_AVAILABLE: + with torch.autocast("cuda", enabled=False): + out: torch.Tensor = flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen_in_batch, + dropout_p=0.0, + softmax_scale=self.softmax_scale, + ) # (total, local_heads, head_dim) + out = out.view(total, local_dim) + else: + raise NotImplementedError("Flash attention is currently required.") + print("qkv: ",qkv.shape, qkv.dtype, qkv.device) + expected_size = 2 * 44520 * 3 * 24 * 128 + actual_size = qkv.numel() + print(f"Expected size: {expected_size}, Actual size: {actual_size}") + q, k, v = qkv.reshape(B, N, 3, local_heads, self.head_dim).permute(2, 0, 3, 1, 4) + with torch.autocast("cuda", enabled=False): + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + out = out.transpose(1, 2).reshape(B, -1, local_heads * self.head_dim) + x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype) assert x.size() == (B, N, local_dim) assert y.size() == (B, L, local_dim) @@ -305,20 +284,18 @@ class AsymmetricJointBlock(nn.Module): mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens. update_y: bool = True, # Whether to update text tokens in this block. device: Optional[torch.device] = None, - attention_mode: str = "sdpa", **block_kwargs, ): super().__init__() self.update_y = update_y self.hidden_size_x = hidden_size_x self.hidden_size_y = hidden_size_y - self.attention_mode = attention_mode self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device) if self.update_y: self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device) else: self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device) - + # Self-attention: self.attn = AsymmetricAttention( hidden_size_x, @@ -326,7 +303,6 @@ class AsymmetricJointBlock(nn.Module): num_heads=num_heads, update_y=update_y, device=device, - attention_mode=attention_mode, **block_kwargs, ) @@ -473,7 +449,6 @@ class AsymmDiTJoint(nn.Module): use_extended_posenc: bool = False, rope_theta: float = 10000.0, device: Optional[torch.device] = None, - attention_mode: str = "sdpa", **block_kwargs, ): super().__init__() @@ -541,7 +516,6 @@ class AsymmDiTJoint(nn.Module): mlp_ratio_y=mlp_ratio_y, update_y=update_y, device=device, - attention_mode=attention_mode, **block_kwargs, ) diff --git a/mochi_preview/t2v_synth_mochi.py b/mochi_preview/t2v_synth_mochi.py index 3561859..2d7c318 100644 --- a/mochi_preview/t2v_synth_mochi.py +++ b/mochi_preview/t2v_synth_mochi.py @@ -119,7 +119,6 @@ class T2VSynthMochiModel: dit_checkpoint_path: str, weight_dtype: torch.dtype = torch.float8_e4m3fn, fp8_fastmode: bool = False, - attention_mode: str = "sdpa" ): super().__init__() self.device = device @@ -145,7 +144,6 @@ class T2VSynthMochiModel: t5_feat_dim=4096, t5_token_length=256, rope_theta=10000.0, - attention_mode=attention_mode, ) params_to_keep = {"t_embedder", "x_embedder", "pos_frequencies", "t5", "norm"} diff --git a/nodes.py b/nodes.py index 76ac104..3558955 100644 --- a/nodes.py +++ b/nodes.py @@ -60,8 +60,7 @@ class DownloadAndLoadMochiModel: {"tooltip": "Downloads from 'https://huggingface.co/Kijai/Mochi_preview_comfy' to 'models/vae/mochi'", }, ), "precision": (["fp8_e4m3fn","fp8_e4m3fn_fast","fp16", "fp32", "bf16"], - {"default": "fp8_e4m3fn", }), - "attention_mode": (["sdpa","flash_attn","sage_attn"], + {"default": "fp8_e4m3fn", } ), }, } @@ -72,7 +71,7 @@ class DownloadAndLoadMochiModel: CATEGORY = "MochiWrapper" DESCRIPTION = "Downloads and loads the selected Mochi model from Huggingface" - def loadmodel(self, model, vae, precision, attention_mode): + def loadmodel(self, model, vae, precision): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -116,7 +115,6 @@ class DownloadAndLoadMochiModel: dit_checkpoint_path=model_path, weight_dtype=dtype, fp8_fastmode = True if precision == "fp8_e4m3fn_fast" else False, - attention_mode=attention_mode ) with (init_empty_weights() if is_accelerate_available else nullcontext()): vae = Decoder(