mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 06:04:26 +08:00
[V1] Optimize the CPU overheads in FlashAttention custom op (#10733)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
8c1e77fb58
commit
98f47f2a40
@ -135,6 +135,13 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|
||||||
|
# Reshape the query, key, and value tensors.
|
||||||
|
# NOTE(woosuk): We do this outside the custom op to minimize the CPU
|
||||||
|
# overheads from the non-CUDA-graph regions.
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
torch.ops.vllm.unified_v1_flash_attention(
|
torch.ops.vllm.unified_v1_flash_attention(
|
||||||
output,
|
output,
|
||||||
@ -153,7 +160,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.logits_soft_cap,
|
self.logits_soft_cap,
|
||||||
)
|
)
|
||||||
return output
|
return output.view(-1, self.num_heads * self.head_size)
|
||||||
|
|
||||||
|
|
||||||
def unified_v1_flash_attention(
|
def unified_v1_flash_attention(
|
||||||
@ -184,11 +191,6 @@ def unified_v1_flash_attention(
|
|||||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
# Reshape the query, key, and value tensors.
|
|
||||||
query = query.view(-1, num_heads, head_size)
|
|
||||||
key = key.view(-1, num_kv_heads, head_size)
|
|
||||||
value = value.view(-1, num_kv_heads, head_size)
|
|
||||||
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
key_cache = kv_cache[0]
|
key_cache = kv_cache[0]
|
||||||
value_cache = kv_cache[1]
|
value_cache = kv_cache[1]
|
||||||
@ -218,8 +220,7 @@ def unified_v1_flash_attention(
|
|||||||
block_table=attn_metadata.block_table,
|
block_table=attn_metadata.block_table,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.view(num_actual_tokens, -1)
|
# TODO(woosuk): Remove this unnecessary copy.
|
||||||
# TODO(woosuk): Optimize this.
|
|
||||||
output[:num_actual_tokens].copy_(attn_output)
|
output[:num_actual_tokens].copy_(attn_output)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user