From 51daeef1b7a0e61054730d5a45b4b5a0ec8cf4e5 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 20 Jan 2025 11:23:02 +0200 Subject: [PATCH] sageattn fp8/GGUF fix --- custom_cogvideox_transformer_3d.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index fb859a5..45f53d3 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -63,25 +63,25 @@ def set_attention_func(attention_mode, heads): elif attention_mode == "sageattn" or attention_mode == "fused_sageattn": @torch.compiler.disable() def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask) + return sageattn(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask) return func elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda": from sageattention import sageattn_qk_int8_pv_fp16_cuda @torch.compiler.disable() def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32") + return sageattn_qk_int8_pv_fp16_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32") return func elif attention_mode == "sageattn_qk_int8_pv_fp16_triton": from sageattention import sageattn_qk_int8_pv_fp16_triton @torch.compiler.disable() def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask) + return sageattn_qk_int8_pv_fp16_triton(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask) return func elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda": from sageattention import sageattn_qk_int8_pv_fp8_cuda @torch.compiler.disable() def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32") + return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32") return func def fft(tensor):