mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
Don't use autocast with fp/bf16
This commit is contained in:
parent
b9f7b6e338
commit
b74aa75026
@ -98,6 +98,9 @@ class CogVideoXAttnProcessor2_0:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.to_q.weight.dtype == torch.float16 or attn.to_q.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(attn.to_q.weight.dtype)
|
||||
|
||||
if attention_mode != "fused_sdpa" or attention_mode != "fused_sageattn":
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
@ -124,7 +127,7 @@ class CogVideoXAttnProcessor2_0:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
|
||||
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
||||
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
@ -135,7 +138,7 @@ class CogVideoXAttnProcessor2_0:
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
elif attention_mode == "comfy":
|
||||
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
|
||||
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
|
||||
@ -391,6 +391,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
pipeline = {
|
||||
"pipe": pipe,
|
||||
"dtype": dtype,
|
||||
"quantization": quantization,
|
||||
"base_path": base_path,
|
||||
"onediff": True if compile == "onediff" else False,
|
||||
"cpu_offloading": enable_sequential_cpu_offload,
|
||||
@ -571,6 +572,7 @@ class DownloadAndLoadCogVideoGGUFModel:
|
||||
pipeline = {
|
||||
"pipe": pipe,
|
||||
"dtype": vae_dtype,
|
||||
"quantization": "GGUF",
|
||||
"base_path": model,
|
||||
"onediff": False,
|
||||
"cpu_offloading": enable_sequential_cpu_offload,
|
||||
@ -802,6 +804,7 @@ class CogVideoXModelLoader:
|
||||
pipeline = {
|
||||
"pipe": pipe,
|
||||
"dtype": base_dtype,
|
||||
"quantization": quantization,
|
||||
"base_path": model,
|
||||
"onediff": False,
|
||||
"cpu_offloading": enable_sequential_cpu_offload,
|
||||
|
||||
5
nodes.py
5
nodes.py
@ -689,8 +689,9 @@ class CogVideoSampler:
|
||||
except:
|
||||
pass
|
||||
|
||||
autocastcondition = not model["onediff"] or not dtype == torch.float32
|
||||
autocast_context = torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocastcondition else nullcontext()
|
||||
autocast_context = torch.autocast(
|
||||
mm.get_autocast_device(device), dtype=dtype
|
||||
) if any(q in model["quantization"] for q in ("e4m3fn", "GGUF")) else nullcontext()
|
||||
with autocast_context:
|
||||
latents = model["pipe"](
|
||||
num_inference_steps=steps,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user