mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-05-18 22:47:11 +08:00
Update model_optimization_nodes.py
This commit is contained in:
parent
08164edab3
commit
9954d6b599
@ -122,16 +122,10 @@ class BaseLoaderKJ:
|
|||||||
from sageattn3 import sageattn3_blackwell
|
from sageattn3 import sageattn3_blackwell
|
||||||
if sage_attention == "sageattn3_per_block_mean":
|
if sage_attention == "sageattn3_per_block_mean":
|
||||||
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
|
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
|
||||||
if q.shape == k.shape and q.shape == v.shape:
|
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True).transpose(1, 2)
|
||||||
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True)
|
|
||||||
else:
|
|
||||||
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout="NHD")
|
|
||||||
else:
|
else:
|
||||||
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
|
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
|
||||||
if q.shape == k.shape and q.shape == v.shape:
|
return sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False).transpose(1, 2)
|
||||||
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False)
|
|
||||||
else:
|
|
||||||
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout="NHD")
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
sage_func = set_sage_func(sage_attention)
|
sage_func = set_sage_func(sage_attention)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user