Don't use autocast with fp/bf16

This commit is contained in:
kijai 2024-11-20 14:22:10 +02:00
parent b9f7b6e338
commit b74aa75026
3 changed files with 11 additions and 4 deletions

View File

@ -98,6 +98,9 @@ class CogVideoXAttnProcessor2_0:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn": if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states) key = attn.to_k(hidden_states)
@ -124,7 +127,7 @@ class CogVideoXAttnProcessor2_0:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention: if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if attention_mode == "sageattn" or attention_mode == "fused_sageattn": if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@ -135,7 +138,7 @@ class CogVideoXAttnProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif attention_mode == "comfy": elif attention_mode == "comfy":
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True) hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout

View File

@ -391,6 +391,7 @@ class DownloadAndLoadCogVideoModel:
pipeline = { pipeline = {
"pipe": pipe, "pipe": pipe,
"dtype": dtype, "dtype": dtype,
"quantization": quantization,
"base_path": base_path, "base_path": base_path,
"onediff": True if compile == "onediff" else False, "onediff": True if compile == "onediff" else False,
"cpu_offloading": enable_sequential_cpu_offload, "cpu_offloading": enable_sequential_cpu_offload,
@ -571,6 +572,7 @@ class DownloadAndLoadCogVideoGGUFModel:
pipeline = { pipeline = {
"pipe": pipe, "pipe": pipe,
"dtype": vae_dtype, "dtype": vae_dtype,
"quantization": "GGUF",
"base_path": model, "base_path": model,
"onediff": False, "onediff": False,
"cpu_offloading": enable_sequential_cpu_offload, "cpu_offloading": enable_sequential_cpu_offload,
@ -802,6 +804,7 @@ class CogVideoXModelLoader:
pipeline = { pipeline = {
"pipe": pipe, "pipe": pipe,
"dtype": base_dtype, "dtype": base_dtype,
"quantization": quantization,
"base_path": model, "base_path": model,
"onediff": False, "onediff": False,
"cpu_offloading": enable_sequential_cpu_offload, "cpu_offloading": enable_sequential_cpu_offload,

View File

@ -689,8 +689,9 @@ class CogVideoSampler:
except: except:
pass pass
autocastcondition = not model["onediff"] or not dtype == torch.float32 autocast_context = torch.autocast(
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext() mm.get_autocast_device(device), dtype=dtype
) if any(q in model["quantization"] for q in ("e4m3fn", "GGUF")) else nullcontext()
with autocast_context: with autocast_context:
latents = model["pipe"]( latents = model["pipe"](
num_inference_steps=steps, num_inference_steps=steps,