mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 01:16:33 +08:00
[Bugfix][ROCm] Use chunked_prefill_paged_decode as fallback for V1 attention on ROCm (#18093)
Signed-off-by: kf <kuanfu.liu@embeddedllm.com>
This commit is contained in:
parent
4e1c6a0264
commit
ee659e3b60
@ -7,6 +7,9 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@ -162,19 +165,40 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||
use_prefill_decode_attn = (num_queries_per_kv &
|
||||
(num_queries_per_kv - 1)) != 0
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
if use_prefill_decode_attn:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
else:
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
@ -209,26 +233,47 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
if use_prefill_decode_attn:
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale)
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
else:
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user