From eca7be9077aa22e70da5c2ef04ff056e3c7bdc58 Mon Sep 17 00:00:00 2001 From: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Date: Tue, 23 Sep 2025 15:17:49 -0700 Subject: [PATCH] =?UTF-8?q?Add=20VLLM=5FENABLE=5FINDUCTOR=5FMAX=5FAUTOTUNE?= =?UTF-8?q?=20&=20VLLM=5FENABLE=5FINDUCTOR=5FCOORDINA=E2=80=A6=20(#25493)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: rouchenzi Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> --- vllm/compilation/compiler_interface.py | 5 +++-- vllm/envs.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 7158fd685964..eeca14d1296f 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -551,8 +551,9 @@ def set_inductor_config(config, runtime_shape): if isinstance(runtime_shape, int): # for a specific batchsize, tuning triton kernel parameters # can be beneficial - config["max_autotune"] = True - config["coordinate_descent_tuning"] = True + config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE + config["coordinate_descent_tuning"] = ( + envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING) class EagerAdaptor(CompilerInterface): diff --git a/vllm/envs.py b/vllm/envs.py index 50d58c5468f9..1c6c1e78ac9b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -193,6 +193,8 @@ if TYPE_CHECKING: VLLM_DBO_COMM_SMS: int = 20 GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None + VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True + VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True VLLM_USE_NCCL_SYMM_MEM: bool = False VLLM_NCCL_INCLUDE_PATH: Optional[str] = None @@ -1413,6 +1415,17 @@ environment_variables: dict[str, Callable[[], Any]] = { "code_interpreter", "web_search_preview"]), + # Enable max_autotune & coordinate_descent_tuning in inductor_config + # to compile static shapes passed from compile_sizes in compilation_config + # If set to 1, enable max_autotune; By default, this is enabled (1) + "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE": + lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "1"))), + # If set to 1, enable coordinate_descent_tuning; + # By default, this is enabled (1) + "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING": + lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", + "1"))), + # Flag to enable NCCL symmetric memory allocation and registration "VLLM_USE_NCCL_SYMM_MEM": lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))), @@ -1513,6 +1526,8 @@ def compute_hash() -> str: "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", "VLLM_ROCM_FP8_MFMA_PAGE_ATTN", + "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", + "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", ] for key in environment_variables_to_hash: # if this goes out of sync with environment_variables,