mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:25:32 +08:00
[torch.compile] Passing only necessary compilation config to inductor pass config (#27041)
Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com>
This commit is contained in:
parent
41d3071918
commit
11ae016bd7
@ -341,6 +341,15 @@ def async_tp_pass_on_test_model(
|
||||
async_tp_pass = AsyncTPPass(vllm_config)
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
assert (
|
||||
async_tp_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
async_tp_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
|
||||
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
||||
|
||||
hidden_states = torch.randn(
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||
@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic():
|
||||
assert vllm_config.compilation_config.use_cudagraph
|
||||
|
||||
|
||||
def test_copy_pass():
|
||||
vllm_config = VllmConfig()
|
||||
inductor_pass = FixFunctionalizationPass(vllm_config)
|
||||
copied_inductor_pass = copy.deepcopy(inductor_pass)
|
||||
assert (
|
||||
copied_inductor_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
assert (
|
||||
copied_inductor_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
|
||||
|
||||
def test_custom_op():
|
||||
# proper syntax
|
||||
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
|
||||
|
||||
@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model(
|
||||
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import regex as re
|
||||
@ -19,6 +19,12 @@ from .inductor_pass import InductorPass
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InductorCompilationConfig:
|
||||
splitting_ops: list[str] | None = None
|
||||
use_inductor_graph_partition: bool = False
|
||||
|
||||
|
||||
class VllmInductorPass(InductorPass):
|
||||
"""
|
||||
An inductor pass with access to vLLM PassConfig.
|
||||
@ -29,7 +35,12 @@ class VllmInductorPass(InductorPass):
|
||||
"""Keep track of pass index for debug dump ordering."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.compilation_config = weakref.proxy(config.compilation_config)
|
||||
# Get only the necessary CompilationConfig for the inductor pass, since
|
||||
# full `CompilationConfig` contains pointer to model which is unsafe.
|
||||
self.compilation_config = InductorCompilationConfig(
|
||||
splitting_ops=config.compilation_config.splitting_ops,
|
||||
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
|
||||
)
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user