mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-15 15:20:12 +08:00
Update model_optimization_nodes.py
This commit is contained in:
parent
7ce88200c7
commit
ab89f1c6c9
@ -121,10 +121,10 @@ class BaseLoaderKJ:
|
|||||||
elif "sageattn3" in sage_attention:
|
elif "sageattn3" in sage_attention:
|
||||||
from sageattn3 import sageattn3_blackwell
|
from sageattn3 import sageattn3_blackwell
|
||||||
if sage_attention == "sageattn3_per_block_mean":
|
if sage_attention == "sageattn3_per_block_mean":
|
||||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
|
||||||
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True)
|
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True)
|
||||||
else:
|
else:
|
||||||
def func(q, k, v, is_causal=False, attn_mask=None):
|
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
|
||||||
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False)
|
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user