Allow fp32 input for sageattn function

This commit is contained in:
kijai 2025-11-27 13:33:41 +02:00
parent acdd16a973
commit f0ed965cd9

View File

@ -73,6 +73,9 @@ def get_sage_func(sage_attention, allow_compile=False):
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
in_dtype = v.dtype
if q.dtype == torch.float32 or k.dtype == torch.float32 or v.dtype == torch.float32:
q, k, v = q.to(torch.float16), k.to(torch.float16), v.to(torch.float16)
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
@ -91,7 +94,7 @@ def get_sage_func(sage_attention, allow_compile=False):
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout).to(in_dtype)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (