From 833cda9fa2c46adfa659e3dda0d8b2b35c111ab5 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 12 Jan 2025 17:49:55 +0200 Subject: [PATCH] Support sageattn for Cosmos --- nodes/model_optimization_nodes.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 8593bf6..0f39db1 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -46,7 +46,7 @@ class BaseLoaderKJ: sage_func = set_sage_func(sage_attention) @torch.compiler.disable() - def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape tensor_layout="HND" @@ -67,23 +67,29 @@ class BaseLoaderKJ: mask = mask.unsqueeze(1) out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) if tensor_layout == "HND": - out = ( - out.transpose(1, 2).reshape(b, -1, heads * dim_head) - ) + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) else: - out = out.reshape(b, -1, heads * dim_head) + if skip_output_reshape: + out = out.transpose(1, 2) + else: + out = out.reshape(b, -1, heads * dim_head) return out comfy_attention.optimized_attention = attention_sage comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage comfy.ldm.flux.math.optimized_attention = attention_sage comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage + comfy.ldm.cosmos.blocks.optimized_attention = attention_sage else: comfy_attention.optimized_attention = orig_attention comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention comfy.ldm.flux.math.optimized_attention = orig_attention comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention + comfy.ldm.cosmos.blocks.optimized_attention = orig_attention if patch_cublaslinear: if not BaseLoaderKJ.cublas_patched: