From 08164edab30bd6161f76519dd54fceb5ab85775b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 6 Oct 2025 21:23:47 +0300 Subject: [PATCH] Fix sage3 for Wan --- nodes/model_optimization_nodes.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index bd1705d..bc49123 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -122,10 +122,16 @@ class BaseLoaderKJ: from sageattn3 import sageattn3_blackwell if sage_attention == "sageattn3_per_block_mean": def func(q, k, v, is_causal=False, attn_mask=None, **kwargs): - return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True) + if q.shape == k.shape and q.shape == v.shape: + return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True) + else: + return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout="NHD") else: def func(q, k, v, is_causal=False, attn_mask=None, **kwargs): - return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False) + if q.shape == k.shape and q.shape == v.shape: + return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False) + else: + return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout="NHD") return func sage_func = set_sage_func(sage_attention)