mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 08:05:16 +08:00
[Platform] Add output for Attention Backend (#11981)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
1f18adb245
commit
2e0e017610
@ -31,6 +31,10 @@ class AttentionType:
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
# For some attention backends, we allocate an output tensor before
|
||||
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
|
||||
@ -29,6 +29,8 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@ -110,11 +110,7 @@ class Attention(nn.Module):
|
||||
self.use_direct_call = not current_platform.is_cuda_alike(
|
||||
) and not current_platform.is_cpu()
|
||||
|
||||
# For some attention backends, we allocate an output tensor before
|
||||
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
self.use_output = self.backend == _Backend.FLASH_ATTN or \
|
||||
self.backend == _Backend.FLASH_ATTN_VLLM_V1
|
||||
self.use_output = attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
|
||||
@ -15,6 +15,8 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user