mirror of
https://git.datalinker.icu/kijai/ComfyUI-CogVideoXWrapper.git
synced 2025-12-09 04:44:22 +08:00
exclude sageattn from compile
This commit is contained in:
parent
eebdc412f9
commit
e70da23ac2
@ -40,10 +40,15 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
from sageattention import sageattn
|
||||
|
||||
SAGEATTN_IS_AVAILABLE = True
|
||||
except:
|
||||
SAGEATTN_IS_AVAILABLE = False
|
||||
|
||||
@torch.compiler.disable()
|
||||
def sageattn_func(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False):
|
||||
return sageattn(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p,is_causal=is_causal)
|
||||
|
||||
def fft(tensor):
|
||||
tensor_fft = torch.fft.fft2(tensor)
|
||||
tensor_fft_shifted = torch.fft.fftshift(tensor_fft)
|
||||
@ -121,7 +126,7 @@ class CogVideoXAttnProcessor2_0:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
if attention_mode == "sageattn":
|
||||
hidden_states = sageattn(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||
hidden_states = sageattn_func(query, key, value, attn_mask=attention_mask, dropout_p=0.0,is_causal=False)
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user