[Platform] Add output for Attention Backend (#11981)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-01-14 21:27:04 +08:00 committed by GitHub
parent 1f18adb245
commit 2e0e017610
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 9 additions and 5 deletions

View File

@ -31,6 +31,10 @@ class AttentionType:
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
@staticmethod
@abstractmethod

View File

@ -29,6 +29,8 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]

View File

@ -110,11 +110,7 @@ class Attention(nn.Module):
self.use_direct_call = not current_platform.is_cuda_alike(
) and not current_platform.is_cpu()
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
self.use_output = self.backend == _Backend.FLASH_ATTN or \
self.backend == _Backend.FLASH_ATTN_VLLM_V1
self.use_output = attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")

View File

@ -15,6 +15,8 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]