use flash-attn via xformers (#877)

This commit is contained in:
Aman Gupta Karmani 2023-08-30 00:52:13 -04:00 committed by GitHub
parent d2b2eed67c
commit 75471386de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 0 additions and 5 deletions

View File

@ -266,7 +266,6 @@ def run_multi_query_kv_attention(
qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1)
attn_op = xops.fmha.cutlass.FwOp()
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
@ -275,7 +274,6 @@ def run_multi_query_kv_attention(
attn_bias=attn_bias,
p=0.0,
scale=scale,
op=attn_op,
)
output = output.squeeze(0)

View File

@ -61,7 +61,6 @@ class PagedAttention(nn.Module):
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
@ -115,7 +114,6 @@ class PagedAttention(nn.Module):
attn_bias=input_metadata.attn_bias[0],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
@ -404,7 +402,6 @@ class PagedAttentionWithALiBi(PagedAttention):
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))