From 634c22db505716a8a828846b3086ed3267746f5a Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 9 Nov 2024 17:05:55 +0200 Subject: [PATCH] sageattn --- custom_cogvideox_transformer_3d.py | 45 +++++++++++++++++------------- fp8_optimization.py | 2 -- model_loading.py | 10 +++++-- 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 1003aa7..79e2ebb 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -32,6 +32,7 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero from diffusers.loaders import PeftAdapterMixin +from diffusers.models.embeddings import apply_rotary_emb from .embeddings import CogVideoXPatchEmbed @@ -40,9 +41,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name try: from sageattention import sageattn SAGEATTN_IS_AVAILABLE = True - logger.info("Using sageattn") except: - logger.info("sageattn not found, using sdpa") SAGEATTN_IS_AVAILABLE = False def fft(tensor): @@ -73,7 +72,6 @@ class CogVideoXAttnProcessor2_0: raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @torch.compiler.disable() - def __call__( self, attn: Attention, @@ -81,6 +79,7 @@ class CogVideoXAttnProcessor2_0: encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + attention_mode: Optional[str] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -112,20 +111,21 @@ class CogVideoXAttnProcessor2_0: # Apply RoPE if needed if image_rotary_emb is not None: - from diffusers.models.embeddings import apply_rotary_emb - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - - #if SAGEATTN_IS_AVAILABLE: - # hidden_states = sageattn(query, key, value, is_causal=False) - #else: - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - if torch.isinf(hidden_states).any(): - raise ValueError(f"hidden_states after dot product has inf") + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + if attention_mode == "sageattn": + if SAGEATTN_IS_AVAILABLE: + hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) + else: + raise ImportError("sageattn not found") + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + #if torch.isinf(hidden_states).any(): + # raise ValueError(f"hidden_states after dot product has inf") hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -193,6 +193,7 @@ class CogVideoXBlock(nn.Module): ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, + attention_mode: Optional[str] = None, ): super().__init__() @@ -224,6 +225,7 @@ class CogVideoXBlock(nn.Module): ) self.cached_hidden_states = [] self.cached_encoder_hidden_states = [] + self.attention_mode = attention_mode def forward( self, @@ -235,7 +237,7 @@ class CogVideoXBlock(nn.Module): fuser=None, fastercache_counter=0, fastercache_start_step=15, - fastercache_device="cuda:0" + fastercache_device="cuda:0", ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -271,7 +273,8 @@ class CogVideoXBlock(nn.Module): attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb + image_rotary_emb=image_rotary_emb, + attention_mode=self.attention_mode, ) if fastercache_counter == fastercache_start_step: self.cached_hidden_states = [attn_hidden_states.to(fastercache_device), attn_hidden_states.to(fastercache_device)] @@ -386,6 +389,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): use_rotary_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False, patch_bias: bool = True, + attention_mode: Optional[str] = None, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -471,6 +475,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): self.fastercache_lf_step = 40 self.fastercache_hf_step = 30 self.fastercache_device = "cuda" + self.attention_mode = attention_mode def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value @@ -667,9 +672,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): fastercache_counter = self.fastercache_counter, fastercache_device = self.fastercache_device ) - has_nan = torch.isnan(hidden_states).any() - if has_nan: - raise ValueError(f"block output hidden_states has nan: {has_nan}") + #has_nan = torch.isnan(hidden_states).any() + #if has_nan: + # raise ValueError(f"block output hidden_states has nan: {has_nan}") if (controlnet_states is not None) and (i < len(controlnet_states)): controlnet_states_block = controlnet_states[i] diff --git a/fp8_optimization.py b/fp8_optimization.py index 05b0146..09f026d 100644 --- a/fp8_optimization.py +++ b/fp8_optimization.py @@ -39,11 +39,9 @@ def fp8_linear_forward(cls, original_dtype, input): def convert_fp8_linear(module, original_dtype, params_to_keep={}): setattr(module, "fp8_matmul_enabled", True) - for name, module in module.named_modules(): if not any(keyword in name for keyword in params_to_keep): if isinstance(module, nn.Linear): - print(name) original_forward = module.forward setattr(module, "original_forward", original_forward) setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) diff --git a/model_loading.py b/model_loading.py index 15166c6..dbc5804 100644 --- a/model_loading.py +++ b/model_loading.py @@ -60,6 +60,7 @@ class CogVideoLoraSelect: cog_loras_list.append(cog_lora) print(cog_loras_list) return (cog_loras_list,) + #region DownloadAndLoadCogVideoModel class DownloadAndLoadCogVideoModel: @classmethod @@ -98,6 +99,7 @@ class DownloadAndLoadCogVideoModel: "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "lora": ("COGLORA", {"default": None}), "compile_args":("COMPILEARGS", ), + "attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}), } } @@ -108,9 +110,9 @@ class DownloadAndLoadCogVideoModel: CATEGORY = "CogVideoWrapper" DESCRIPTION = "Downloads and loads the selected CogVideo model from Huggingface to 'ComfyUI/models/CogVideo'" - def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None, load_device="main_device"): - - check_diffusers_version() + def loadmodel(self, model, precision, fp8_transformer="disabled", compile="disabled", + enable_sequential_cpu_offload=False, pab_config=None, block_edit=None, lora=None, compile_args=None, + attention_mode="sdpa", load_device="main_device"): device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -195,6 +197,8 @@ class DownloadAndLoadCogVideoModel: transformer = transformer.to(dtype).to(transformer_load_device) + transformer.attention_mode = attention_mode + if block_edit is not None: transformer = remove_specific_blocks(transformer, block_edit)