mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Merge 483ba1e98b0eb44fcd0995fd73604fcc116f218a into fd271dedfde6e192a1f1a025521070876e89e04a
This commit is contained in:
commit
cdefeba7ed
@ -112,6 +112,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||
attn_group.add_argument("--use-sage-attention3", action="store_true", help="Use sage attention 3.")
|
||||
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
@ -30,6 +30,18 @@ except ImportError as e:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = False
|
||||
try:
|
||||
from sageattn3 import sageattn3_blackwell
|
||||
SAGE_ATTENTION3_IS_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
if model_management.sage_attention3_enabled():
|
||||
if e.name == "sageattn3":
|
||||
logging.error(f"\n\nTo use the `--use-sage-attention3` feature, the `sageattn3` package must be installed first.\nPlease check https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell")
|
||||
else:
|
||||
raise e
|
||||
exit(-1)
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
@ -563,6 +575,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
exception_fallback = False
|
||||
if (q.device.type != "cuda" or
|
||||
q.dtype not in (torch.float16, torch.bfloat16) or
|
||||
mask is not None):
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
B, H, L, D = q.shape
|
||||
if H != heads:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
q_s, k_s, v_s = q, k, v
|
||||
N = q.shape[2]
|
||||
dim_head = D
|
||||
else:
|
||||
B, N, inner_dim = q.shape
|
||||
if inner_dim % heads != 0:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
dim_head = inner_dim // heads
|
||||
|
||||
if dim_head >= 256 or N <= 1024:
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not skip_reshape:
|
||||
q_s, k_s, v_s = map(
|
||||
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
B, H, L, D = q_s.shape
|
||||
|
||||
try:
|
||||
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
|
||||
except Exception as e:
|
||||
exception_fallback = True
|
||||
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)
|
||||
|
||||
if exception_fallback:
|
||||
if not skip_reshape:
|
||||
del q_s, k_s, v_s
|
||||
return attention_pytorch(
|
||||
q, k, v, heads,
|
||||
mask=mask,
|
||||
attn_precision=attn_precision,
|
||||
skip_reshape=False,
|
||||
skip_output_reshape=skip_output_reshape,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if skip_reshape:
|
||||
if not skip_output_reshape:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
else:
|
||||
if skip_output_reshape:
|
||||
pass
|
||||
else:
|
||||
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
|
||||
|
||||
return out
|
||||
|
||||
try:
|
||||
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
|
||||
@ -627,6 +726,9 @@ optimized_attention = attention_basic
|
||||
if model_management.sage_attention_enabled():
|
||||
logging.info("Using sage attention")
|
||||
optimized_attention = attention_sage
|
||||
if model_management.sage_attention3_enabled():
|
||||
logging.info("Using sage attention 3")
|
||||
optimized_attention = attention3_sage
|
||||
elif model_management.xformers_enabled():
|
||||
logging.info("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
@ -650,6 +752,8 @@ optimized_attention_masked = optimized_attention
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if SAGE_ATTENTION3_IS_AVAILABLE:
|
||||
register_attention_function("sage3", attention3_sage)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
|
||||
@ -1189,6 +1189,9 @@ def unpin_memory(tensor):
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
def sage_attention3_enabled():
|
||||
return args.use_sage_attention3
|
||||
|
||||
def flash_attention_enabled():
|
||||
return args.use_flash_attention
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user