fix cross attention (#28346)

Signed-off-by: fsx950223 <fsx950223@outlook.com>
This commit is contained in:
who who who 2025-11-21 20:55:43 +08:00 committed by GitHub
parent 9452863088
commit fc9f821d20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -244,14 +244,11 @@ class TritonAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type != AttentionType.DECODER:
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonAttentionImpl"
"Encoder self-attention is not implemented for TritonAttentionImpl"
)
self.attn_type = attn_type
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks
@ -312,7 +309,11 @@ class TritonAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(1)
if self.kv_sharing_target_layer_name is None:
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
@ -346,7 +347,7 @@ class TritonAttentionImpl(AttentionImpl):
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
unified_attention(
q=query[:num_actual_tokens],