From 2a4ec90854ae3ad08a3593cb4896dfce601974c3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 23 Aug 2023 17:44:21 +0900 Subject: [PATCH] Fix for breaking changes in xformers 0.0.21 (#834) --- requirements.txt | 2 +- vllm/model_executor/layers/attention.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index f9f3c787bd32..38010648e26a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ sentencepiece # Required for LLaMA tokenizer. numpy torch >= 2.0.0 transformers >= 4.31.0 # Required for LLaMA-2. -xformers >= 0.0.19 +xformers >= 0.0.21 fastapi uvicorn pydantic < 2 # Required for OpenAI server. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index a8e064339f68..a9bbb64b7eb5 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -357,11 +357,12 @@ class PagedAttentionWithALiBi(PagedAttention): # be sliced from a tensor whose length is a multiple of 8. padded_len = (prompt_len + 7) // 8 * 8 bias = torch.empty( + 1, # batch_size self.num_heads, - padded_len, + prompt_len, padded_len, device=self.alibi_slopes.device, - )[:, :prompt_len, :prompt_len].copy_(bias) + )[:, :, :, :prompt_len].copy_(bias) bias.mul_(self.alibi_slopes[:, None, None]) attn_bias = LowerTriangularMaskWithTensorBias(bias) input_metadata.attn_bias.append(attn_bias)