From 897d2662be3d593c78ee88a3eae24601cfef5499 Mon Sep 17 00:00:00 2001 From: vivienfanghua Date: Tue, 28 Oct 2025 19:13:17 +0800 Subject: [PATCH] add --- comfy/cli_args.py | 1 + comfy/ldm/modules/attention.py | 96 ++++++++++++++++++++++++++++++++++ comfy/model_management.py | 3 ++ 3 files changed, 100 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc1f12482..d4219cdf3 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -111,6 +111,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") +attn_group.add_argument("--use-aiter-attention", action="store_true", help="Use aiter attention.") attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7437e0567..76bab22b0 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -39,6 +39,15 @@ except ImportError: logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +AITER_ATTENTION_IS_AVAILABLE = False +try: + import aiter + AITER_ATTENTION_IS_AVAILABLE = True +except ImportError: + if model_management.aiter_attention_enabled(): + logging.error(f"\n\nTo use the `--use-aiter-attention` feature, the `aiter` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install aiter") + exit(-1) + REGISTERED_ATTENTION_FUNCTIONS = {} def register_attention_function(name: str, func: Callable): # avoid replacing existing functions @@ -619,11 +628,96 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape return out +try: + @torch.library.custom_op("aiter_attention::aiter_flash_attn", mutates_args=()) + def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, softmax_scale: Optional[float] = None, + causal: bool = False, window_size: tuple = (-1, -1), + bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False) -> torch.Tensor: + return aiter.flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, + causal=causal, window_size=window_size, bias=bias, + alibi_slopes=alibi_slopes, deterministic=deterministic, + return_lse=False, return_attn_probs=False, + cu_seqlens_q=None, cu_seqlens_kv=None) + + + @aiter_flash_attn_wrapper.register_fake + def aiter_flash_attn_fake(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, + window_size=(-1, -1), bias=None, alibi_slopes=None, deterministic=False): + # Output shape is the same as q + return q.new_empty(q.shape) +except AttributeError as error: + AITER_ATTN_ERROR = error + + def aiter_flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, softmax_scale: Optional[float] = None, + causal: bool = False, window_size: tuple = (-1, -1), + bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False) -> torch.Tensor: + assert False, f"Could not define aiter_flash_attn_wrapper: {AITER_ATTN_ERROR}" + +@wrap_attn +def attention_aiter(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + # reshape to (batch, seqlen, nheads, headdim) for aiter + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head), + (q, k, v), + ) + + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + try: + # aiter.flash_attn_func expects (batch, seqlen, nheads, headdim) format + out = aiter_flash_attn_wrapper( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + bias=mask, + alibi_slopes=None, + deterministic=False, + ) + except Exception as e: + logging.warning(f"Aiter Attention failed, using default SDPA: {e}") + # fallback needs (batch, nheads, seqlen, headdim) format + q_sdpa = q.transpose(1, 2) + k_sdpa = k.transpose(1, 2) + v_sdpa = v.transpose(1, 2) + out = torch.nn.functional.scaled_dot_product_attention(q_sdpa, k_sdpa, v_sdpa, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = out.transpose(1, 2) + + if skip_output_reshape: + # output is already in (batch, seqlen, nheads, headdim), need (batch, nheads, seqlen, headdim) + out = out.transpose(1, 2) + else: + # reshape from (batch, seqlen, nheads, headdim) to (batch, seqlen, nheads * headdim) + out = out.reshape(b, -1, heads * dim_head) + return out + + optimized_attention = attention_basic if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage +elif model_management.aiter_attention_enabled(): + logging.info("Using aiter attention") + optimized_attention = attention_aiter elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers @@ -647,6 +741,8 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if AITER_ATTENTION_IS_AVAILABLE: + register_attention_function("aiter", attention_aiter) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): diff --git a/comfy/model_management.py b/comfy/model_management.py index afe78f36e..350bee6c6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1083,6 +1083,9 @@ def cast_to_device(tensor, device, dtype, copy=False): def sage_attention_enabled(): return args.use_sage_attention +def aiter_attention_enabled(): + return args.use_aiter_attention + def flash_attention_enabled(): return args.use_flash_attention