diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index dc5d5e8c4904a..ac19b7bd81e8d 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -110,7 +110,12 @@ def vit_flash_attn_wrapper( ) -def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: float | None = None) -> torch.Tensor: +def apply_sdpa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: float | None = None, +) -> torch.Tensor: """ Input shape: (batch_size x seq_len x num_heads x head_size)