diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 60856f5a5806..cce99d0c4f4c 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -341,6 +341,15 @@ def async_tp_pass_on_test_model( async_tp_pass = AsyncTPPass(vllm_config) backend = TestBackend(async_tp_pass) + assert ( + async_tp_pass.compilation_config.splitting_ops + == vllm_config.compilation_config.splitting_ops + ) + assert ( + async_tp_pass.compilation_config.use_inductor_graph_partition + == vllm_config.compilation_config.use_inductor_graph_partition + ) + model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor hidden_states = torch.randn( diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 7f51c763da73..87b5d167d168 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy + import pytest from vllm.compilation.counter import compilation_counter +from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config.compilation import CompilationMode from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer @@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic(): assert vllm_config.compilation_config.use_cudagraph +def test_copy_pass(): + vllm_config = VllmConfig() + inductor_pass = FixFunctionalizationPass(vllm_config) + copied_inductor_pass = copy.deepcopy(inductor_pass) + assert ( + copied_inductor_pass.compilation_config.use_inductor_graph_partition + == vllm_config.compilation_config.use_inductor_graph_partition + ) + assert ( + copied_inductor_pass.compilation_config.splitting_ops + == vllm_config.compilation_config.splitting_ops + ) + + def test_custom_op(): # proper syntax _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"]) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 6abab88e6369..9969a629c008 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model( noop_pass = NoOpEliminationPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config) + assert ( + sequence_parallelism_pass.compilation_config.splitting_ops + == vllm_config.compilation_config.splitting_ops + ) + assert ( + sequence_parallelism_pass.compilation_config.use_inductor_graph_partition + == vllm_config.compilation_config.use_inductor_graph_partition + ) func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index beac928b5d71..7ef2dddcb407 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -3,7 +3,7 @@ import functools import operator import time -import weakref +from dataclasses import dataclass from typing import ClassVar import regex as re @@ -19,6 +19,12 @@ from .inductor_pass import InductorPass logger = init_logger(__name__) +@dataclass +class InductorCompilationConfig: + splitting_ops: list[str] | None = None + use_inductor_graph_partition: bool = False + + class VllmInductorPass(InductorPass): """ An inductor pass with access to vLLM PassConfig. @@ -29,7 +35,12 @@ class VllmInductorPass(InductorPass): """Keep track of pass index for debug dump ordering.""" def __init__(self, config: VllmConfig): - self.compilation_config = weakref.proxy(config.compilation_config) + # Get only the necessary CompilationConfig for the inductor pass, since + # full `CompilationConfig` contains pointer to model which is unsafe. + self.compilation_config = InductorCompilationConfig( + splitting_ops=config.compilation_config.splitting_ops, + use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition, + ) self.pass_config = config.compilation_config.pass_config self.model_dtype = config.model_config.dtype if config.model_config else None self.device = config.device_config.device if config.device_config else None