diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 15c0ce33e9659..8d5ebd93e063d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -430,9 +430,11 @@ class MultiHeadAttention(nn.Module): key: torch.Tensor, value: torch.Tensor, ) -> torch.Tensor: - """Input shape: batch_size x seq_len x hidden_size""" - # TODO(Isotr0py): Use existing backend implementations and support FA3 - bsz, q_len, _ = query.size() + """Input shape: + (batch_size x seq_len x hidden_size) or + (batch_size x seq_len x num_heads x head_size) + """ + bsz, q_len = query.size()[:2] kv_len = key.size(1) query = query.view(bsz, q_len, self.num_heads, self.head_size)