[CLI env var] Add VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH in env variables (#25274)

Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: qqma <qqma@amazon.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Daisy-Ma-coder 2025-09-22 10:37:43 -07:00 committed by GitHub
parent 06a41334c7
commit cfbee3d0e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 32 additions and 13 deletions

View File

@ -46,7 +46,10 @@ backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
@ -66,6 +69,7 @@ backend_configs = {
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
@ -89,7 +93,10 @@ backend_configs = {
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
}),

View File

@ -47,7 +47,10 @@ backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
@ -67,6 +70,7 @@ backend_configs = {
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
@ -75,7 +79,10 @@ backend_configs = {
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),

View File

@ -119,6 +119,7 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
@ -946,6 +947,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MLA_DISABLE":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
# If set, vLLM will pick up the provided Flash Attention MLA
# max number splits for cuda graph decode
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"16")),
# Number of GPUs per worker in Ray, if it is set to be a fraction,
# it allows ray to schedule multiple actors on a single GPU,
# so that users can colocate other actors on the same GPUs as vLLM.
@ -1379,6 +1386,7 @@ def compute_hash() -> str:
environment_variables_to_hash = [
"VLLM_PP_LAYER_PARTITION",
"VLLM_MLA_DISABLE",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
"VLLM_USE_TRITON_FLASH_ATTN",
"VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK",

View File

@ -8,6 +8,7 @@ import numpy as np
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
@ -33,9 +34,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
class FlashAttentionBackend(AttentionBackend):
@ -215,7 +213,8 @@ class FlashAttentionMetadataBuilder(
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
self.max_num_splits = (
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.

View File

@ -6,6 +6,7 @@ from typing import ClassVar, Optional, Union
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
@ -24,10 +25,6 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
logger = init_logger(__name__)
# NOTE(matt): This is an arbitrary number, copied from
# woosuk's implementation in standard FlashAttention backend
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
class FlashAttnMLABackend(MLACommonBackend):
@ -97,7 +94,8 @@ class FlashAttnMLAMetadataBuilder(
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
self.max_num_splits = (
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
# TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled.