[torch.compile]: Add VLLM_DEBUG_DUMP_PATH environment variable (#25651)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Jiangyun Zhu 2025-09-28 00:09:00 +08:00 committed by GitHub
parent b65e56babe
commit c0ec81836f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 44 additions and 17 deletions

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time import time
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
@ -18,13 +17,12 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
torch_compile_start_time = time.time() torch_compile_start_time = time.time()
compilation_config: CompilationConfig = vllm_config.compilation_config compilation_config: CompilationConfig = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.PIECEWISE and \ path = vllm_config.compile_debug_dump_path()
compilation_config.debug_dump_path: if compilation_config.level == CompilationLevel.PIECEWISE and path:
import depyf import depyf
path = os.path.join(compilation_config.debug_dump_path, path.mkdir(parents=True, exist_ok=True)
f"rank_{vllm_config.parallel_config.rank}")
global context_manager global context_manager
context_manager = depyf.prepare_debug(path) context_manager = depyf.prepare_debug(path.as_posix())
context_manager.__enter__() context_manager.__enter__()

View File

@ -3,7 +3,6 @@
import functools import functools
import operator import operator
import time import time
from pathlib import Path
from typing import ClassVar, Optional from typing import ClassVar, Optional
import regex as re import regex as re
@ -96,12 +95,10 @@ class VllmPatternMatcherPass(VllmInductorPass):
TODO(luka): use pattern object to manually produce pattern graph TODO(luka): use pattern object to manually produce pattern graph
""" """
debug_dump_path = config.compilation_config.debug_dump_path debug_dump_path = config.compile_debug_dump_path()
if not debug_dump_path: if not debug_dump_path:
return return
rank = config.parallel_config.rank
debug_dump_path = Path(debug_dump_path) / f"rank_{rank}"
debug_dump_path.mkdir(parents=True, exist_ok=True) debug_dump_path.mkdir(parents=True, exist_ok=True)
from vllm.utils import unique_filepath from vllm.utils import unique_filepath

View File

@ -92,12 +92,11 @@ class TorchCompileWrapperWithCustomDispatcher:
return return
self.compiled_codes.append(new_code) self.compiled_codes.append(new_code)
debug_dump_dir = self.vllm_config.compilation_config.debug_dump_path
if isinstance(debug_dump_dir, str) and debug_dump_dir != "": path = self.vllm_config.compile_debug_dump_path()
rank = self.vllm_config.parallel_config.rank if path:
decompiled_file = os.path.join(debug_dump_dir, f"rank_{rank}", decompiled_file = path / "transformed_code.py"
"transformed_code.py") if not decompiled_file.exists():
if not os.path.exists(decompiled_file):
try: try:
# usually the decompilation will succeed for most models, # usually the decompilation will succeed for most models,
# as we guarantee a full-graph compilation in Dynamo. # as we guarantee a full-graph compilation in Dynamo.

View File

@ -12,6 +12,7 @@ import textwrap
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import field, fields, is_dataclass, replace from dataclasses import field, fields, is_dataclass, replace
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Literal, Optional, Protocol, TypeVar, from typing import (TYPE_CHECKING, Any, Literal, Optional, Protocol, TypeVar,
Union, cast) Union, cast)
@ -541,6 +542,17 @@ class VllmConfig:
# local attention. # local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = \
self.compilation_config.debug_dump_path.absolute().expanduser()
if envs.VLLM_DEBUG_DUMP_PATH is not None:
env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
if self.compilation_config.debug_dump_path:
logger.warning(
"Config-specified debug dump path is overridden"
" by VLLM_DEBUG_DUMP_PATH to %s", env_path)
self.compilation_config.debug_dump_path = env_path
def update_sizes_for_sequence_parallelism(self, def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list: possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when # remove the sizes that not multiple of tp_size when
@ -672,6 +684,20 @@ class VllmConfig:
f"but got '{self.load_config.load_format}'. " f"but got '{self.load_config.load_format}'. "
f"Model: {self.model_config.model}") f"Model: {self.model_config.model}")
def compile_debug_dump_path(self) -> Optional[Path]:
"""Returns a rank-aware path for dumping
torch.compile debug information.
"""
if self.compilation_config.debug_dump_path is None:
return None
tp_rank = self.parallel_config.rank
dp_rank = self.parallel_config.data_parallel_rank
data_parallel_size = self.parallel_config.data_parallel_size
append_path = f"rank_{tp_rank}" if data_parallel_size == 1 \
else f"rank_{tp_rank}_dp_{dp_rank}"
path = self.compilation_config.debug_dump_path / append_path
return path
def __str__(self): def __str__(self):
return ( return (
f"model={self.model_config.model!r}, " f"model={self.model_config.model!r}, "

View File

@ -5,6 +5,7 @@ import enum
import hashlib import hashlib
from collections import Counter from collections import Counter
from dataclasses import asdict, field from dataclasses import asdict, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union
from pydantic import TypeAdapter, field_validator from pydantic import TypeAdapter, field_validator
@ -169,7 +170,7 @@ class CompilationConfig:
- 1: dynamo as is. - 1: dynamo as is.
- 2: dynamo once. - 2: dynamo once.
- 3: piecewise compilation.""" - 3: piecewise compilation."""
debug_dump_path: str = "" debug_dump_path: Optional[Path] = None
"""The path to dump the debug information.""" """The path to dump the debug information."""
cache_dir: str = "" cache_dir: str = ""
"""The directory to store the compiled graph, to accelerate Inductor """The directory to store the compiled graph, to accelerate Inductor

View File

@ -199,6 +199,7 @@ if TYPE_CHECKING:
VLLM_DBO_COMM_SMS: int = 20 VLLM_DBO_COMM_SMS: int = 20
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
VLLM_DEBUG_DUMP_PATH: Optional[str] = None
VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True
VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING: bool = True
VLLM_USE_NCCL_SYMM_MEM: bool = False VLLM_USE_NCCL_SYMM_MEM: bool = False
@ -513,6 +514,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_PATTERN_MATCH_DEBUG": "VLLM_PATTERN_MATCH_DEBUG":
lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None), lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None),
# Dump fx graphs to the given directory.
# It will override CompilationConfig.debug_dump_path if set.
"VLLM_DEBUG_DUMP_PATH":
lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
"LOCAL_RANK": "LOCAL_RANK":