From 762be26a8ee0de15638fa21a59d85efedacec847 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Fri, 11 Jul 2025 03:15:22 -0400 Subject: [PATCH] [Bugfix] Upgrade depyf to 0.19 and streamline custom pass logging (#20777) Signed-off-by: Luka Govedic Signed-off-by: luka --- requirements/common.txt | 2 +- tests/compile/test_full_graph.py | 6 ++++++ vllm/compilation/vllm_inductor_pass.py | 28 ++++---------------------- vllm/config.py | 13 ++---------- 4 files changed, 13 insertions(+), 36 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index 0af7478daa8b6..f97fe35d28b30 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -40,7 +40,7 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. compressed-tensors == 0.10.2 # required for compressed-tensors -depyf==0.18.0 # required for profiling and debugging with compilation config +depyf==0.19.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/others/logging_configuration.md diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 1d000fe00c598..72f962ed7484c 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -3,6 +3,7 @@ from __future__ import annotations +import tempfile from typing import Any, Optional, Union import pytest @@ -111,6 +112,11 @@ def test_full_graph( pass_config=PassConfig(enable_fusion=True, enable_noop=True)), model) for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) + ] + [ + # Test depyf integration works + (CompilationConfig(level=CompilationLevel.PIECEWISE, + debug_dump_path=tempfile.gettempdir()), + ("facebook/opt-125m", {})), ]) # only test some of the models @create_new_process_for_each_test() diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 628e9e204c552..b822b05b0f1ec 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -6,13 +6,7 @@ import time import torch from torch._dynamo.utils import lazy_format_graph_code -from vllm.config import PassConfig, VllmConfig -# yapf: disable -from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank -from vllm.distributed import ( - get_tensor_model_parallel_world_size as get_tp_world_size) -from vllm.distributed import model_parallel_is_initialized as p_is_init -# yapf: enable +from vllm.config import VllmConfig from vllm.logger import init_logger from .inductor_pass import InductorPass @@ -34,22 +28,9 @@ class VllmInductorPass(InductorPass): else None self.pass_name = self.__class__.__name__ - def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): + def dump_graph(self, graph: torch.fx.Graph, stage: str): lazy_format_graph_code(stage, graph.owning_module) - if stage in self.pass_config.dump_graph_stages or always: - # Make sure filename includes rank in the distributed setting - parallel = p_is_init() and get_tp_world_size() > 1 - rank = f"-{get_tp_rank()}" if parallel else "" - filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py" - - logger.info("%s printing graph to %s", self.pass_name, filepath) - with open(filepath, "w") as f: - src = graph.python_code(root_module="self", verbose=True).src - # Add imports so it's not full of errors - print("import torch; from torch import device", file=f) - print(src, file=f) - def begin(self): self._start_time = time.perf_counter_ns() @@ -61,10 +42,9 @@ class VllmInductorPass(InductorPass): class PrinterInductorPass(VllmInductorPass): - def __init__(self, name: str, config: PassConfig, always=False): + def __init__(self, name: str, config: VllmConfig): super().__init__(config) self.name = name - self.always = always def __call__(self, graph: torch.fx.Graph): - self.dump_graph(graph, self.name, always=self.always) + self.dump_graph(graph, self.name) diff --git a/vllm/config.py b/vllm/config.py index ad40fcba4f065..b1f7f9e57a79b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,7 +16,6 @@ from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, replace) from functools import cached_property from importlib.util import find_spec -from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, Protocol, TypeVar, Union, cast, get_args) @@ -3953,11 +3952,6 @@ class PassConfig: don't all have access to full configuration - that would create a cycle as the `PassManager` is set as a property of config.""" - dump_graph_stages: list[str] = field(default_factory=list) - """List of stages for which we want to dump the graph. Each pass defines - its own stages (before, after, maybe in-between).""" - dump_graph_dir: Path = Path(".") - """Directory to dump the graphs.""" enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" enable_attn_fusion: bool = False @@ -3975,12 +3969,9 @@ class PassConfig: """ Produces a hash unique to the pass configuration. Any new fields that affect compilation should be added to the hash. - Do not include dump_graph_* in the hash - they don't affect - compilation. + Any future fields that don't affect compilation should be excluded. """ - exclude = {"dump_graph_stages", "dump_graph_dir"} - dict_ = {k: v for k, v in asdict(self).items() if k not in exclude} - return InductorPass.hash_dict(dict_) + return InductorPass.hash_dict(asdict(self)) def __post_init__(self) -> None: if not self.enable_noop: