mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:54:58 +08:00
use flash-attn via xformers (#877)
This commit is contained in:
parent
d2b2eed67c
commit
75471386de
@ -266,7 +266,6 @@ def run_multi_query_kv_attention(
|
|||||||
qkv.uniform_(-1e-3, 1e-3)
|
qkv.uniform_(-1e-3, 1e-3)
|
||||||
query, key, value = qkv.unbind(dim=1)
|
query, key, value = qkv.unbind(dim=1)
|
||||||
|
|
||||||
attn_op = xops.fmha.cutlass.FwOp()
|
|
||||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
|
||||||
output = xops.memory_efficient_attention_forward(
|
output = xops.memory_efficient_attention_forward(
|
||||||
query.unsqueeze(0),
|
query.unsqueeze(0),
|
||||||
@ -275,7 +274,6 @@ def run_multi_query_kv_attention(
|
|||||||
attn_bias=attn_bias,
|
attn_bias=attn_bias,
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
op=attn_op,
|
|
||||||
)
|
)
|
||||||
output = output.squeeze(0)
|
output = output.squeeze(0)
|
||||||
|
|
||||||
|
|||||||
@ -61,7 +61,6 @@ class PagedAttention(nn.Module):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
@ -115,7 +114,6 @@ class PagedAttention(nn.Module):
|
|||||||
attn_bias=input_metadata.attn_bias[0],
|
attn_bias=input_metadata.attn_bias[0],
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
op=self.attn_op,
|
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
output.copy_(out.squeeze(0))
|
output.copy_(out.squeeze(0))
|
||||||
@ -404,7 +402,6 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|||||||
attn_bias=input_metadata.attn_bias[i],
|
attn_bias=input_metadata.attn_bias[i],
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
op=self.attn_op,
|
|
||||||
)
|
)
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
output[start:end].copy_(out.squeeze(0))
|
output[start:end].copy_(out.squeeze(0))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user