mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 11:37:58 +08:00
[Bugfix][ROCm] Use chunked_prefill_paged_decode as fallback for V1 attention on ROCm (#18093)
Signed-off-by: kf <kuanfu.liu@embeddedllm.com>
This commit is contained in:
parent
4e1c6a0264
commit
ee659e3b60
@ -7,6 +7,9 @@ import torch
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||||
|
chunked_prefill_paged_decode)
|
||||||
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -162,19 +165,40 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
# Whenever making a change in this method, please benchmark the
|
# Whenever making a change in this method, please benchmark the
|
||||||
# performance to make sure it does not introduce any overhead.
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
|
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||||
|
use_prefill_decode_attn = (num_queries_per_kv &
|
||||||
|
(num_queries_per_kv - 1)) != 0
|
||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
if use_prefill_decode_attn:
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
key,
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
value,
|
|
||||||
key_cache,
|
# Reshape the input keys and values and store them in the cache.
|
||||||
value_cache,
|
PagedAttention.write_to_paged_cache(
|
||||||
attn_metadata.slot_mapping,
|
key,
|
||||||
self.kv_cache_dtype,
|
value,
|
||||||
layer._k_scale,
|
key_cache,
|
||||||
layer._v_scale,
|
value_cache,
|
||||||
)
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
@ -209,26 +233,47 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
max_seqlen_k = attn_metadata.max_seq_len
|
max_seqlen_k = attn_metadata.max_seq_len
|
||||||
block_table = attn_metadata.block_table
|
block_table = attn_metadata.block_table
|
||||||
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
if use_prefill_decode_attn:
|
||||||
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
|
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
|
||||||
|
key=key[:num_actual_tokens],
|
||||||
|
value=value[:num_actual_tokens],
|
||||||
|
output=output[:num_actual_tokens],
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
block_table=block_table,
|
||||||
|
query_start_loc=cu_seqlens_q,
|
||||||
|
seq_lens=seqused_k,
|
||||||
|
max_seq_len=max_seqlen_k,
|
||||||
|
max_query_len=max_seqlen_q,
|
||||||
|
k_scale=layer._k_scale,
|
||||||
|
v_scale=layer._v_scale,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
sliding_window=self.sliding_window[0],
|
||||||
|
sm_scale=self.scale)
|
||||||
|
|
||||||
unified_attention(
|
else:
|
||||||
q=query[:num_actual_tokens],
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||||
k=key_cache,
|
|
||||||
v=value_cache,
|
unified_attention(
|
||||||
out=output[:num_actual_tokens],
|
q=query[:num_actual_tokens],
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
k=key_cache,
|
||||||
max_seqlen_q=max_seqlen_q,
|
v=value_cache,
|
||||||
seqused_k=seqused_k,
|
out=output[:num_actual_tokens],
|
||||||
max_seqlen_k=max_seqlen_k,
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
softmax_scale=self.scale,
|
max_seqlen_q=max_seqlen_q,
|
||||||
causal=True,
|
seqused_k=seqused_k,
|
||||||
alibi_slopes=self.alibi_slopes,
|
max_seqlen_k=max_seqlen_k,
|
||||||
window_size=self.sliding_window,
|
softmax_scale=self.scale,
|
||||||
block_table=block_table,
|
causal=True,
|
||||||
softcap=self.logits_soft_cap,
|
alibi_slopes=self.alibi_slopes,
|
||||||
q_descale=None, # Not supported
|
window_size=self.sliding_window,
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
block_table=block_table,
|
||||||
v_descale=layer._v_scale.expand(descale_shape),
|
softcap=self.logits_soft_cap,
|
||||||
)
|
q_descale=None, # Not supported
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user