diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 4d60899a628a..e459bc539f2b 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest import torch from torch import nn @@ -14,6 +15,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 @@ -65,19 +67,40 @@ def run_model( return output.cpu() -def test_ignore_torch_compile_decorator(): - # vllmcompile +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) +def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatch): + # disable compile cache so that we can count the number of compilations + # appropriately + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + + # piecewise vllm_config = VllmConfig( compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], - use_inductor_graph_partition=False, # TODO test both? + use_inductor_graph_partition=use_inductor_graph_partition, ) ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE + expected_num_graphs_seen = 1 + expected_num_cudagraph_captured = ( + 4 # num_cudagraph_sizes * num cudagraphs to capture + ) + if use_inductor_graph_partition: + expected_num_piecewise_graphs_seen = 1 + expected_num_piecewise_capturable_graphs_seen = 1 + expected_num_backend_compilations = 1 + else: + expected_num_piecewise_graphs_seen = 3 + expected_num_piecewise_capturable_graphs_seen = 2 + expected_num_backend_compilations = 2 + @support_torch_compile class A(nn.Module): def __init__( @@ -104,12 +127,11 @@ def test_ignore_torch_compile_decorator(): # A has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=expected_num_graphs_seen, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) @@ -131,12 +153,11 @@ def test_ignore_torch_compile_decorator(): # C's support_torch_compile should override B's ignore_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=expected_num_graphs_seen, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, + num_cudagraph_captured=expected_num_cudagraph_captured, ): run_model(vllm_config, mod_C, cudagraph_runtime_mode) @@ -179,7 +200,15 @@ class A(nn.Module): return x -def test_conditional_compile_enable_if(): +@pytest.mark.parametrize("use_inductor_graph_partition", [True, False]) +def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch): + # disable compile cache so that we can count the number of compilations + # appropriately + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") + vllm_config = VllmConfig( cache_config=CacheConfig( kv_sharing_fast_prefill=True, @@ -189,7 +218,7 @@ def test_conditional_compile_enable_if(): use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], - use_inductor_graph_partition=False, # TODO test both + use_inductor_graph_partition=use_inductor_graph_partition, ), ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @@ -197,17 +226,26 @@ def test_conditional_compile_enable_if(): with set_current_vllm_config(vllm_config): mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() + if use_inductor_graph_partition: + expected_num_piecewise_graphs_seen = 2 + expected_num_piecewise_capturable_graphs_seen = 2 + expected_num_backend_compilations = 2 + else: + expected_num_piecewise_graphs_seen = 6 + expected_num_piecewise_capturable_graphs_seen = 4 + expected_num_backend_compilations = 4 + # A has support_torch_compile but enable_if fn returns False # enalbe_if will be True for B, so we expect mod1 and mod2 # to be compiled with compilation_counter.expect( num_graphs_seen=2, - num_piecewise_graphs_seen=6, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, # 3 piecewise graphs per instance of B() - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + # num_cudagraph_sizes * num cudagraphable graphs to capture ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) @@ -222,20 +260,30 @@ def test_conditional_compile_enable_if(): use_cudagraph=True, splitting_ops=["silly::attention"], cudagraph_capture_sizes=[1, 2], - use_inductor_graph_partition=False, # TODO test both? + use_inductor_graph_partition=use_inductor_graph_partition, ), ) with set_current_vllm_config(vllm_config): mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() + if use_inductor_graph_partition: + expected_num_piecewise_graphs_seen = 1 + expected_num_piecewise_capturable_graphs_seen = 1 + expected_num_backend_compilations = 1 + else: + # 3 attn ops and 4 non-attn ops + expected_num_piecewise_graphs_seen = 7 + expected_num_piecewise_capturable_graphs_seen = 4 + expected_num_backend_compilations = 4 + with compilation_counter.expect( num_graphs_seen=1, - num_piecewise_graphs_seen=7, + num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, # 3 attn ops and 4 non-attn ops - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, + num_backend_compilations=expected_num_backend_compilations, num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + # num_cudagraph_sizes * num cudagraphable graphs to capture ): run_model(vllm_config, mod_A, cudagraph_runtime_mode)