From ba8c300018e18dd4cdd2b7d904086feec5a79287 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 14 Jul 2025 21:26:18 -0400 Subject: [PATCH] [BugFix] VLLM_DISABLE_COMPILE_CACHE=1 should disable all reads and writes from the cache (#20942) Signed-off-by: Richard Zou --- tests/compile/test_config.py | 24 ++++++++++++++++++++++++ vllm/compilation/backends.py | 3 ++- vllm/compilation/compiler_interface.py | 4 +++- vllm/compilation/counter.py | 4 ++++ 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 8679d5c3019b..0ba59f4b5a05 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -26,6 +26,30 @@ def test_use_cudagraphs_dynamic(monkeypatch): assert not vllm_config.compilation_config.use_cudagraph +# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends +# on the state of the cache directory on the current machine, which +# may be influenced by other tests. +@pytest.mark.parametrize("val", ["1"]) +def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val): + assert vllm.envs.VLLM_USE_V1 + + # spawn means that the counters are in the same process. + monkeypatch.setenv('VLLM_WORKER_MULTIPROC_METHOD', "spawn") + monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val) + + compilation_config = { + "use_cudagraph": False, # speed things up a bit + } + with ( + compilation_counter.expect(num_cache_entries_updated=0, + num_compiled_artifacts_saved=0), + # loading the model causes compilation (if enabled) to happen + vllm_runner('facebook/opt-125m', + compilation_config=compilation_config, + gpu_memory_utilization=0.4) as _): + pass + + @pytest.mark.parametrize("enabled", [True, False]) def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): assert vllm.envs.VLLM_USE_V1 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5148c289d865..673fb5866234 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -183,9 +183,10 @@ class CompilerManager: assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache - if handle is not None: + if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index fd39a6127d00..b529f84b7987 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -213,7 +213,9 @@ class InductorStandaloneAdaptor(CompilerInterface): # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) - compiled_graph.save(path=path, format="unpacked") + if not envs.VLLM_DISABLE_COMPILE_CACHE: + compiled_graph.save(path=path, format="unpacked") + compilation_counter.num_compiled_artifacts_saved += 1 return compiled_graph, (key, path) def load(self, diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 9d7a25689b56..6acb8abb3deb 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -23,6 +23,10 @@ class CompilationCounter: num_inductor_compiles: int = 0 # EagerAdapter.compile calls num_eager_compiles: int = 0 + # The number of time vLLM's compiler cache entry was updated + num_cache_entries_updated: int = 0 + # The number of standalone_compile compiled artifacts saved + num_compiled_artifacts_saved: int = 0 def clone(self) -> "CompilationCounter": return copy.deepcopy(self)