sageattn fp8/GGUF fix

This commit is contained in:
kijai 2025-01-20 11:23:02 +02:00
parent 5bca0548d9
commit 51daeef1b7

View File

@ -63,25 +63,25 @@ def set_attention_func(attention_mode, heads):
elif attention_mode == "sageattn" or attention_mode == "fused_sageattn": elif attention_mode == "sageattn" or attention_mode == "fused_sageattn":
@torch.compiler.disable() @torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None): def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask) return sageattn(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask)
return func return func
elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda": elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda":
from sageattention import sageattn_qk_int8_pv_fp16_cuda from sageattention import sageattn_qk_int8_pv_fp16_cuda
@torch.compiler.disable() @torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None): def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32") return sageattn_qk_int8_pv_fp16_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32")
return func return func
elif attention_mode == "sageattn_qk_int8_pv_fp16_triton": elif attention_mode == "sageattn_qk_int8_pv_fp16_triton":
from sageattention import sageattn_qk_int8_pv_fp16_triton from sageattention import sageattn_qk_int8_pv_fp16_triton
@torch.compiler.disable() @torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None): def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask) return sageattn_qk_int8_pv_fp16_triton(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask)
return func return func
elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda": elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda":
from sageattention import sageattn_qk_int8_pv_fp8_cuda from sageattention import sageattn_qk_int8_pv_fp8_cuda
@torch.compiler.disable() @torch.compiler.disable()
def func(q, k, v, is_causal=False, attn_mask=None): def func(q, k, v, is_causal=False, attn_mask=None):
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32") return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32")
return func return func
def fft(tensor): def fft(tensor):