Support sageattn for Cosmos

This commit is contained in:
kijai 2025-01-12 17:49:55 +02:00
parent 31cb7c1d14
commit 833cda9fa2

View File

@ -46,7 +46,7 @@ class BaseLoaderKJ:
sage_func = set_sage_func(sage_attention) sage_func = set_sage_func(sage_attention)
@torch.compiler.disable() @torch.compiler.disable()
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
tensor_layout="HND" tensor_layout="HND"
@ -67,23 +67,29 @@ class BaseLoaderKJ:
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if tensor_layout == "HND": if tensor_layout == "HND":
out = ( if not skip_output_reshape:
out.transpose(1, 2).reshape(b, -1, heads * dim_head) out = (
) out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else: else:
out = out.reshape(b, -1, heads * dim_head) if skip_output_reshape:
out = out.transpose(1, 2)
else:
out = out.reshape(b, -1, heads * dim_head)
return out return out
comfy_attention.optimized_attention = attention_sage comfy_attention.optimized_attention = attention_sage
comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage
comfy.ldm.flux.math.optimized_attention = attention_sage comfy.ldm.flux.math.optimized_attention = attention_sage
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage
comfy.ldm.cosmos.blocks.optimized_attention = attention_sage
else: else:
comfy_attention.optimized_attention = orig_attention comfy_attention.optimized_attention = orig_attention
comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention
comfy.ldm.flux.math.optimized_attention = orig_attention comfy.ldm.flux.math.optimized_attention = orig_attention
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention
comfy.ldm.cosmos.blocks.optimized_attention = orig_attention
if patch_cublaslinear: if patch_cublaslinear:
if not BaseLoaderKJ.cublas_patched: if not BaseLoaderKJ.cublas_patched: