Don't use autocast with fp/bf16

This commit is contained in:
kijai 2024-11-20 14:22:10 +02:00
parent b9f7b6e338
commit b74aa75026
3 changed files with 11 additions and 4 deletions

View File

@ -98,6 +98,9 @@ class CogVideoXAttnProcessor2_0:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn": if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states) key = attn.to_k(hidden_states)

View File

@ -391,6 +391,7 @@ class DownloadAndLoadCogVideoModel:
pipeline = { pipeline = {
"pipe": pipe, "pipe": pipe,
"dtype": dtype, "dtype": dtype,
"quantization": quantization,
"base_path": base_path, "base_path": base_path,
"onediff": True if compile == "onediff" else False, "onediff": True if compile == "onediff" else False,
"cpu_offloading": enable_sequential_cpu_offload, "cpu_offloading": enable_sequential_cpu_offload,
@ -571,6 +572,7 @@ class DownloadAndLoadCogVideoGGUFModel:
pipeline = { pipeline = {
"pipe": pipe, "pipe": pipe,
"dtype": vae_dtype, "dtype": vae_dtype,
"quantization": "GGUF",
"base_path": model, "base_path": model,
"onediff": False, "onediff": False,
"cpu_offloading": enable_sequential_cpu_offload, "cpu_offloading": enable_sequential_cpu_offload,
@ -802,6 +804,7 @@ class CogVideoXModelLoader:
pipeline = { pipeline = {
"pipe": pipe, "pipe": pipe,
"dtype": base_dtype, "dtype": base_dtype,
"quantization": quantization,
"base_path": model, "base_path": model,
"onediff": False, "onediff": False,
"cpu_offloading": enable_sequential_cpu_offload, "cpu_offloading": enable_sequential_cpu_offload,

View File

@ -689,8 +689,9 @@ class CogVideoSampler:
except: except:
pass pass
autocastcondition = not model["onediff"] or not dtype == torch.float32 autocast_context = torch.autocast(
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext() mm.get_autocast_device(device), dtype=dtype
) if any(q in model["quantization"] for q in ("e4m3fn", "GGUF")) else nullcontext()
with autocast_context: with autocast_context:
latents = model["pipe"]( latents = model["pipe"](
num_inference_steps=steps, num_inference_steps=steps,