diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 12633b1..20615be 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -93,9 +93,14 @@ class CogVideoXAttnProcessor2_0: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) + if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn": + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + else: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -240,13 +245,14 @@ class CogVideoXBlock(nn.Module): fastercache_start_step=15, fastercache_device="cuda:0", ) -> torch.Tensor: - + #print("hidden_states in block: ", hidden_states.shape) #1.5: torch.Size([2, 3200, 3072]) 10.: torch.Size([2, 6400, 3072]) text_seq_length = encoder_hidden_states.size(1) # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) + #print("norm_hidden_states in block: ", norm_hidden_states.shape) #torch.Size([2, 3200, 3072]) # Tora Motion-guidance Fuser if video_flow_feature is not None: @@ -587,13 +593,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # 2. Patch embedding p = self.config.patch_size p_t = self.config.patch_size_t + + #print("hidden_states before patch_embedding", hidden_states.shape) #torch.Size([2, 4, 16, 60, 90]) hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + #print("hidden_states after patch_embedding", hidden_states.shape) #1.5: torch.Size([2, 2926, 3072]) #1.0: torch.Size([2, 5626, 3072]) hidden_states = self.embedding_dropout(hidden_states) text_seq_length = encoder_hidden_states.shape[1] encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] + #print("hidden_states after split", hidden_states.shape) #1.5: torch.Size([2, 2700, 3072]) #1.0: torch.Size([2, 5400, 3072]) if self.use_fastercache: self.fastercache_counter+=1 diff --git a/model_loading.py b/model_loading.py index 45b6c1b..387de3f 100644 --- a/model_loading.py +++ b/model_loading.py @@ -89,13 +89,13 @@ class DownloadAndLoadCogVideoModel: "precision": (["fp16", "fp32", "bf16"], {"default": "bf16", "tooltip": "official recommendation is that 2b model should be fp16, 5b model should be bf16"} ), - "fp8_transformer": (['disabled', 'enabled', 'fastmode'], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}), + "fp8_transformer": (['disabled', 'enabled', 'fastmode', 'torchao_fp8dq', "torchao_fp6"], {"default": 'disabled', "tooltip": "enabled casts the transformer to torch.float8_e4m3fn, fastmode is only for latest nvidia GPUs and requires torch 2.4.0 and cu124 minimum"}), "compile": (["disabled","onediff","torch"], {"tooltip": "compile the model for faster inference, these are advanced options only available on Linux, see readme for more info"}), "enable_sequential_cpu_offload": ("BOOLEAN", {"default": False, "tooltip": "significantly reducing memory usage and slows down the inference"}), "block_edit": ("TRANSFORMERBLOCKS", {"default": None}), "lora": ("COGLORA", {"default": None}), "compile_args":("COMPILEARGS", ), - "attention_mode": (["sdpa", "sageattn"], {"default": "sdpa"}), + "attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}), "load_device": (["main_device", "offload_device"], {"default": "main_device"}), } } @@ -111,10 +111,11 @@ class DownloadAndLoadCogVideoModel: attention_mode="sdpa", load_device="main_device"): if precision == "fp16" and "1.5" in model: - raise ValueError("1.5 models do not work in fp16") + raise ValueError("1.5 models do not currently work in fp16") device = mm.get_torch_device() offload_device = mm.unet_offload_device() + manual_offloading = True transformer_load_device = device if load_device == "main_device" else offload_device mm.soft_empty_cache() @@ -189,7 +190,6 @@ class DownloadAndLoadCogVideoModel: transformer = CogVideoXTransformer3DModel.from_pretrained(base_path, subfolder=subfolder) transformer = transformer.to(dtype).to(transformer_load_device) - transformer.attention_mode = attention_mode if "1.5" in model: transformer.config.sample_height = 300 @@ -202,7 +202,7 @@ class DownloadAndLoadCogVideoModel: scheduler_config = json.load(f) scheduler = CogVideoXDDIMScheduler.from_config(scheduler_config) - #VAE + # VAE if "Fun" in model: vae = AutoencoderKLCogVideoXFun.from_pretrained(base_path, subfolder="vae").to(dtype).to(offload_device) if "Pose" in model: @@ -263,13 +263,75 @@ class DownloadAndLoadCogVideoModel: if "1.5" in model: params_to_keep.update({"ff"}) #otherwise NaNs convert_fp8_linear(pipe.transformer, dtype, params_to_keep=params_to_keep) + + elif "torchao" in fp8_transformer: + try: + from torchao.quantization import ( + quantize_, + fpx_weight_only, + float8_dynamic_activation_float8_weight + ) + except: + raise ImportError("torchao is not installed, please install torchao to use fp8dq") + + def filter_fn(module: nn.Module, fqn: str) -> bool: + target_submodules = {'attn1', 'ff'} # avoid norm layers, 1.5 at least won't work with quantized norm1 #todo: test other models + if any(sub in fqn for sub in target_submodules): + return isinstance(module, nn.Linear) + return False + + if "fp6" in fp8_transformer: #slower for some reason on 4090 + quant_func = fpx_weight_only(3, 2) + elif "fp8dq" in fp8_transformer: #very fast on 4090 when compiled + quant_func = float8_dynamic_activation_float8_weight() + + for i, block in enumerate(pipe.transformer.transformer_blocks): + if "CogVideoXBlock" in str(block): + quantize_(block, quant_func, filter_fn=filter_fn) + + manual_offloading = False # to disable manual .to(device) calls if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() + manual_offloading = False + + # CogVideoXBlock( + # (norm1): CogVideoXLayerNormZero( + # (silu): SiLU() + # (linear): Linear(in_features=512, out_features=18432, bias=True) + # (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True) + # ) + # (attn1): Attention( + # (norm_q): LayerNorm((64,), eps=1e-06, elementwise_affine=True) + # (norm_k): LayerNorm((64,), eps=1e-06, elementwise_affine=True) + # (to_q): Linear(in_features=3072, out_features=3072, bias=True) + # (to_k): Linear(in_features=3072, out_features=3072, bias=True) + # (to_v): Linear(in_features=3072, out_features=3072, bias=True) + # (to_out): ModuleList( + # (0): Linear(in_features=3072, out_features=3072, bias=True) + # (1): Dropout(p=0.0, inplace=False) + # ) + # ) + # (norm2): CogVideoXLayerNormZero( + # (silu): SiLU() + # (linear): Linear(in_features=512, out_features=18432, bias=True) + # (norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True) + # ) + # (ff): FeedForward( + # (net): ModuleList( + # (0): GELU( + # (proj): Linear(in_features=3072, out_features=12288, bias=True) + # ) + # (1): Dropout(p=0.0, inplace=False) + # (2): Linear(in_features=12288, out_features=3072, bias=True) + # (3): Dropout(p=0.0, inplace=False) + # ) + # ) + # ) # compilation if compile == "torch": - pipe.transformer.to(memory_format=torch.channels_last) + #pipe.transformer.to(memory_format=torch.channels_last) if compile_args is not None: torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"] for i, block in enumerate(pipe.transformer.transformer_blocks): @@ -279,7 +341,16 @@ class DownloadAndLoadCogVideoModel: for i, block in enumerate(pipe.transformer.transformer_blocks): if "CogVideoXBlock" in str(block): pipe.transformer.transformer_blocks[i] = torch.compile(block, fullgraph=False, dynamic=False, backend="inductor") - + + transformer.attention_mode = attention_mode + + if "fused" in attention_mode: + from diffusers.models.attention import Attention + transformer.fuse_qkv_projections = True + for module in transformer.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + elif compile == "onediff": from onediffx import compile_pipe os.environ['NEXFORT_FX_FORCE_TRITON_SDPA'] = '1' @@ -298,8 +369,9 @@ class DownloadAndLoadCogVideoModel: "base_path": base_path, "onediff": True if compile == "onediff" else False, "cpu_offloading": enable_sequential_cpu_offload, + "manual_offloading": manual_offloading, "scheduler_config": scheduler_config, - "model_name": model + "model_name": model, } return (pipeline,) @@ -515,7 +587,8 @@ class DownloadAndLoadCogVideoGGUFModel: "onediff": False, "cpu_offloading": enable_sequential_cpu_offload, "scheduler_config": scheduler_config, - "model_name": model + "model_name": model, + "manual_offloading": True, } return (pipeline,) diff --git a/nodes.py b/nodes.py index 8feba51..8d7257e 100644 --- a/nodes.py +++ b/nodes.py @@ -819,7 +819,7 @@ class CogVideoSampler: dtype = pipeline["dtype"] scheduler_config = pipeline["scheduler_config"] - if not pipeline["cpu_offloading"]: + if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]: pipe.transformer.to(device) generator = torch.Generator(device=torch.device("cpu")).manual_seed(seed) @@ -890,7 +890,7 @@ class CogVideoSampler: controlnet=controlnet, tora=tora_trajectory if tora_trajectory is not None else None, ) - if not pipeline["cpu_offloading"]: + if not pipeline["cpu_offloading"] and pipeline["manual_offloading"]: pipe.transformer.to(offload_device) if fastercache is not None: