Don't use autocast with fp/bf16

This commit is contained in:
kijai 2024-11-20 14:22:10 +02:00
parent b9f7b6e338
commit b74aa75026
3 changed files with 11 additions and 4 deletions

View File

@ -98,6 +98,9 @@ class CogVideoXAttnProcessor2_0:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
@ -124,7 +127,7 @@ class CogVideoXAttnProcessor2_0:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@ -135,7 +138,7 @@ class CogVideoXAttnProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
elif attention_mode == "comfy":
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout

View File

@ -391,6 +391,7 @@ class DownloadAndLoadCogVideoModel:
pipeline = {
"pipe": pipe,
"dtype": dtype,
"quantization": quantization,
"base_path": base_path,
"onediff": True if compile == "onediff" else False,
"cpu_offloading": enable_sequential_cpu_offload,
@ -571,6 +572,7 @@ class DownloadAndLoadCogVideoGGUFModel:
pipeline = {
"pipe": pipe,
"dtype": vae_dtype,
"quantization": "GGUF",
"base_path": model,
"onediff": False,
"cpu_offloading": enable_sequential_cpu_offload,
@ -802,6 +804,7 @@ class CogVideoXModelLoader:
pipeline = {
"pipe": pipe,
"dtype": base_dtype,
"quantization": quantization,
"base_path": model,
"onediff": False,
"cpu_offloading": enable_sequential_cpu_offload,

View File

@ -689,8 +689,9 @@ class CogVideoSampler:
except:
pass
autocastcondition = not model["onediff"] or not dtype == torch.float32
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
autocast_context = torch.autocast(
mm.get_autocast_device(device), dtype=dtype
) if any(q in model["quantization"] for q in ("e4m3fn", "GGUF")) else nullcontext()
with autocast_context:
latents = model["pipe"](
num_inference_steps=steps,