[ROCm] Support for Whisper v1 with Aiter Unified Attention and Aiter Flash Attention (#28376)

Signed-off-by: apinge <Tong.Qiu2@amd.com>
This commit is contained in:
tongqiu 2025-11-24 11:26:00 +08:00 committed by GitHub
parent 30854783ad
commit 5253f4276f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 15 deletions

View File

@ -517,12 +517,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type != AttentionType.DECODER:
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl"
"Encoder self-attention is not implemented for FlashAttentionImpl"
)
def extend_forward(
@ -678,7 +675,14 @@ class AiterFlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
@ -704,8 +708,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
# decode:extend:prefill
query = query[:num_actual_tokens]
key = key[:num_actual_tokens]
value = value[:num_actual_tokens]
if key is not None:
key = key[:num_actual_tokens]
if value is not None:
value = value[:num_actual_tokens]
output_actual_tokens = output[:num_actual_tokens]

View File

@ -142,7 +142,14 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
@ -169,7 +176,10 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
descale_shape = (
cu_seqlens_q.shape[0] - 1,
key.shape[1] if key is not None else self.num_kv_heads,
)
self.unified_attention(
q=query[:num_actual_tokens],

View File

@ -238,12 +238,9 @@ class RocmAttentionImpl(AttentionImpl):
RocmAttentionBackend.validate_head_size(head_size)
if attn_type != AttentionType.DECODER:
if attn_type not in [AttentionType.DECODER, AttentionType.ENCODER_DECODER]:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"RocmAttentionImpl"
"Encoder self-attention is not implemented for RocmAttentionImpl"
)
self.fp8_dtype = current_platform.fp8_dtype()