[torch.compile] Passing only necessary compilation config to inductor pass config (#27041)

Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com>
This commit is contained in:
Lucia Fang 2025-10-16 17:01:52 -07:00 committed by GitHub
parent 41d3071918
commit 11ae016bd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 2 deletions

View File

@ -341,6 +341,15 @@ def async_tp_pass_on_test_model(
async_tp_pass = AsyncTPPass(vllm_config) async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass) backend = TestBackend(async_tp_pass)
assert (
async_tp_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
async_tp_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
hidden_states = torch.randn( hidden_states = torch.randn(

View File

@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import pytest import pytest
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode from vllm.config.compilation import CompilationMode
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic():
assert vllm_config.compilation_config.use_cudagraph assert vllm_config.compilation_config.use_cudagraph
def test_copy_pass():
vllm_config = VllmConfig()
inductor_pass = FixFunctionalizationPass(vllm_config)
copied_inductor_pass = copy.deepcopy(inductor_pass)
assert (
copied_inductor_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
assert (
copied_inductor_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
def test_custom_op(): def test_custom_op():
# proper syntax # proper syntax
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"]) _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

View File

@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model(
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
func_pass = FixFunctionalizationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config)

View File

@ -3,7 +3,7 @@
import functools import functools
import operator import operator
import time import time
import weakref from dataclasses import dataclass
from typing import ClassVar from typing import ClassVar
import regex as re import regex as re
@ -19,6 +19,12 @@ from .inductor_pass import InductorPass
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class InductorCompilationConfig:
splitting_ops: list[str] | None = None
use_inductor_graph_partition: bool = False
class VllmInductorPass(InductorPass): class VllmInductorPass(InductorPass):
""" """
An inductor pass with access to vLLM PassConfig. An inductor pass with access to vLLM PassConfig.
@ -29,7 +35,12 @@ class VllmInductorPass(InductorPass):
"""Keep track of pass index for debug dump ordering.""" """Keep track of pass index for debug dump ordering."""
def __init__(self, config: VllmConfig): def __init__(self, config: VllmConfig):
self.compilation_config = weakref.proxy(config.compilation_config) # Get only the necessary CompilationConfig for the inductor pass, since
# full `CompilationConfig` contains pointer to model which is unsafe.
self.compilation_config = InductorCompilationConfig(
splitting_ops=config.compilation_config.splitting_ops,
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
)
self.pass_config = config.compilation_config.pass_config self.pass_config = config.compilation_config.pass_config
self.model_dtype = config.model_config.dtype if config.model_config else None self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None self.device = config.device_config.device if config.device_config else None