Support sageattn for Cosmos

This commit is contained in:
kijai 2025-01-12 17:49:55 +02:00
parent 31cb7c1d14
commit 833cda9fa2

View File

@ -46,7 +46,7 @@ class BaseLoaderKJ:
sage_func = set_sage_func(sage_attention)
@torch.compiler.disable()
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
@ -67,23 +67,29 @@ class BaseLoaderKJ:
mask = mask.unsqueeze(1)
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND":
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = out.reshape(b, -1, heads * dim_head)
if skip_output_reshape:
out = out.transpose(1, 2)
else:
out = out.reshape(b, -1, heads * dim_head)
return out
comfy_attention.optimized_attention = attention_sage
comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage
comfy.ldm.flux.math.optimized_attention = attention_sage
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage
comfy.ldm.cosmos.blocks.optimized_attention = attention_sage
else:
comfy_attention.optimized_attention = orig_attention
comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention
comfy.ldm.flux.math.optimized_attention = orig_attention
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention
comfy.ldm.cosmos.blocks.optimized_attention = orig_attention
if patch_cublaslinear:
if not BaseLoaderKJ.cublas_patched: