mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 05:55:01 +08:00
[CLI env var] Add VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH in env variables (#25274)
Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
06a41334c7
commit
cfbee3d0e7
@ -46,7 +46,10 @@ backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
@ -66,6 +69,7 @@ backend_configs = {
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
@ -89,7 +93,10 @@ backend_configs = {
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
|
||||
@ -47,7 +47,10 @@ backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
@ -67,6 +70,7 @@ backend_configs = {
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
@ -75,7 +79,10 @@ backend_configs = {
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
|
||||
@ -119,6 +119,7 @@ if TYPE_CHECKING:
|
||||
VLLM_SERVER_DEV_MODE: bool = False
|
||||
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||
VLLM_MLA_DISABLE: bool = False
|
||||
VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH: int = 16
|
||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
VLLM_CUDART_SO_PATH: Optional[str] = None
|
||||
@ -946,6 +947,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_MLA_DISABLE":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
|
||||
|
||||
# If set, vLLM will pick up the provided Flash Attention MLA
|
||||
# max number splits for cuda graph decode
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH":
|
||||
lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||
"16")),
|
||||
|
||||
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
||||
# it allows ray to schedule multiple actors on a single GPU,
|
||||
# so that users can colocate other actors on the same GPUs as vLLM.
|
||||
@ -1379,6 +1386,7 @@ def compute_hash() -> str:
|
||||
environment_variables_to_hash = [
|
||||
"VLLM_PP_LAYER_PARTITION",
|
||||
"VLLM_MLA_DISABLE",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH",
|
||||
"VLLM_USE_TRITON_FLASH_ATTN",
|
||||
"VLLM_USE_TRITON_AWQ",
|
||||
"VLLM_DP_RANK",
|
||||
|
||||
@ -8,6 +8,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
@ -33,9 +34,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
|
||||
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@ -215,7 +213,8 @@ class FlashAttentionMetadataBuilder(
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
self.max_num_splits = (
|
||||
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
|
||||
@ -24,10 +25,6 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# NOTE(matt): This is an arbitrary number, copied from
|
||||
# woosuk's implementation in standard FlashAttention backend
|
||||
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
|
||||
|
||||
|
||||
class FlashAttnMLABackend(MLACommonBackend):
|
||||
|
||||
@ -97,7 +94,8 @@ class FlashAttnMLAMetadataBuilder(
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
self.max_num_splits = (
|
||||
envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
|
||||
|
||||
# TODO(lucas): Until we add support for the DCP custom masking we need
|
||||
# to restrict decodes to q_len == 1 when DCP is enabled.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user