mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-10 05:14:22 +08:00
add comfyui attention mode
This commit is contained in:
parent
cac1f81c51
commit
882faa6dea
@ -40,11 +40,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from sageattention import sageattn
|
from sageattention import sageattn
|
||||||
|
|
||||||
SAGEATTN_IS_AVAILABLE = True
|
SAGEATTN_IS_AVAILABLE = True
|
||||||
except:
|
except:
|
||||||
SAGEATTN_IS_AVAILABLE = False
|
SAGEATTN_IS_AVAILABLE = False
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
@torch.compiler.disable()
|
@torch.compiler.disable()
|
||||||
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
|
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
|
||||||
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
|
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
|
||||||
@ -126,12 +127,12 @@ class CogVideoXAttnProcessor2_0:
|
|||||||
|
|
||||||
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
||||||
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||||
else:
|
elif attention_mode == "sdpa":
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
)
|
)
|
||||||
#if torch.isinf(hidden_states).any():
|
elif attention_mode == "comfy":
|
||||||
# raise ValueError(f"hidden_states after dot product has inf")
|
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
|
||||||
|
|||||||
@ -123,7 +123,7 @@ class DownloadAndLoadCogVideoModel:
|
|||||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||||
"lora": ("COGLORA", {"default": None}),
|
"lora": ("COGLORA", {"default": None}),
|
||||||
"compile_args":("COMPILEARGS", ),
|
"compile_args":("COMPILEARGS", ),
|
||||||
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
|
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn", "comfy"], {"default": "sdpa"}),
|
||||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -395,7 +395,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|||||||
|
|
||||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||||
num_videos_per_prompt = 1
|
|
||||||
|
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user