mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-08 20:34:23 +08:00
add comfyui attention mode
This commit is contained in:
parent
cac1f81c51
commit
882faa6dea
@ -40,11 +40,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
|
||||
SAGEATTN_IS_AVAILABLE = True
|
||||
except:
|
||||
SAGEATTN_IS_AVAILABLE = False
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
@torch.compiler.disable()
|
||||
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
|
||||
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
|
||||
@ -126,12 +127,12 @@ class CogVideoXAttnProcessor2_0:
|
||||
|
||||
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
|
||||
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||
else:
|
||||
elif attention_mode == "sdpa":
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
#if torch.isinf(hidden_states).any():
|
||||
# raise ValueError(f"hidden_states after dot product has inf")
|
||||
elif attention_mode == "comfy":
|
||||
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
|
||||
@ -123,7 +123,7 @@ class DownloadAndLoadCogVideoModel:
|
||||
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
|
||||
"lora": ("COGLORA", {"default": None}),
|
||||
"compile_args":("COMPILEARGS", ),
|
||||
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
|
||||
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn", "comfy"], {"default": "sdpa"}),
|
||||
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -395,7 +395,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
||||
num_videos_per_prompt = 1
|
||||
|
||||
self.num_frames = num_frames
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user