exclude sageattn from compile

This commit is contained in:
kijai 2024-11-17 22:11:32 +02:00
parent eebdc412f9
commit e70da23ac2

View File

@ -40,10 +40,15 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try: try:
from sageattention import sageattn from sageattention import sageattn
SAGEATTN_IS_AVAILABLE = True SAGEATTN_IS_AVAILABLE = True
except: except:
SAGEATTN_IS_AVAILABLE = False SAGEATTN_IS_AVAILABLE = False
@torch.compiler.disable()
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
def fft(tensor): def fft(tensor):
tensor_fft = torch.fft.fft2(tensor) tensor_fft = torch.fft.fft2(tensor)
tensor_fft_shifted = torch.fft.fftshift(tensor_fft) tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
@ -121,7 +126,7 @@ class CogVideoXAttnProcessor2_0:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if attention_mode == "sageattn": if attention_mode == "sageattn":
hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False) hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
else: else:
hidden_states = F.scaled_dot_product_attention( hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False