diff --git a/custom_cogvideox_transformer_3d.py b/custom_cogvideox_transformer_3d.py index fb859a5..45f53d3 100644 --- a/custom_cogvideox_transformer_3d.py +++ b/custom_cogvideox_transformer_3d.py @@ -63,25 +63,25 @@ def set_attention_func(attention_mode, heads): elif attention_mode == "sageattn" or attention_mode == "fused_sageattn": @torch.compiler.disable() def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask) + return sageattn(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask) return func elif attention_mode == "sageattn_qk_int8_pv_fp16_cuda": from sageattention import sageattn_qk_int8_pv_fp16_cuda @torch.compiler.disable() def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32") + return sageattn_qk_int8_pv_fp16_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32") return func elif attention_mode == "sageattn_qk_int8_pv_fp16_triton": from sageattention import sageattn_qk_int8_pv_fp16_triton @torch.compiler.disable() def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask) + return sageattn_qk_int8_pv_fp16_triton(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask) return func elif attention_mode == "sageattn_qk_int8_pv_fp8_cuda": from sageattention import sageattn_qk_int8_pv_fp8_cuda @torch.compiler.disable() def func(q, k, v, is_causal=False, attn_mask=None): - return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32") + return sageattn_qk_int8_pv_fp8_cuda(q.to(v), k.to(v), v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32") return func def fft(tensor):