diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index bc49123..152d62f 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -122,16 +122,10 @@ class BaseLoaderKJ: from sageattn3 import sageattn3_blackwell if sage_attention == "sageattn3_per_block_mean": def func(q, k, v, is_causal=False, attn_mask=None, **kwargs): - if q.shape == k.shape and q.shape == v.shape: - return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True) - else: - return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout="NHD") + return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True).transpose(1, 2) else: def func(q, k, v, is_causal=False, attn_mask=None, **kwargs): - if q.shape == k.shape and q.shape == v.shape: - return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False) - else: - return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout="NHD") + return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False).transpose(1, 2) return func sage_func = set_sage_func(sage_attention)