precommit

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
NickLucche 2025-12-19 12:45:12 +00:00
parent 7c95fd8279
commit 1429a5e9a8

View File

@ -110,7 +110,12 @@ def vit_flash_attn_wrapper(
)
def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: float | None = None) -> torch.Tensor:
def apply_sdpa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: float | None = None,
) -> torch.Tensor:
"""
Input shape:
(batch_size x seq_len x num_heads x head_size)