diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 54ffd5c45ff91..a281c9771a82e 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -345,10 +345,10 @@ class CommonAttentionState(AttentionState): if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or " \ - f"'FLASH_ATTN', but "\ + assert self.runner.attn_backend.get_name() in \ + ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ + f"Expected attn_backend name to be either 'XFORMERS'," \ + f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ f"got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) @@ -367,10 +367,10 @@ class CommonAttentionState(AttentionState): if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or "\ - f"'FLASH_ATTN', but "\ + assert self.runner.attn_backend.get_name() in \ + ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ + f"Expected attn_backend name to be either 'XFORMERS'," \ + f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ f"got '{self.runner.attn_backend.get_name()}'" self._add_additonal_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers)