diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 75a89d692fa8f..500cca87d96ed 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -337,9 +337,8 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor: def test_toy_llama( backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path ): - # We disable the vLLM compile cache into a new tmp dir for 2 reasons: + # We disable the vLLM compile cache into a new tmp dir for 1 reason: # 1. To make sure we can properly track the number of Inductor compilations. - # 2. Inductor partitioning does not play nicely with Autograd cache (below) monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): @@ -369,15 +368,6 @@ def test_toy_llama( cudagraph_capture_sizes=[1, 2], ) - # FIXME(luka/boyuan): the graph from the previous test case - # (no inductor partition) gets cached by AotAutograd so then the - # compilation with inductor partitioning incorrectly loads an unpartitioned - # graph and never partitions. I think this is a bug with custom inductor - # partitioning but does not affect vLLM more generally as vLLM uses its own - # cache (which takes inductor partitioning into account). - if use_inductor_graph_partition: - compile_config_no_split.inductor_compile_config["force_disable_caches"] = True - compile_config_split = deepcopy(compile_config_no_split) compile_config_split.splitting_ops = ["silly::attention"] diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 55fe235e2d2c1..343297e944684 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -110,6 +110,27 @@ class PostGradPassManager(CustomGraphPass): self.post_cleanup = PostCleanupPass(config) self.fix_functionalization = FixFunctionalizationPass(config) + # [HACK: Bug with Inductor graph partition and torch.compile cache] + # In PyTorch 2.9, torch.compile has a bug where the graph + # partition is not taken into account during caching. + # Because vLLM's Mode.VLLM_COMPILE is the only mode that uses + # Inductor graph partition, and VLLM_COMPILE implies there + # is a PostGradPassManager, we put the list of operators to graph + # partition into the PostGradPassManager's uuid (which + # then gets incorporated into Inductor's FX graph cache key). + # Remove this hack whenever torch.compile fixes it. + + # This is the list of operators that vLLM asks Inductor to split. + self.inductor_splitting_ops = [] + if ( + config.compilation_config.use_inductor_graph_partition + and config.compilation_config.splitting_ops is not None + ): + # Sort them so we're not dependent on the ordering. + self.inductor_splitting_ops = sorted( + config.compilation_config.splitting_ops + ) + def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) self.passes.append(pass_) @@ -120,8 +141,16 @@ class PostGradPassManager(CustomGraphPass): affects compilation caching. Its uuid depends on the UUIDs of all dependent passes and the pass config. See InductorPass for more info. """ - state = {"pass_config": self.pass_config.uuid(), "passes": []} + state = { + "pass_config": self.pass_config.uuid(), + "passes": [], + "inductor_splitting_ops": [], + } for pass_ in self.passes: state["passes"].append(pass_.uuid()) state["passes"].append(self.fix_functionalization.uuid()) + + # See [HACK: Bug with Inductor graph partition and torch.compile cache] + state["inductor_splitting_ops"].extend(self.inductor_splitting_ops) + return InductorPass.hash_dict(state)