mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 15:25:28 +08:00
Add VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE & VLLM_ENABLE_INDUCTOR_COORDINA… (#25493)
Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com>
This commit is contained in:
parent
969b4da3a6
commit
eca7be9077
@ -551,8 +551,9 @@ def set_inductor_config(config, runtime_shape):
|
|||||||
if isinstance(runtime_shape, int):
|
if isinstance(runtime_shape, int):
|
||||||
# for a specific batchsize, tuning triton kernel parameters
|
# for a specific batchsize, tuning triton kernel parameters
|
||||||
# can be beneficial
|
# can be beneficial
|
||||||
config["max_autotune"] = True
|
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||||
config["coordinate_descent_tuning"] = True
|
config["coordinate_descent_tuning"] = (
|
||||||
|
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING)
|
||||||
|
|
||||||
|
|
||||||
class EagerAdaptor(CompilerInterface):
|
class EagerAdaptor(CompilerInterface):
|
||||||
|
|||||||
15
vllm/envs.py
15
vllm/envs.py
@ -193,6 +193,8 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DBO_COMM_SMS: int = 20
|
VLLM_DBO_COMM_SMS: int = 20
|
||||||
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
||||||
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||||
|
VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True
|
||||||
|
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
|
||||||
VLLM_USE_NCCL_SYMM_MEM: bool = False
|
VLLM_USE_NCCL_SYMM_MEM: bool = False
|
||||||
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
|
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
|
||||||
|
|
||||||
@ -1413,6 +1415,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"code_interpreter",
|
"code_interpreter",
|
||||||
"web_search_preview"]),
|
"web_search_preview"]),
|
||||||
|
|
||||||
|
# Enable max_autotune & coordinate_descent_tuning in inductor_config
|
||||||
|
# to compile static shapes passed from compile_sizes in compilation_config
|
||||||
|
# If set to 1, enable max_autotune; By default, this is enabled (1)
|
||||||
|
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "1"))),
|
||||||
|
# If set to 1, enable coordinate_descent_tuning;
|
||||||
|
# By default, this is enabled (1)
|
||||||
|
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||||
|
"1"))),
|
||||||
|
|
||||||
# Flag to enable NCCL symmetric memory allocation and registration
|
# Flag to enable NCCL symmetric memory allocation and registration
|
||||||
"VLLM_USE_NCCL_SYMM_MEM":
|
"VLLM_USE_NCCL_SYMM_MEM":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))),
|
lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))),
|
||||||
@ -1513,6 +1526,8 @@ def compute_hash() -> str:
|
|||||||
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
|
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
|
||||||
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
|
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
|
||||||
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
|
||||||
|
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
|
||||||
|
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
|
||||||
]
|
]
|
||||||
for key in environment_variables_to_hash:
|
for key in environment_variables_to_hash:
|
||||||
# if this goes out of sync with environment_variables,
|
# if this goes out of sync with environment_variables,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user