Update model_optimization_nodes.py

This commit is contained in:
kijai 2025-10-06 21:37:56 +03:00
parent 08164edab3
commit 9954d6b599

View File

@ -122,16 +122,10 @@ class BaseLoaderKJ:
from sageattn3 import sageattn3_blackwell
if sage_attention == "sageattn3_per_block_mean":
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
if q.shape == k.shape and q.shape == v.shape:
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True)
else:
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout="NHD")
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True).transpose(1, 2)
else:
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
if q.shape == k.shape and q.shape == v.shape:
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False)
else:
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout="NHD")
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False).transpose(1, 2)
return func
sage_func = set_sage_func(sage_attention)