[torch.compile] CUDAGraph Inductor partition integration (#24281)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Signed-off-by: boyuanfeng <boyuan@meta.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Boyuan Feng 2025-09-19 18:02:15 -07:00 committed by GitHub
parent b8a287a0a8
commit 8945b001db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 280 additions and 32 deletions

View File

@ -15,6 +15,7 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config) VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1 from vllm.envs import VLLM_USE_V1
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention` # This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter from ..silly_attention import get_global_counter, reset_global_counter
@ -50,16 +51,21 @@ class SillyModel(nn.Module):
return x return x
@pytest.mark.parametrize("use_inductor", [True, False]) def _run_simple_model(
@torch.inference_mode() splitting_ops,
def test_simple_piecewise_compile(use_inductor): use_inductor_graph_partition,
assert VLLM_USE_V1 use_inductor,
expected_num_piecewise_graphs_seen,
expected_num_piecewise_capturable_graphs_seen,
expected_num_backend_compilations,
expected_num_cudagraph_captured,
):
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
use_cudagraph=True, use_cudagraph=True,
use_inductor=use_inductor, use_inductor=use_inductor,
splitting_ops=["silly.attention"], splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True, cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2], cudagraph_capture_sizes=[1, 2],
)) ))
@ -70,11 +76,11 @@ def test_simple_piecewise_compile(use_inductor):
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=5, # 2 * num_layers + 1 num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers num_piecewise_capturable_graphs_seen=
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen expected_num_piecewise_capturable_graphs_seen,
num_cudagraph_captured= num_backend_compilations=expected_num_backend_compilations,
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen num_cudagraph_captured=expected_num_cudagraph_captured,
), set_forward_context(None, ), set_forward_context(None,
vllm_config=vllm_config): # background context vllm_config=vllm_config): # background context
# warm up with background context # warm up with background context
@ -104,3 +110,46 @@ def test_simple_piecewise_compile(use_inductor):
output = model(input) output = model(input)
assert get_global_counter() == 2 assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0])) assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1
_run_simple_model(
splitting_ops=["silly.attention"],
use_inductor_graph_partition=False,
use_inductor=use_inductor,
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
expected_num_backend_compilations=
3, # num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
)
@torch.inference_mode()
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
def test_simple_inductor_graph_partition(splitting_ops):
assert VLLM_USE_V1
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
_run_simple_model(
# inductor graph partition automatically resets splitting_ops
# to be an empty list
splitting_ops=splitting_ops,
use_inductor_graph_partition=True,
use_inductor=True,
expected_num_piecewise_graphs_seen=
1, # since not splitting at fx graph level
expected_num_piecewise_capturable_graphs_seen=
1, # since not splitting at fx graph level
expected_num_backend_compilations=
1, # since not splitting at fx graph level
expected_num_cudagraph_captured=
6, # inductor graph partition still captures 6
# graph, same as fx graph partition.
)

View File

@ -60,4 +60,5 @@ direct_register_custom_op(
mutates_args=["out"], mutates_args=["out"],
fake_impl=silly_attention_fake, fake_impl=silly_attention_fake,
target_lib=silly_lib, target_lib=silly_lib,
tags=(torch._C.Tag.cudagraph_unsafe, ),
) )

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging
import tempfile import tempfile
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -10,9 +11,13 @@ import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel, PassConfig from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
@ -105,6 +110,18 @@ def test_full_graph(
(CompilationConfig(level=CompilationLevel.PIECEWISE, (CompilationConfig(level=CompilationLevel.PIECEWISE,
debug_dump_path=tempfile.gettempdir()), debug_dump_path=tempfile.gettempdir()),
("facebook/opt-125m", {})), ("facebook/opt-125m", {})),
] + [
# graph inductor partition
(
CompilationConfig(
level=CompilationLevel.PIECEWISE,
# inductor graph partition uses
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2]),
model) for model in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev")
]) ])
# only test some of the models # only test some of the models
@create_new_process_for_each_test() @create_new_process_for_each_test()
@ -112,11 +129,51 @@ def test_custom_compile_config(
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]], model_info: tuple[str, dict[str, Any]],
): ):
if (compilation_config.use_inductor_graph_partition
and not is_torch_equal_or_newer("2.9.0.dev")):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
model, model_kwargs = model_info model, model_kwargs = model_info
print(f"MODEL={model}") print(f"MODEL={model}")
run_model(compilation_config, model, model_kwargs) run_model(compilation_config, model, model_kwargs)
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"],
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
)
model_kwargs = {
"kv_cache_dtype": "fp8",
"max_model_len": 1024,
}
with caplog_vllm.at_level(
logging.DEBUG), global_force_attn_backend_context_manager(
_Backend.FLASHINFER):
run_model(compilation_config, model, model_kwargs)
try:
assert ("Fused quantization onto 48 attention nodes"
in caplog_vllm.text), caplog_vllm.text
except AssertionError:
# Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on
# PyTorch side, the compilation of a graph may be cached such
# that custom pass directly goes through cache. In this case,
# we go through this branch and assert that the pass is not
# triggered.
assert "Fused quantization" not in caplog_vllm.text
def run_model(compile_config: Union[int, CompilationConfig], model: str, def run_model(compile_config: Union[int, CompilationConfig], model: str,
model_kwargs: dict[str, Any]): model_kwargs: dict[str, Any]):
prompts = [ prompts = [

View File

@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp) Fp8LinearOp)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@ -339,6 +340,10 @@ else:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"split_attention", "split_attention",
[False, True] if current_platform.is_rocm() else [False]) [False, True] if current_platform.is_rocm() else [False])
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
"use_inductor_graph_partition",
[False] if current_platform.is_rocm() else [False, True])
@pytest.mark.skipif(not current_platform.is_cuda_alike(), @pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA") reason="Only test ROCm or CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@ -352,9 +357,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
dtype: torch.dtype, model_name: str, dtype: torch.dtype, model_name: str,
model_class: type[AttentionQuantPatternModel], model_class: type[AttentionQuantPatternModel],
backend: _Backend, split_attention: bool, backend: _Backend, split_attention: bool,
monkeypatch, dist_init): use_inductor_graph_partition: bool,
monkeypatch, dist_init, caplog_vllm):
"""Test AttentionStaticQuantPattern fusion pass""" """Test AttentionStaticQuantPattern fusion pass"""
if use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention: if split_attention:
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1") monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
@ -372,6 +383,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, level=CompilationLevel.PIECEWISE,
custom_ops=["+quant_fp8"], custom_ops=["+quant_fp8"],
use_inductor_graph_partition=use_inductor_graph_partition,
), ),
cache_config=CacheConfig(cache_dtype="fp8")) cache_config=CacheConfig(cache_dtype="fp8"))
@ -444,6 +456,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
backend=test_backend, backend=test_backend,
fullgraph=True) fullgraph=True)
assert model_compiled.attn._o_scale_float is None assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v) result_fused_1 = model_compiled(q, k, v)
if backend == _Backend.FLASHINFER: if backend == _Backend.FLASHINFER:
@ -453,6 +466,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
# _o_scale_float # _o_scale_float
assert model_compiled.attn._o_scale_float is not None assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v) result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None assert model_compiled.attn._o_scale_float is not None
torch.testing.assert_close(result_unfused, torch.testing.assert_close(result_unfused,

View File

@ -577,6 +577,7 @@ direct_register_custom_op(
mutates_args=[], mutates_args=[],
fake_impl=unified_attention_fake, fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch._C.Tag.cudagraph_unsafe, ),
) )
@ -627,4 +628,5 @@ direct_register_custom_op(
mutates_args=["output", "output_block_scale"], mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake, fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=(torch._C.Tag.cudagraph_unsafe, ),
) )

View File

@ -329,6 +329,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
i for i, x in enumerate(args) if isinstance(x, torch.SymInt) i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
] ]
global compilation_start_time global compilation_start_time
compiled_graph_for_dynamic_shape = self.vllm_backend.\ compiled_graph_for_dynamic_shape = self.vllm_backend.\
compiler_manager.compile( compiler_manager.compile(
submod, submod,
@ -339,7 +340,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
num_graphs=len(self.compile_submod_names), num_graphs=len(self.compile_submod_names),
runtime_shape=None) runtime_shape=None)
# Lazy import here to avoid circular import # Lazy import here to avoid circular import
from .cuda_graph import CUDAGraphOptions
from .cuda_piecewise_backend import PiecewiseBackend from .cuda_piecewise_backend import PiecewiseBackend
piecewise_backend = PiecewiseBackend( piecewise_backend = PiecewiseBackend(
@ -347,7 +347,13 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
len(self.compile_submod_names), sym_shape_indices, len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_dynamic_shape, self.vllm_backend) compiled_graph_for_dynamic_shape, self.vllm_backend)
if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and
not self.compilation_config.use_inductor_graph_partition):
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.
from .cuda_graph import CUDAGraphOptions
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent. # class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname( static_graph_wrapper_class = resolve_obj_by_qualname(

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import inspect import inspect
from typing import Callable, Optional, TypeVar, Union, overload from typing import Callable, Optional, TypeVar, Union, overload
from unittest.mock import patch from unittest.mock import patch
@ -14,7 +15,7 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo from vllm.utils import resolve_obj_by_qualname, supports_dynamo
from .monitor import start_monitoring_torch_compile from .monitor import start_monitoring_torch_compile
@ -301,8 +302,11 @@ def _support_torch_compile(
with patch.object(InliningInstructionTranslator, 'inline_call', with patch.object(InliningInstructionTranslator, 'inline_call',
patched_inline_call), torch._dynamo.config.patch( patched_inline_call), torch._dynamo.config.patch(
**dynamo_config_patches): **dynamo_config_patches
), maybe_use_cudagraph_partition_wrapper(
self.vllm_config):
output = self.compiled_callable(*args, **kwargs) output = self.compiled_callable(*args, **kwargs)
return output return output
# usually, capturing the model once is enough, and then we can # usually, capturing the model once is enough, and then we can
@ -314,3 +318,52 @@ def _support_torch_compile(
cls.__call__ = __call__ cls.__call__ = __call__
return cls return cls
@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
"""
Context manager to set/unset customized cudagraph partition wrappers.
If we're using Inductor-based graph partitioning, we currently have the
whole `fx.Graph` before Inductor lowering and and the piecewise
splitting happens after all graph passes and fusions. Here, we add
a custom hook for Inductor to wrap each partition with our static
graph wrapper class to maintain more control over static graph
capture and replay.
"""
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and compilation_config.use_inductor_graph_partition):
from torch._inductor.utils import CUDAGraphWrapperMetadata
from vllm.compilation.cuda_graph import CUDAGraphOptions
from vllm.platforms import current_platform
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())
def customized_cudagraph_wrapper(f,
metadata: CUDAGraphWrapperMetadata):
partition_id = metadata.partition_index
num_partitions = metadata.num_partitions
return static_graph_wrapper_class(
runnable=f,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=partition_id == 0,
gc_disable=partition_id != 0,
weak_ref_output=partition_id == num_partitions - 1,
))
torch._inductor.utils.set_customized_partition_wrappers(
customized_cudagraph_wrapper)
yield
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and compilation_config.use_inductor_graph_partition):
torch._inductor.utils.set_customized_partition_wrappers(None)

View File

@ -299,6 +299,26 @@ class CompilationConfig:
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
""" """
use_inductor_graph_partition: bool = False
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
This partition happens at inductor codegen time after all passes and fusions
are finished. It generates a single `call` function which wraps
cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops
outside the partition functions. For a graph with N cudagraph-unsafe ops
(e.g., Attention), there would be N+1 partitions. To mark an op as
cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when
register the custom op.
This config supports both full cudagraph and piecewise cudagraph without
compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper
to each partition. For N+1 partitions, there would be N+1
CUDAGraph wrapper instances.
For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the
inductor `call` function in the model runner. The top-level full cudagraph
capture ignores all partitioning.
"""
pass_config: PassConfig = field(default_factory=PassConfig) pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details""" """Custom inductor passes, see PassConfig for more details"""
@ -461,6 +481,12 @@ class CompilationConfig:
"since full_cuda_graph is deprecated.") "since full_cuda_graph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.FULL self.cudagraph_mode = CUDAGraphMode.FULL
if (self.use_inductor_graph_partition
and not is_torch_equal_or_newer("2.9.0.dev")):
raise ValueError("use_inductor_graph_partition is only "
"supported with torch>=2.9.0.dev. Set "
"use_inductor_graph_partition=False instead.")
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION: if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.") raise ValueError("No compilation level is set.")
@ -540,19 +566,36 @@ class CompilationConfig:
"set_splitting_ops_for_v1 should only be called when " "set_splitting_ops_for_v1 should only be called when "
"level is CompilationLevel.PIECEWISE") "level is CompilationLevel.PIECEWISE")
use_inductor_graph_partition_msg = (
"When use_inductor_graph_partition=True, splitting_ops "
"are ignored and set to an empty list. Instead, "
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
"used to annotate custom ops for graph partition.")
if self.splitting_ops is None: if self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty if self.use_inductor_graph_partition:
# list and capture the full cudagraph inside the flattened fx # When using inductor graph partition, we set splitting_ops
# graph, we keep the piecewise fx graph structure but capture the # to be empty and rely on torch._C.Tag.cudagraph_unsafe to
# full cudagraph outside the fx graph. This reduces some cpu # annotate custom ops as splitting ops.
# overhead when the runtime batch_size is not cudagraph captured. logger.warning_once(use_inductor_graph_partition_msg)
# see https://github.com/vllm-project/vllm/pull/20059 for details. self.splitting_ops = []
# make a copy to avoid mutating the class-level list via reference. else:
self.splitting_ops = list(self._attention_ops) # NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture
# the full cudagraph outside the fx graph. This reduces some
# cpu overhead when the runtime batch_size is not cudagraph
# captured. see https://github.com/vllm-project/vllm/pull/20059
# for details. make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
elif len(self.splitting_ops) == 0: elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty " logger.warning_once(
"splitting_ops.") "Using piecewise compilation with empty "
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: "splitting_ops and use_inductor_graph_partition"
f"={self.use_inductor_graph_partition}.")
if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE
and not self.use_inductor_graph_partition):
logger.warning_once( logger.warning_once(
"When compilation level is piecewise with empty " "When compilation level is piecewise with empty "
"splitting_ops, PIECEWISE cudagraph_mode will be " "splitting_ops, PIECEWISE cudagraph_mode will be "
@ -562,7 +605,26 @@ class CompilationConfig:
"any problems.") "any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = [] self.splitting_ops = []
elif self.use_inductor_graph_partition:
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
def splitting_ops_contain_attention(self) -> bool: def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all( return self.splitting_ops is not None and all(
op in self.splitting_ops for op in self._attention_ops) op in self.splitting_ops for op in self._attention_ops)
def is_attention_compiled_piecewise(self) -> bool:
use_fx_graph_piecewise_compilation = (
self.level == CompilationLevel.PIECEWISE
and self.splitting_ops_contain_attention())
inductor_used = (self.level == CompilationLevel.PIECEWISE
and self.use_inductor) or (
self.level >= CompilationLevel.DYNAMO_AS_IS
and self.backend == "inductor")
use_inductor_piecewise_compilation = (
inductor_used and self.use_inductor_graph_partition
and not self.splitting_ops_contain_attention())
return use_fx_graph_piecewise_compilation or \
use_inductor_piecewise_compilation

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger from vllm.logger import init_logger
@ -39,11 +39,15 @@ class CudagraphDispatcher:
CUDAGraphMode.FULL: set(), CUDAGraphMode.FULL: set(),
} }
assert not self.cudagraph_mode.requires_piecewise_compilation() or \ not_use_piecewise_compilation = (
(self.compilation_config.level == CompilationLevel.PIECEWISE and not self.cudagraph_mode.requires_piecewise_compilation())
self.compilation_config.splitting_ops_contain_attention()), \
assert not_use_piecewise_compilation or \
self.compilation_config.is_attention_compiled_piecewise(), \
"Compilation level should be CompilationLevel.PIECEWISE when "\ "Compilation level should be CompilationLevel.PIECEWISE when "\
"cudagraph_mode piecewise cudagraphs is used, "\ "cudagraph_mode piecewise cudagraphs is used, "\
"and attention should be in splitting_ops or "\
"inductor splitting should be used. " \
f"cudagraph_mode={self.cudagraph_mode}, "\ f"cudagraph_mode={self.cudagraph_mode}, "\
f"compilation_level={self.compilation_config.level}, "\ f"compilation_level={self.compilation_config.level}, "\
f"splitting_ops={self.compilation_config.splitting_ops}" f"splitting_ops={self.compilation_config.splitting_ops}"