From ecff8309a3ca5159ac09ac9a7976516b9301f64d Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Thu, 27 Mar 2025 01:46:12 -0400 Subject: [PATCH] [ROCm] Env variable to trigger custom PA (#15557) Signed-off-by: Gregory Shtrasberg --- vllm/attention/backends/rocm_flash_attn.py | 3 ++- vllm/envs.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 34f5fedcf36e8..f19773bb2843a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -908,4 +908,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) diff --git a/vllm/envs.py b/vllm/envs.py index 46c5b3a1dc5d0..e16753191c6e2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -78,6 +78,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True + VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -541,6 +542,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ROCM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), + # custom paged attention kernel for MI3* cards + "VLLM_ROCM_CUSTOM_PAGED_ATTN": + lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in + ("true", "1")), + # Divisor for dynamic query scale factor calculation for FP8 KV Cache "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),