diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index eed3cba9a2ca7..eca83b6d2ee45 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -405,13 +405,6 @@ class FlashAttentionImpl(AttentionImpl): FlashAttentionBackend.validate_head_size(head_size) - if attn_type not in [ - AttentionType.DECODER, AttentionType.ENCODER_ONLY - ]: - raise NotImplementedError("Encoder/decoder cross-attention " - "is not implemented for " - "FlashAttentionImpl") - self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -477,7 +470,7 @@ class FlashAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens # Handle encoder attention differently - no KV cache needed - if attn_type in (AttentionType.ENCODER_ONLY, ): + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching return self._forward_encoder_attention(query[:num_actual_tokens], @@ -489,7 +482,11 @@ class FlashAttentionImpl(AttentionImpl): # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: + # key and value may be None in the case of cross attention. They are + # calculated once based on the output from the encoder and then cached + # in KV cache. + if (self.kv_sharing_target_layer_name is None and key is not None + and value is not None): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -528,7 +525,7 @@ class FlashAttentionImpl(AttentionImpl): block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) flash_attn_varlen_func( q=query[:num_actual_tokens],