diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index be4dc3eb3c0da..bb05b468fd102 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -399,7 +399,8 @@ class MultiHeadAttention(nn.Module): key, value, scale=self.scale) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif (self.attn_backend == _Backend.TORCH_SDPA + or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1): query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention(query,