add comfyui attention mode

This commit is contained in:
kijai 2024-11-19 19:55:51 +02:00
parent cac1f81c51
commit 882faa6dea
3 changed files with 7 additions and 7 deletions

View File

@ -40,11 +40,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try:
from sageattention import sageattn
SAGEATTN_IS_AVAILABLE = True
except:
SAGEATTN_IS_AVAILABLE = False
from comfy.ldm.modules.attention import optimized_attention
@torch.compiler.disable()
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
@ -126,12 +127,12 @@ class CogVideoXAttnProcessor2_0:
if attention_mode == "sageattn" or attention_mode == "fused_sageattn":
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
else:
elif attention_mode == "sdpa":
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
#if torch.isinf(hidden_states).any():
# raise ValueError(f"hidden_states after dot product has inf")
elif attention_mode == "comfy":
hidden_states = optimized_attention(query, key, value, mask=attention_mask, heads=attn.heads, skip_reshape=True)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

View File

@ -123,7 +123,7 @@ class DownloadAndLoadCogVideoModel:
"block_edit": ("TRANSFORMERBLOCKS", {"default": None}),
"lora": ("COGLORA", {"default": None}),
"compile_args":("COMPILEARGS", ),
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn"], {"default": "sdpa"}),
"attention_mode": (["sdpa", "sageattn", "fused_sdpa", "fused_sageattn", "comfy"], {"default": "sdpa"}),
"load_device": (["main_device", "offload_device"], {"default": "main_device"}),
}
}

View File

@ -395,7 +395,6 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
self.num_frames = num_frames