mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 13:53:45 +08:00
Use FlashAttention kernel
This commit is contained in:
parent
7e3a230c38
commit
62b870fa07
@ -165,8 +165,8 @@ class Attention(nn.Module):
|
||||
value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2)
|
||||
key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2)
|
||||
|
||||
if False:
|
||||
# FIXME(woosuk)
|
||||
if True:
|
||||
# FlashAttention.
|
||||
output = flash_attn(
|
||||
query_proj,
|
||||
key_proj,
|
||||
@ -174,6 +174,7 @@ class Attention(nn.Module):
|
||||
self.sm_scale,
|
||||
)
|
||||
else:
|
||||
# Naive attention with masking.
|
||||
seq_len = query_proj.shape[1]
|
||||
attn_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))
|
||||
|
||||
|
||||
@ -81,6 +81,10 @@ class TPUModelRunner:
|
||||
|
||||
max_prompt_len = max(prompt_lens)
|
||||
assert max_prompt_len > 0
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
# multiple of 16. This is also good for performance.
|
||||
max_prompt_len = (max_prompt_len + 15) // 16 * 16
|
||||
|
||||
input_tokens = _make_array_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user