diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index c65e5a25934d2..8fa305d6d72f5 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools +import multiprocessing import tempfile from contextlib import contextmanager @@ -137,3 +139,67 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch): artifacts = compiled_mod.aot_compiled_fn._artifacts guards_string = artifacts.compiled_fn.shape_env.format_guards() assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +@use_vllm_config(make_vllm_config()) +def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch): + """ + Test that compiling gpt2 twice results in a cache hit and + capture torch dynamic symbol creations to ensure make_symbol + not called on cache hit. + """ + + import torch.fx.experimental.symbolic_shapes as symbolic_shapes_module + from torch.utils._sympy.symbol import make_symbol + + from vllm import LLM + + create_symbol_counter = multiprocessing.Value("i", 0) + original_make_symbol = make_symbol + + @functools.wraps(original_make_symbol) + def counting_make_symbol(prefix, idx, **kwargs): + with create_symbol_counter.get_lock(): + create_symbol_counter.value += 1 + return original_make_symbol(prefix, idx, **kwargs) + + symbolic_shapes_module.make_symbol = counting_make_symbol + try: + with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname: + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + # First compilation - initialize model and generate + llm_model = LLM( + model="gpt2", + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + ), + max_model_len=256, + ) + + llm_model.generate("Hello, my name is") + assert create_symbol_counter.value == 2 + create_symbol_counter.value = 0 + + # Clean up first model + del llm_model + + # Second compilation - should hit cache + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + llm_model = LLM( + model="gpt2", + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + ), + max_model_len=256, + ) + llm_model.generate("Hello, my name is") + + assert create_symbol_counter.value == 0 + + finally: + # Restore original method + symbolic_shapes_module.make_symbol = original_make_symbol diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1773913d0b6c6..b5b7fe2b76c27 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -402,6 +402,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): self.extra_traceback = False def run(self, *args): + # maybe instead just assert inputs are fake? fake_args = [ self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args @@ -416,11 +417,13 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): kwargs: dict[str, Any], ) -> Any: assert isinstance(target, str) + output = super().call_module(target, args, kwargs) if target in self.compile_submod_names: index = self.compile_submod_names.index(target) submod = self.fetch_attr(target) + sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] @@ -746,11 +749,21 @@ class VllmBackend: if not item.is_splitting_graph ] + # Extract fake values from the graph to use them when needed. + all_fake_values = [] + for i in graph.graph.find_nodes(op="placeholder"): + all_fake_values.append(i.meta["example_value"]) + + fake_args = [ + all_fake_values[i] if isinstance(t, torch.Tensor) else t + for i, t in enumerate(example_inputs) + ] + # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter( self.split_gm, submod_names_to_compile, self.vllm_config, self - ).run(*example_inputs) + ).run(*fake_args) graph_path = os.path.join(local_cache_dir, "computation_graph.py") if not os.path.exists(graph_path): @@ -780,14 +793,7 @@ class VllmBackend: ) # if we need to copy input buffers for cudagraph - from torch._guards import detect_fake_mode - - fake_mode = detect_fake_mode() - fake_args = [ - fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t - for t in example_inputs - ] - + # # index of tensors that have symbolic shapes (batch size) # for weights and static buffers, they will have concrete shapes. # symbolic shape only happens for input tensors. diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 6d9da1c488c6d..eed7795cdb349 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -433,7 +433,6 @@ def _support_torch_compile( return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # This is the path for the first compilation. - # the first compilation needs to have dynamic shapes marked _mark_dynamic_inputs( self,