From a4f59bc65ed29ecab11a7e7e20c63aa6de03f6e8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 18 Dec 2024 01:30:20 -0500 Subject: [PATCH 1/2] Pick attention implementation based on device in llama code. --- comfy/text_encoders/llama.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 037dbf280..ad4b4623e 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, Any -from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit @@ -81,6 +81,7 @@ class Attention(nn.Module): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, + optimized_attention=None, ): batch_size, seq_length, _ = hidden_states.shape @@ -124,6 +125,7 @@ class TransformerBlock(nn.Module): x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, + optimized_attention=None, ): # Self Attention residual = x @@ -132,6 +134,7 @@ class TransformerBlock(nn.Module): hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, + optimized_attention=optimized_attention, ) x = residual + x @@ -180,6 +183,7 @@ class Llama2_(nn.Module): mask += causal_mask else: mask = causal_mask + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None if intermediate_output is not None: @@ -191,6 +195,7 @@ class Llama2_(nn.Module): x=x, attention_mask=mask, freqs_cis=freqs_cis, + optimized_attention=optimized_attention, ) if i == intermediate_output: intermediate = x.clone() From 37e5390f5ff01ae367ac37d62377bbedff2a68da Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 18 Dec 2024 01:56:10 -0500 Subject: [PATCH 2/2] Add: --use-sage-attention to enable SageAttention. You need to have the library installed first. --- comfy/cli_args.py | 1 + comfy/ldm/modules/attention.py | 47 ++++++++++++++++++++++++++++++---- comfy/model_management.py | 2 ++ 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 847f35abd..4c6545011 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -104,6 +104,7 @@ attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") +attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index e60d1ab25..0d54e6bec 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -15,6 +15,9 @@ if model_management.xformers_enabled(): import xformers import xformers.ops +if model_management.sage_attention_enabled(): + from sageattention import sageattn + from comfy.cli_args import args import comfy.ops ops = comfy.ops.disable_weight_init @@ -447,20 +450,54 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha return out +def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + tensor_layout="HND" + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head), + (q, k, v), + ) + tensor_layout="NHD" + + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + if tensor_layout == "HND": + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + else: + out = out.reshape(b, -1, heads * dim_head) + return out + + optimized_attention = attention_basic -if model_management.xformers_enabled(): - logging.info("Using xformers cross attention") +if model_management.sage_attention_enabled(): + logging.info("Using sage attention") + optimized_attention = attention_sage +elif model_management.xformers_enabled(): + logging.info("Using xformers attention") optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled(): - logging.info("Using pytorch cross attention") + logging.info("Using pytorch attention") optimized_attention = attention_pytorch else: if args.use_split_cross_attention: - logging.info("Using split optimization for cross attention") + logging.info("Using split optimization for attention") optimized_attention = attention_split else: - logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad optimized_attention_masked = optimized_attention diff --git a/comfy/model_management.py b/comfy/model_management.py index 177c7998b..f6ca252e3 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -837,6 +837,8 @@ def cast_to_device(tensor, device, dtype, copy=False): non_blocking = device_supports_non_blocking(device) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) +def sage_attention_enabled(): + return args.use_sage_attention def xformers_enabled(): global directml_enabled