mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:55:55 +08:00
[torch.compile] Passing only necessary compilation config to inductor pass config (#27041)
Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com>
This commit is contained in:
parent
41d3071918
commit
11ae016bd7
@ -341,6 +341,15 @@ def async_tp_pass_on_test_model(
|
|||||||
async_tp_pass = AsyncTPPass(vllm_config)
|
async_tp_pass = AsyncTPPass(vllm_config)
|
||||||
backend = TestBackend(async_tp_pass)
|
backend = TestBackend(async_tp_pass)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
async_tp_pass.compilation_config.splitting_ops
|
||||||
|
== vllm_config.compilation_config.splitting_ops
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
async_tp_pass.compilation_config.use_inductor_graph_partition
|
||||||
|
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||||
|
)
|
||||||
|
|
||||||
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor
|
||||||
|
|
||||||
hidden_states = torch.randn(
|
hidden_states = torch.randn(
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import copy
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||||
from vllm.config.compilation import CompilationMode
|
from vllm.config.compilation import CompilationMode
|
||||||
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||||
@ -25,6 +28,20 @@ def test_use_cudagraphs_dynamic():
|
|||||||
assert vllm_config.compilation_config.use_cudagraph
|
assert vllm_config.compilation_config.use_cudagraph
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy_pass():
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
inductor_pass = FixFunctionalizationPass(vllm_config)
|
||||||
|
copied_inductor_pass = copy.deepcopy(inductor_pass)
|
||||||
|
assert (
|
||||||
|
copied_inductor_pass.compilation_config.use_inductor_graph_partition
|
||||||
|
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
copied_inductor_pass.compilation_config.splitting_ops
|
||||||
|
== vllm_config.compilation_config.splitting_ops
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_custom_op():
|
def test_custom_op():
|
||||||
# proper syntax
|
# proper syntax
|
||||||
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
|
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
|
||||||
|
|||||||
@ -285,6 +285,14 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
|
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||||
|
assert (
|
||||||
|
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||||
|
== vllm_config.compilation_config.splitting_ops
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
|
||||||
|
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||||
|
)
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
func_pass = FixFunctionalizationPass(vllm_config)
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import operator
|
import operator
|
||||||
import time
|
import time
|
||||||
import weakref
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
@ -19,6 +19,12 @@ from .inductor_pass import InductorPass
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InductorCompilationConfig:
|
||||||
|
splitting_ops: list[str] | None = None
|
||||||
|
use_inductor_graph_partition: bool = False
|
||||||
|
|
||||||
|
|
||||||
class VllmInductorPass(InductorPass):
|
class VllmInductorPass(InductorPass):
|
||||||
"""
|
"""
|
||||||
An inductor pass with access to vLLM PassConfig.
|
An inductor pass with access to vLLM PassConfig.
|
||||||
@ -29,7 +35,12 @@ class VllmInductorPass(InductorPass):
|
|||||||
"""Keep track of pass index for debug dump ordering."""
|
"""Keep track of pass index for debug dump ordering."""
|
||||||
|
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
self.compilation_config = weakref.proxy(config.compilation_config)
|
# Get only the necessary CompilationConfig for the inductor pass, since
|
||||||
|
# full `CompilationConfig` contains pointer to model which is unsafe.
|
||||||
|
self.compilation_config = InductorCompilationConfig(
|
||||||
|
splitting_ops=config.compilation_config.splitting_ops,
|
||||||
|
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
|
||||||
|
)
|
||||||
self.pass_config = config.compilation_config.pass_config
|
self.pass_config = config.compilation_config.pass_config
|
||||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||||
self.device = config.device_config.device if config.device_config else None
|
self.device = config.device_config.device if config.device_config else None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user