From b74aa75026d92b9d4101c0d4dc5010914fab07b4 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 20 Nov 2024 14:22:10 +0200 Subject: [PATCH] Don't use autocast with fp/bf16 --- custom_cogvideox_transformer_3d.py | 7 +++++-- model_loading.py | 3 +++ nodes.py | 5 +++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index 27f2d5f..7401cd9 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -98,6 +98,9 @@ class CogVideoXAttnProcessor2_0: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16: + hidden_states = hidden_states.to(attn.to_q.weight.dtype) + if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn": query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -124,7 +127,7 @@ class CogVideoXAttnProcessor2_0: query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - + if attention_mode == "sageattn" or attention_mode == "fused_sageattn": hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -135,7 +138,7 @@ class CogVideoXAttnProcessor2_0: hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) elif attention_mode == "comfy": hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True) - + # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout diff --git a/model_loading.py b/model_loading.py index 1402b59..805b087 100644 --- a/model_loading.py +++ b/model_loading.py @@ -391,6 +391,7 @@ class DownloadAndLoadCogVideoModel: pipeline = { "pipe": pipe, "dtype": dtype, + "quantization": quantization, "base_path": base_path, "onediff": True if compile == "onediff" else False, "cpu_offloading": enable_sequential_cpu_offload, @@ -571,6 +572,7 @@ class DownloadAndLoadCogVideoGGUFModel: pipeline = { "pipe": pipe, "dtype": vae_dtype, + "quantization": "GGUF", "base_path": model, "onediff": False, "cpu_offloading": enable_sequential_cpu_offload, @@ -802,6 +804,7 @@ class CogVideoXModelLoader: pipeline = { "pipe": pipe, "dtype": base_dtype, + "quantization": quantization, "base_path": model, "onediff": False, "cpu_offloading": enable_sequential_cpu_offload, diff --git a/nodes.py b/nodes.py index 1fea2e1..c9cc493 100644 --- a/nodes.py +++ b/nodes.py @@ -689,8 +689,9 @@ class CogVideoSampler: except: pass - autocastcondition = not model["onediff"] or not dtype == torch.float32 - autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext() + autocast_context = torch.autocast( + mm.get_autocast_device(device), dtype=dtype + ) if any(q in model["quantization"] for q in ("e4m3fn", "GGUF")) else nullcontext() with autocast_context: latents = model["pipe"]( num_inference_steps=steps,