diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 84194f3ed01e..e01b58220959 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -11,6 +11,7 @@ from tests.v1.attention.utils import full_cg_backend_configs as backend_configs from vllm import LLM, SamplingParams from vllm.config import CompilationConfig from vllm.platforms import current_platform +from vllm.utils import is_torch_equal_or_newer @contextlib.contextmanager @@ -32,13 +33,13 @@ def temporary_environ(env_vars): os.environ[k] = v -test_params_full_cudagraph = [] +model_backends_full_cudagraph = [] # deepseek-ai/DeepSeek-V2-Lite with MLA MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"] for mla_backend in MLA_backends: - test_params_full_cudagraph.append( - pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])) + model_backends_full_cudagraph.append( + ("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]) ) # Qwen/Qwen2-1.5B-Instruct with other backends @@ -46,14 +47,18 @@ other_backend_configs = [ backend_configs[c] for c in backend_configs if c not in MLA_backends ] for backend_config in other_backend_configs: - test_params_full_cudagraph.append( - pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)) - ) + model_backends_full_cudagraph.append(("Qwen/Qwen2-1.5B-Instruct", backend_config)) @pytest.fixture(scope="class") def llm_pair(request): - model, backend_config = request.param + model, backend_config, use_inductor_graph_partition = request.param + backend_config.comp_config["use_inductor_graph_partition"] = ( + use_inductor_graph_partition + ) + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition only supported in torch>=2.9") # Dynamically skip test if GPU capability is not met if ( @@ -104,7 +109,15 @@ def llm_pair(request): ) -@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True) +@pytest.mark.parametrize( + "llm_pair", + [ + pytest.param((model, backend_config, use_inductor_graph_partition)) + for model, backend_config in model_backends_full_cudagraph + for use_inductor_graph_partition in [True, False] + ], + indirect=True, +) class TestFullCUDAGraph: """ Use a class such that an llm pair is constructed once for all diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index d88645e3bfd6..0d265bc59638 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -5,6 +5,7 @@ Test (piecewise) compilation with a simple model where multiple submodules are compiled and graph captured separately. """ +import pytest import torch from torch import nn @@ -190,7 +191,12 @@ def run_model( return output.cpu() -def test_multi_graph_piecewise_compile_outputs_equal(): +@pytest.mark.parametrize("use_inductor_graph_partition", [False, True]) +def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool): + if use_inductor_graph_partition: + # FIXME(luka/boyuan): this currently fails + pytest.skip("Inductor graph partition not supported with multi-graph") + outputs = [] # piecewise compile @@ -200,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], + use_inductor_graph_partition=use_inductor_graph_partition, ) ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @@ -220,16 +227,24 @@ def test_multi_graph_piecewise_compile_outputs_equal(): # static tensor addresses inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() - with compilation_counter.expect( - num_graphs_seen=2, # two graphs for the model - num_piecewise_graphs_seen=6, + if use_inductor_graph_partition: + # Splitting happens at Inductor lowering level, + # total piecewise fx graphs is equal to total graphs + num_piecewise_fx = 2 + num_piecewise_capturable_fx = 2 + else: # attn_one, attn_two each has 3 piecewise graphs # (pre attn, post attn, silly_attention) each - num_piecewise_capturable_graphs_seen=4, + num_piecewise_fx = 6 # attn_one, attn_two has pre attn and post attn each, total=4 - num_backend_compilations=4, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_piecewise_capturable_fx = 4 + + with compilation_counter.expect( + num_graphs_seen=2, # two graphs for the model + num_piecewise_graphs_seen=num_piecewise_fx, + num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx, + num_backend_compilations=num_piecewise_capturable_fx, + num_cudagraph_captured=8, # num_cudagraph_sizes * num_partitions ): outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) @@ -268,6 +283,7 @@ def test_multi_graph_piecewise_compile_outputs_equal(): level=CompilationLevel.PIECEWISE, use_cudagraph=False, splitting_ops=["silly::attention"], + use_inductor_graph_partition=use_inductor_graph_partition, ) ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @@ -286,9 +302,9 @@ def test_multi_graph_piecewise_compile_outputs_equal(): with compilation_counter.expect( num_graphs_seen=2, - num_piecewise_graphs_seen=6, - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, + num_piecewise_graphs_seen=num_piecewise_fx, + num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx, + num_backend_compilations=num_piecewise_capturable_fx, num_cudagraph_captured=0, # no cudagraph captured ): outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index eaf0a15479e9..7ab610fa7811 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -9,6 +9,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are initialized randomly with a fixed seed. """ +from copy import deepcopy from dataclasses import dataclass from typing import Any @@ -26,6 +27,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 @@ -257,27 +259,13 @@ def tractable_computation( @torch.inference_mode -def run_model( - llama_config, use_compile: bool, backend: str, split_attn: bool = False -) -> torch.Tensor: - if use_compile: - compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - backend=backend, - cudagraph_capture_sizes=[1, 2], - ) - if split_attn: - compilation_config.splitting_ops = ["silly::attention"] - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - compilation_config = CompilationConfig( - level=CompilationLevel.NO_COMPILATION, - ) - cudagraph_runtime_mode = CUDAGraphMode.NONE +def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor: + # Start with a fresh copy to make sure there's no cache dir sharing + compile_config = deepcopy(compile_config) + cudagraph_runtime_mode = compile_config.cudagraph_mode vllm_config = VllmConfig( - compilation_config=compilation_config, additional_config=llama_config + compilation_config=compile_config, additional_config=llama_config ) with set_current_vllm_config(vllm_config): model = ( @@ -338,8 +326,25 @@ def run_model( return output.cpu() -@pytest.mark.parametrize("backend", ["inductor", "eager"]) -def test_toy_llama(backend: str): +@pytest.mark.parametrize( + "backend, use_inductor_graph_partition", + [ + ("eager", False), # No inductor + ("inductor", False), # Inductor, Dynamo partition + ("inductor", True), # Inductor, Inductor partition + ], +) +def test_toy_llama( + backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path +): + # We disable the vLLM compile cache into a new tmp dir for 2 reasons: + # 1. To make sure we can properly track the number of Inductor compilations. + # 2. Inductor partitioning does not play nicely with Autograd cache (below) + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("Inductor graph partition only supported in torch>=2.9") + # compare output with and without piecewise compilation llama_config = LlamaConfig( @@ -350,6 +355,32 @@ def test_toy_llama(backend: str): hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True ) + compile_config_no_compile = CompilationConfig( + level=CompilationLevel.NO_COMPILATION, + cudagraph_mode=CUDAGraphMode.NONE, + backend="eager", + ) + + compile_config_no_split = CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_inductor_graph_partition=use_inductor_graph_partition, + cudagraph_mode=CUDAGraphMode.PIECEWISE, + backend=backend, + cudagraph_capture_sizes=[1, 2], + ) + + # FIXME(luka/boyuan): the graph from the previous test case + # (no inductor partition) gets cached by AotAutograd so then the + # compilation with inductor partitioning incorrectly loads an unpartitioned + # graph and never partitions. I think this is a bug with custom inductor + # partitioning but does not affect vLLM more generally as vLLM uses its own + # cache (which takes inductor partitioning into account). + if use_inductor_graph_partition: + compile_config_no_split.inductor_compile_config["force_disable_caches"] = True + + compile_config_split = deepcopy(compile_config_no_split) + compile_config_split.splitting_ops = ["silly::attention"] + outputs = [] with compilation_counter.expect( num_graphs_seen=0, @@ -358,8 +389,9 @@ def test_toy_llama(backend: str): num_backend_compilations=0, num_cudagraph_captured=0, ): - outputs.append(run_model(llama_config, backend="eager", use_compile=False)) - run_model(tractable_config, backend="eager", use_compile=False) + outputs.append(run_model(llama_config, compile_config_no_compile)) + + run_model(tractable_config, compile_config_no_compile) if backend == "inductor": kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0} @@ -367,35 +399,34 @@ def test_toy_llama(backend: str): kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} with compilation_counter.expect( - # One graph for the model - num_graphs_seen=1, + num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=1, num_piecewise_capturable_graphs_seen=1, - # num_piecewise_capturable_graphs_seen - num_backend_compilations=1, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_backend_compilations=1, # num_piecewise_capturable_graphs_seen num_cudagraph_captured=2, **kwargs, ): - outputs.append(run_model(llama_config, backend=backend, use_compile=True)) - run_model(tractable_config, backend=backend, use_compile=True) + outputs.append(run_model(llama_config, compile_config_no_split)) + + run_model(tractable_config, compile_config_no_split) + + if use_inductor_graph_partition: + num_piecewise_fx = 1 + num_piecewise_capturable_fx = 1 + else: + num_piecewise_fx = 2 * llama_config.num_layers + 1 + num_piecewise_capturable_fx = 1 + llama_config.num_layers with compilation_counter.expect( num_graphs_seen=1, # one graph for the model - num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1 - num_piecewise_capturable_graphs_seen=1 - + llama_config.num_layers, # 1 + num_layers - num_backend_compilations=1 - + llama_config.num_layers, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured=2 - * ( - 1 + llama_config.num_layers - ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_piecewise_graphs_seen=num_piecewise_fx, + num_piecewise_capturable_graphs_seen=num_piecewise_capturable_fx, + num_backend_compilations=num_piecewise_capturable_fx, + # num_cudagraph_sizes * num_partitions + num_cudagraph_captured=2 * (1 + llama_config.num_layers), ): - outputs.append( - run_model(llama_config, backend=backend, use_compile=True, split_attn=True) - ) - run_model(tractable_config, backend=backend, use_compile=True, split_attn=True) + outputs.append(run_model(llama_config, compile_config_split)) + run_model(tractable_config, compile_config_split) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py index c0d3f908149f..f33c5772906a 100644 --- a/tests/compile/silly_attention.py +++ b/tests/compile/silly_attention.py @@ -62,5 +62,4 @@ direct_register_custom_op( mutates_args=["out"], fake_impl=silly_attention_fake, target_lib=silly_lib, - tags=(torch._C.Tag.cudagraph_unsafe,), ) diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 6b050207ec41..63cb266094a1 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -73,6 +73,7 @@ def test_ignore_torch_compile_decorator(): use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], + use_inductor_graph_partition=False, # TODO test both? ) ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @@ -188,6 +189,7 @@ def test_conditional_compile_enable_if(): use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], + use_inductor_graph_partition=False, # TODO test both ), ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @@ -220,6 +222,7 @@ def test_conditional_compile_enable_if(): use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], + use_inductor_graph_partition=False, # TODO test both? ), ) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index fe9de65b52c6..8b5b87cba404 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -38,10 +38,6 @@ from vllm.utils import GiB_bytes, direct_register_custom_op logger = init_logger(__name__) USE_XFORMERS_OPS = None -try: - tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,) -except AttributeError: - tag_cudagraph_unsafe = () # type: ignore[assignment] def check_xformers_availability(): @@ -879,7 +875,6 @@ direct_register_custom_op( op_name="unified_attention", op_func=unified_attention, fake_impl=unified_attention_fake, - tags=tag_cudagraph_unsafe, ) @@ -931,7 +926,6 @@ direct_register_custom_op( op_func=unified_attention_with_output, mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, - tags=tag_cudagraph_unsafe, ) diff --git a/vllm/compilation/partition_rules.py b/vllm/compilation/partition_rules.py index 5ea1b30860f5..cea4f9a81637 100644 --- a/vllm/compilation/partition_rules.py +++ b/vllm/compilation/partition_rules.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import logging from typing import TYPE_CHECKING from torch._library.utils import lookup_op @@ -38,8 +39,16 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]: resolved.append(lookup_op(op_name)) except Exception: # Skip operators that don't exist (e.g., model-specific ops) - logger.warning( - "Failed to resolve operator for Inductor partition: %s", op_name + # Do not warn for attention ops, warn for others + # (most likely manually specified) + from vllm.config import CompilationConfig + + logger.log( + logging.DEBUG + if op_name in CompilationConfig._attention_ops + else logging.WARNING, + "Failed to resolve operator for CUDAGraph partition: %s", + op_name, ) continue diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 60aef2f6f7e1..fb80835ba48a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -201,7 +201,7 @@ class CompilationConfig: (it sees a part of the graph). The backend can not be custom for compilation level 3, i.e. the backend must be either eager or inductor. Furthermore, compilation is only piecewise if splitting ops is set accordingly and - use_inductor_cudagraphs_partition is off. Note that the default options for + use_inductor_graph_partition is off. Note that the default options for splitting ops are sufficient for piecewise compilation. """ custom_ops: list[str] = field(default_factory=list) @@ -431,6 +431,7 @@ class CompilationConfig: factors.append(self.custom_ops) factors.append(self.splitting_ops) factors.append(self.use_inductor) + factors.append(self.use_inductor_graph_partition) factors.append(self.inductor_compile_config) factors.append(self.inductor_passes) factors.append(self.pass_config.uuid())