mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:04:57 +08:00
fix cross attention (#28346)
Signed-off-by: fsx950223 <fsx950223@outlook.com>
This commit is contained in:
parent
9452863088
commit
fc9f821d20
@ -244,14 +244,11 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Encoder self-attention and "
|
"Encoder self-attention is not implemented for TritonAttentionImpl"
|
||||||
"encoder/decoder cross-attention "
|
|
||||||
"are not implemented for "
|
|
||||||
"TritonAttentionImpl"
|
|
||||||
)
|
)
|
||||||
|
self.attn_type = attn_type
|
||||||
self.fp8_dtype = current_platform.fp8_dtype()
|
self.fp8_dtype = current_platform.fp8_dtype()
|
||||||
|
|
||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
@ -312,7 +309,11 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
key_cache, value_cache = kv_cache.unbind(1)
|
key_cache, value_cache = kv_cache.unbind(1)
|
||||||
|
|
||||||
if self.kv_sharing_target_layer_name is None:
|
if (
|
||||||
|
self.kv_sharing_target_layer_name is None
|
||||||
|
and key is not None
|
||||||
|
and value is not None
|
||||||
|
):
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
@ -346,7 +347,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
max_seqlen_k = attn_metadata.max_seq_len
|
max_seqlen_k = attn_metadata.max_seq_len
|
||||||
block_table = attn_metadata.block_table
|
block_table = attn_metadata.block_table
|
||||||
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||||
|
|
||||||
unified_attention(
|
unified_attention(
|
||||||
q=query[:num_actual_tokens],
|
q=query[:num_actual_tokens],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user