Fix for breaking changes in xformers 0.0.21 (#834)

This commit is contained in:
Woosuk Kwon 2023-08-23 17:44:21 +09:00 committed by GitHub
parent 85ebcda94d
commit 2a4ec90854
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 3 deletions

View File

@ -5,7 +5,7 @@ sentencepiece # Required for LLaMA tokenizer.
numpy numpy
torch >= 2.0.0 torch >= 2.0.0
transformers >= 4.31.0 # Required for LLaMA-2. transformers >= 4.31.0 # Required for LLaMA-2.
xformers >= 0.0.19 xformers >= 0.0.21
fastapi fastapi
uvicorn uvicorn
pydantic < 2 # Required for OpenAI server. pydantic < 2 # Required for OpenAI server.

View File

@ -357,11 +357,12 @@ class PagedAttentionWithALiBi(PagedAttention):
# be sliced from a tensor whose length is a multiple of 8. # be sliced from a tensor whose length is a multiple of 8.
padded_len = (prompt_len + 7) // 8 * 8 padded_len = (prompt_len + 7) // 8 * 8
bias = torch.empty( bias = torch.empty(
1, # batch_size
self.num_heads, self.num_heads,
padded_len, prompt_len,
padded_len, padded_len,
device=self.alibi_slopes.device, device=self.alibi_slopes.device,
)[:, :prompt_len, :prompt_len].copy_(bias) )[:, :, :, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None]) bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias) attn_bias = LowerTriangularMaskWithTensorBias(bias)
input_metadata.attn_bias.append(attn_bias) input_metadata.attn_bias.append(attn_bias)