From 176a95c670f676e88175c6d3a507ace0b1c35f3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Tue, 13 May 2025 22:31:42 -0400 Subject: [PATCH] [Fix] Support CUDAGraph capture for encoder-decoder on ROCm (#18104) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič --- vllm/attention/backends/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 54ffd5c45ff91..a281c9771a82e 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -345,10 +345,10 @@ class CommonAttentionState(AttentionState): if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or " \ - f"'FLASH_ATTN', but "\ + assert self.runner.attn_backend.get_name() in \ + ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ + f"Expected attn_backend name to be either 'XFORMERS'," \ + f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ f"got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) @@ -367,10 +367,10 @@ class CommonAttentionState(AttentionState): if is_encoder_decoder_model: # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. - assert self.runner.attn_backend.get_name() in\ - ["XFORMERS", "FLASH_ATTN"], \ - f"Expected attn_backend name to be either 'XFORMERS' or "\ - f"'FLASH_ATTN', but "\ + assert self.runner.attn_backend.get_name() in \ + ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ + f"Expected attn_backend name to be either 'XFORMERS'," \ + f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ f"got '{self.runner.attn_backend.get_name()}'" self._add_additonal_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers)