diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index f5dcaea79af93..737559bfe70ca 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -31,6 +31,10 @@ class AttentionType: class AttentionBackend(ABC): """Abstract class for attention backends.""" + # For some attention backends, we allocate an output tensor before + # calling the custom op. When piecewise cudagraph is enabled, this + # makes sure the output tensor is allocated inside the cudagraph. + accept_output_buffer: bool = False @staticmethod @abstractmethod diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 23ea244f07dfe..48b3e8d177ec9 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -29,6 +29,8 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func, class FlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + @staticmethod def get_supported_head_sizes() -> List[int]: return [32, 64, 96, 128, 160, 192, 224, 256] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a06db075f334d..a283e87d84070 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -110,11 +110,7 @@ class Attention(nn.Module): self.use_direct_call = not current_platform.is_cuda_alike( ) and not current_platform.is_cpu() - # For some attention backends, we allocate an output tensor before - # calling the custom op. When piecewise cudagraph is enabled, this - # makes sure the output tensor is allocated inside the cudagraph. - self.use_output = self.backend == _Backend.FLASH_ATTN or \ - self.backend == _Backend.FLASH_ATTN_VLLM_V1 + self.use_output = attn_backend.accept_output_buffer compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b02bc9ffde538..7b0786261a6a6 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -15,6 +15,8 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func class FlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + @staticmethod def get_supported_head_sizes() -> List[int]: return [32, 64, 96, 128, 160, 192, 224, 256]