[torch.compile] CUDAGraph Inductor partition integration (#24281)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Boyuan Feng <fby.1994@gmail.com>
Signed-off-by: boyuanfeng <boyuan@meta.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Boyuan Feng 2025-09-19 18:02:15 -07:00 committed by GitHub
parent b8a287a0a8
commit 8945b001db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 280 additions and 32 deletions

View File

@ -15,6 +15,7 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter
@ -50,16 +51,21 @@ class SillyModel(nn.Module):
return x
@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1
def _run_simple_model(
splitting_ops,
use_inductor_graph_partition,
use_inductor,
expected_num_piecewise_graphs_seen,
expected_num_piecewise_capturable_graphs_seen,
expected_num_backend_compilations,
expected_num_cudagraph_captured,
):
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
splitting_ops=["silly.attention"],
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
@ -70,11 +76,11 @@ def test_simple_piecewise_compile(use_inductor):
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=
expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
), set_forward_context(None,
vllm_config=vllm_config): # background context
# warm up with background context
@ -104,3 +110,46 @@ def test_simple_piecewise_compile(use_inductor):
output = model(input)
assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1
_run_simple_model(
splitting_ops=["silly.attention"],
use_inductor_graph_partition=False,
use_inductor=use_inductor,
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
expected_num_backend_compilations=
3, # num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
)
@torch.inference_mode()
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
def test_simple_inductor_graph_partition(splitting_ops):
assert VLLM_USE_V1
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
_run_simple_model(
# inductor graph partition automatically resets splitting_ops
# to be an empty list
splitting_ops=splitting_ops,
use_inductor_graph_partition=True,
use_inductor=True,
expected_num_piecewise_graphs_seen=
1, # since not splitting at fx graph level
expected_num_piecewise_capturable_graphs_seen=
1, # since not splitting at fx graph level
expected_num_backend_compilations=
1, # since not splitting at fx graph level
expected_num_cudagraph_captured=
6, # inductor graph partition still captures 6
# graph, same as fx graph partition.
)

View File

@ -60,4 +60,5 @@ direct_register_custom_op(
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
tags=(torch._C.Tag.cudagraph_unsafe, ),
)

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import logging
import tempfile
from typing import Any, Optional, Union
@ -10,9 +11,13 @@ import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig)
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test
@ -105,6 +110,18 @@ def test_full_graph(
(CompilationConfig(level=CompilationLevel.PIECEWISE,
debug_dump_path=tempfile.gettempdir()),
("facebook/opt-125m", {})),
] + [
# graph inductor partition
(
CompilationConfig(
level=CompilationLevel.PIECEWISE,
# inductor graph partition uses
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2]),
model) for model in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev")
])
# only test some of the models
@create_new_process_for_each_test()
@ -112,11 +129,51 @@ def test_custom_compile_config(
compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]],
):
if (compilation_config.use_inductor_graph_partition
and not is_torch_equal_or_newer("2.9.0.dev")):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
model, model_kwargs = model_info
print(f"MODEL={model}")
run_model(compilation_config, model, model_kwargs)
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"],
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
)
model_kwargs = {
"kv_cache_dtype": "fp8",
"max_model_len": 1024,
}
with caplog_vllm.at_level(
logging.DEBUG), global_force_attn_backend_context_manager(
_Backend.FLASHINFER):
run_model(compilation_config, model, model_kwargs)
try:
assert ("Fused quantization onto 48 attention nodes"
in caplog_vllm.text), caplog_vllm.text
except AssertionError:
# Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on
# PyTorch side, the compilation of a graph may be cached such
# that custom pass directly goes through cache. In this case,
# we go through this branch and assert that the pass is not
# triggered.
assert "Fused quantization" not in caplog_vllm.text
def run_model(compile_config: Union[int, CompilationConfig], model: str,
model_kwargs: dict[str, Any]):
prompts = [

View File

@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype()
@ -339,6 +340,10 @@ else:
@pytest.mark.parametrize(
"split_attention",
[False, True] if current_platform.is_rocm() else [False])
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
"use_inductor_graph_partition",
[False] if current_platform.is_rocm() else [False, True])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@ -352,9 +357,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
dtype: torch.dtype, model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend, split_attention: bool,
monkeypatch, dist_init):
use_inductor_graph_partition: bool,
monkeypatch, dist_init, caplog_vllm):
"""Test AttentionStaticQuantPattern fusion pass"""
if use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")
monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention:
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
@ -372,6 +383,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+quant_fp8"],
use_inductor_graph_partition=use_inductor_graph_partition,
),
cache_config=CacheConfig(cache_dtype="fp8"))
@ -444,6 +456,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
backend=test_backend,
fullgraph=True)
assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v)
if backend == _Backend.FLASHINFER:
@ -453,6 +466,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
# _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None
torch.testing.assert_close(result_unfused,

View File

@ -577,6 +577,7 @@ direct_register_custom_op(
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch._C.Tag.cudagraph_unsafe, ),
)
@ -627,4 +628,5 @@ direct_register_custom_op(
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch._C.Tag.cudagraph_unsafe, ),
)

View File

@ -329,6 +329,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
global compilation_start_time
compiled_graph_for_dynamic_shape = self.vllm_backend.\
compiler_manager.compile(
submod,
@ -339,7 +340,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
num_graphs=len(self.compile_submod_names),
runtime_shape=None)
# Lazy import here to avoid circular import
from .cuda_graph import CUDAGraphOptions
from .cuda_piecewise_backend import PiecewiseBackend
piecewise_backend = PiecewiseBackend(
@ -347,7 +347,13 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_dynamic_shape, self.vllm_backend)
if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and
not self.compilation_config.use_inductor_graph_partition):
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.
from .cuda_graph import CUDAGraphOptions
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import inspect
from typing import Callable, Optional, TypeVar, Union, overload
from unittest.mock import patch
@ -14,7 +15,7 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
from .monitor import start_monitoring_torch_compile
@ -301,8 +302,11 @@ def _support_torch_compile(
with patch.object(InliningInstructionTranslator, 'inline_call',
patched_inline_call), torch._dynamo.config.patch(
**dynamo_config_patches):
**dynamo_config_patches
), maybe_use_cudagraph_partition_wrapper(
self.vllm_config):
output = self.compiled_callable(*args, **kwargs)
return output
# usually, capturing the model once is enough, and then we can
@ -314,3 +318,52 @@ def _support_torch_compile(
cls.__call__ = __call__
return cls
@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
"""
Context manager to set/unset customized cudagraph partition wrappers.
If we're using Inductor-based graph partitioning, we currently have the
whole `fx.Graph` before Inductor lowering and and the piecewise
splitting happens after all graph passes and fusions. Here, we add
a custom hook for Inductor to wrap each partition with our static
graph wrapper class to maintain more control over static graph
capture and replay.
"""
from vllm.config import CUDAGraphMode
compilation_config = vllm_config.compilation_config
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and compilation_config.use_inductor_graph_partition):
from torch._inductor.utils import CUDAGraphWrapperMetadata
from vllm.compilation.cuda_graph import CUDAGraphOptions
from vllm.platforms import current_platform
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())
def customized_cudagraph_wrapper(f,
metadata: CUDAGraphWrapperMetadata):
partition_id = metadata.partition_index
num_partitions = metadata.num_partitions
return static_graph_wrapper_class(
runnable=f,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=partition_id == 0,
gc_disable=partition_id != 0,
weak_ref_output=partition_id == num_partitions - 1,
))
torch._inductor.utils.set_customized_partition_wrappers(
customized_cudagraph_wrapper)
yield
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and compilation_config.use_inductor_graph_partition):
torch._inductor.utils.set_customized_partition_wrappers(None)

View File

@ -299,6 +299,26 @@ class CompilationConfig:
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
"""
use_inductor_graph_partition: bool = False
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
This partition happens at inductor codegen time after all passes and fusions
are finished. It generates a single `call` function which wraps
cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops
outside the partition functions. For a graph with N cudagraph-unsafe ops
(e.g., Attention), there would be N+1 partitions. To mark an op as
cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when
register the custom op.
This config supports both full cudagraph and piecewise cudagraph without
compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper
to each partition. For N+1 partitions, there would be N+1
CUDAGraph wrapper instances.
For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the
inductor `call` function in the model runner. The top-level full cudagraph
capture ignores all partitioning.
"""
pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""
@ -461,6 +481,12 @@ class CompilationConfig:
"since full_cuda_graph is deprecated.")
self.cudagraph_mode = CUDAGraphMode.FULL
if (self.use_inductor_graph_partition
and not is_torch_equal_or_newer("2.9.0.dev")):
raise ValueError("use_inductor_graph_partition is only "
"supported with torch>=2.9.0.dev. Set "
"use_inductor_graph_partition=False instead.")
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
@ -540,19 +566,36 @@ class CompilationConfig:
"set_splitting_ops_for_v1 should only be called when "
"level is CompilationLevel.PIECEWISE")
use_inductor_graph_partition_msg = (
"When use_inductor_graph_partition=True, splitting_ops "
"are ignored and set to an empty list. Instead, "
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
"used to annotate custom ops for graph partition.")
if self.splitting_ops is None:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture the
# full cudagraph outside the fx graph. This reduces some cpu
# overhead when the runtime batch_size is not cudagraph captured.
# see https://github.com/vllm-project/vllm/pull/20059 for details.
# make a copy to avoid mutating the class-level list via reference.
self.splitting_ops = list(self._attention_ops)
if self.use_inductor_graph_partition:
# When using inductor graph partition, we set splitting_ops
# to be empty and rely on torch._C.Tag.cudagraph_unsafe to
# annotate custom ops as splitting ops.
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
else:
# NOTE: When using full cudagraph, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx
# graph, we keep the piecewise fx graph structure but capture
# the full cudagraph outside the fx graph. This reduces some
# cpu overhead when the runtime batch_size is not cudagraph
# captured. see https://github.com/vllm-project/vllm/pull/20059
# for details. make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
elif len(self.splitting_ops) == 0:
logger.warning_once("Using piecewise compilation with empty "
"splitting_ops.")
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.warning_once(
"Using piecewise compilation with empty "
"splitting_ops and use_inductor_graph_partition"
f"={self.use_inductor_graph_partition}.")
if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE
and not self.use_inductor_graph_partition):
logger.warning_once(
"When compilation level is piecewise with empty "
"splitting_ops, PIECEWISE cudagraph_mode will be "
@ -562,7 +605,26 @@ class CompilationConfig:
"any problems.")
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
elif self.use_inductor_graph_partition:
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
def splitting_ops_contain_attention(self) -> bool:
return self.splitting_ops is not None and all(
op in self.splitting_ops for op in self._attention_ops)
def is_attention_compiled_piecewise(self) -> bool:
use_fx_graph_piecewise_compilation = (
self.level == CompilationLevel.PIECEWISE
and self.splitting_ops_contain_attention())
inductor_used = (self.level == CompilationLevel.PIECEWISE
and self.use_inductor) or (
self.level >= CompilationLevel.DYNAMO_AS_IS
and self.backend == "inductor")
use_inductor_piecewise_compilation = (
inductor_used and self.use_inductor_graph_partition
and not self.splitting_ops_contain_attention())
return use_fx_graph_piecewise_compilation or \
use_inductor_piecewise_compilation

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
@ -39,11 +39,15 @@ class CudagraphDispatcher:
CUDAGraphMode.FULL: set(),
}
assert not self.cudagraph_mode.requires_piecewise_compilation() or \
(self.compilation_config.level == CompilationLevel.PIECEWISE and
self.compilation_config.splitting_ops_contain_attention()), \
not_use_piecewise_compilation = (
not self.cudagraph_mode.requires_piecewise_compilation())
assert not_use_piecewise_compilation or \
self.compilation_config.is_attention_compiled_piecewise(), \
"Compilation level should be CompilationLevel.PIECEWISE when "\
"cudagraph_mode piecewise cudagraphs is used, "\
"and attention should be in splitting_ops or "\
"inductor splitting should be used. " \
f"cudagraph_mode={self.cudagraph_mode}, "\
f"compilation_level={self.compilation_config.level}, "\
f"splitting_ops={self.compilation_config.splitting_ops}"