mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
use flash-attn via xformers (#877)
This commit is contained in:
parent
d2b2eed67c
commit
75471386de
@ -266,7 +266,6 @@ def run_multi_query_kv_attention(
|
||||
qkv.uniform_(-1e-3, 1e-3)
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
|
||||
attn_op = xops.fmha.cutlass.FwOp()
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||
output = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
@ -275,7 +274,6 @@ def run_multi_query_kv_attention(
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
op=attn_op,
|
||||
)
|
||||
output = output.squeeze(0)
|
||||
|
||||
|
||||
@ -61,7 +61,6 @@ class PagedAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
@ -115,7 +114,6 @@ class PagedAttention(nn.Module):
|
||||
attn_bias=input_metadata.attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.squeeze(0))
|
||||
@ -404,7 +402,6 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
attn_bias=input_metadata.attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.squeeze(0))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user