mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Merge 0f6f3c93fda308e1c10badccf9498007946682ce into 76f18e955dcbc88ed13d6802194fd897927f93e5
This commit is contained in:
commit
2570c6c351
@ -112,6 +112,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||
attn_group.add_argument("--use-aiter-attention", action="store_true", help="Use aiter attention.")
|
||||
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
@ -39,6 +39,16 @@ except ImportError:
|
||||
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
|
||||
exit(-1)
|
||||
|
||||
AITER_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
import aiter
|
||||
AITER_ATTENTION_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.aiter_attention_enabled():
|
||||
logging.error("\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.")
|
||||
logging.error("Installation instructions: https://github.com/ROCm/aiter/tree/main?tab=readme-ov-file#installation")
|
||||
exit(-1)
|
||||
|
||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||
def register_attention_function(name: str, func: Callable):
|
||||
# avoid replacing existing functions
|
||||
@ -615,6 +625,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
except Exception as e:
|
||||
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -622,11 +633,86 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
return out
|
||||
|
||||
|
||||
@wrap_attn
|
||||
def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
# Store original inputs for fallback
|
||||
q_orig, k_orig, v_orig, mask_orig = q, k, v, mask
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
|
||||
# Convert mask to [sq, sk] format for aiter bias
|
||||
bias = None
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
bias = mask
|
||||
elif mask.ndim == 3:
|
||||
seqlen_q = q.shape[1]
|
||||
if mask.shape[-2] == 1:
|
||||
# [1, 1, sk] -> expand to [sq, sk]
|
||||
bias = mask.squeeze(0).expand(seqlen_q, -1)
|
||||
else:
|
||||
# [batch, sq, sk] -> take first batch
|
||||
bias = mask[0]
|
||||
elif mask.ndim == 4:
|
||||
# [batch, heads, sq, sk] -> take first batch and head
|
||||
bias = mask[0, 0]
|
||||
|
||||
try:
|
||||
# aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format
|
||||
out = aiter.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1),
|
||||
bias=bias,
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_lse=False,
|
||||
return_attn_probs=False,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_kv=None,
|
||||
)
|
||||
|
||||
if skip_output_reshape:
|
||||
# output is (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim)
|
||||
out = out.transpose(1, 2)
|
||||
else:
|
||||
# reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim)
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Aiter Attention failed, falling back to pytorch attention: {e}")
|
||||
# Fallback to attention_pytorch with original inputs
|
||||
return attention_pytorch(q_orig, k_orig, v_orig, heads, mask=mask_orig,
|
||||
attn_precision=attn_precision, skip_reshape=skip_reshape,
|
||||
skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
logging.info("Using sage attention")
|
||||
optimized_attention = attention_sage
|
||||
elif model_management.aiter_attention_enabled():
|
||||
logging.info("Using aiter attention")
|
||||
optimized_attention = attention_aiter
|
||||
elif model_management.xformers_enabled():
|
||||
logging.info("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
@ -650,6 +736,8 @@ optimized_attention_masked = optimized_attention
|
||||
# register core-supported attention functions
|
||||
if SAGE_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("sage", attention_sage)
|
||||
if AITER_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("aiter", attention_aiter)
|
||||
if FLASH_ATTENTION_IS_AVAILABLE:
|
||||
register_attention_function("flash", attention_flash)
|
||||
if model_management.xformers_enabled():
|
||||
@ -1093,5 +1181,3 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -1189,6 +1189,9 @@ def unpin_memory(tensor):
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
def aiter_attention_enabled():
|
||||
return args.use_aiter_attention
|
||||
|
||||
def flash_attention_enabled():
|
||||
return args.use_flash_attention
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user