diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 143cb49697f5..5ce520a44025 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -74,11 +74,12 @@ class SillyModel(nn.Module): return x -def test_simple_piecewise_compile(): +def _test_simple_piecewise_compile(*, use_inductor): vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, + use_inductor=use_inductor, splitting_ops=["silly.attention"], cudagraph_copy_inputs=True, cudagraph_capture_sizes=[1, 2], @@ -108,3 +109,11 @@ def test_simple_piecewise_compile(): output = model(input) assert global_counter == 2 assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + + +def test_simple_piecewise_compile_inductor(): + _test_simple_piecewise_compile(use_inductor=True) + + +def test_simple_piecewise_compile_no_inductor(): + _test_simple_piecewise_compile(use_inductor=False) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index d4551b1cc3ae..22560befcbd5 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -261,12 +261,14 @@ def tractable_computation(input_ids: torch.Tensor, @torch.inference_mode def run_model(llama_config, use_compile: bool, + use_inductor: bool, split_attn: bool = False) -> torch.Tensor: if use_compile: compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, + use_inductor=use_inductor, cudagraph_capture_sizes=[1, 2], ) if split_attn: @@ -304,7 +306,7 @@ def run_model(llama_config, return output.cpu() -def test_toy_llama(): +def _test_toy_llama(*, use_inductor): # compare output with and without piecewise compilation llama_config = LlamaConfig(hidden_size=128, @@ -326,8 +328,14 @@ def test_toy_llama(): num_backend_compilations=0, num_cudagraph_caputured=0, ): - outputs.append(run_model(llama_config, use_compile=False)) - run_model(tractable_config, use_compile=False) + outputs.append( + run_model(llama_config, use_inductor=False, use_compile=False)) + run_model(tractable_config, use_inductor=False, use_compile=False) + + if use_inductor: + kwargs = {"num_inductor_compiles": 1, "num_eager_compiles": 0} + else: + kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} with compilation_counter.expect( num_graphs_seen=1, # one graph for the model @@ -336,9 +344,13 @@ def test_toy_llama(): num_backend_compilations=1, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured= 2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + **kwargs, ): - outputs.append(run_model(llama_config, use_compile=True)) - run_model(tractable_config, use_compile=True) + outputs.append( + run_model(llama_config, + use_inductor=use_inductor, + use_compile=True)) + run_model(tractable_config, use_inductor=use_inductor, use_compile=True) with compilation_counter.expect( num_graphs_seen=1, # one graph for the model @@ -353,13 +365,27 @@ def test_toy_llama(): ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): outputs.append( - run_model(llama_config, use_compile=True, split_attn=True)) - run_model(tractable_config, use_compile=True, split_attn=True) + run_model(llama_config, + use_inductor=use_inductor, + use_compile=True, + split_attn=True)) + run_model(tractable_config, + use_inductor=use_inductor, + use_compile=True, + split_attn=True) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) +def test_toy_llama_inductor(): + _test_toy_llama(use_inductor=True) + + +def test_toy_no_inductor(): + _test_toy_llama(use_inductor=False) + + @torch.inference_mode def benchmark(): from triton.testing import do_bench diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 21af5eb76ee8..7e9186f8613c 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -12,6 +12,7 @@ import torch._inductor.compile_fx import torch.fx as fx import vllm.envs as envs +from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig from vllm.utils import is_torch_equal_or_newer @@ -175,6 +176,7 @@ class InductorStandaloneAdaptor(CompilerInterface): runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: current_config.update(compiler_config) @@ -262,6 +264,7 @@ class InductorAdaptor(CompilerInterface): runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_inductor_compiles += 1 from torch._inductor.compile_fx import compile_fx current_config = {} if compiler_config is not None: @@ -528,6 +531,7 @@ class EagerAdaptor(CompilerInterface): runtime_shape: Optional[int] = None, key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: + compilation_counter.num_eager_compiles += 1 # we don't need to compile the graph, just return the graph itself. # It does not support caching, return None for the handle. return graph, None diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 5be452593c62..2200671b8848 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -15,6 +15,10 @@ class CompilationCounter: num_piecewise_capturable_graphs_seen: int = 0 num_backend_compilations: int = 0 num_cudagraph_caputured: int = 0 + # InductorAdapter.compile calls + num_inductor_compiles: int = 0 + # EagerAdapter.compile calls + num_eager_compiles: int = 0 def clone(self) -> "CompilationCounter": return copy.deepcopy(self) diff --git a/vllm/config.py b/vllm/config.py index fe2ad70f5aac..3172cbe454f6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4315,15 +4315,10 @@ class VllmConfig: self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ not self.model_config.enforce_eager: - # NOTE(woosuk): Currently, we use inductor because the piecewise - # CUDA graphs do not work properly with the custom CUDA kernels. - # FIXME(woosuk): Disable inductor to reduce the compilation time - # and avoid any potential issues with the inductor. # FIXME(rob): Add function to set all of these. if not self.compilation_config.custom_ops: self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True - self.compilation_config.use_inductor = True self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_noop = False