mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[torch.compile] CUDAGraph Inductor partition integration (#24281)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
b8a287a0a8
commit
8945b001db
@ -15,6 +15,7 @@ from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
|||||||
VllmConfig, set_current_vllm_config)
|
VllmConfig, set_current_vllm_config)
|
||||||
from vllm.envs import VLLM_USE_V1
|
from vllm.envs import VLLM_USE_V1
|
||||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||||
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
# This import automatically registers `torch.ops.silly.attention`
|
# This import automatically registers `torch.ops.silly.attention`
|
||||||
from ..silly_attention import get_global_counter, reset_global_counter
|
from ..silly_attention import get_global_counter, reset_global_counter
|
||||||
@ -50,16 +51,21 @@ class SillyModel(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_inductor", [True, False])
|
def _run_simple_model(
|
||||||
@torch.inference_mode()
|
splitting_ops,
|
||||||
def test_simple_piecewise_compile(use_inductor):
|
use_inductor_graph_partition,
|
||||||
assert VLLM_USE_V1
|
use_inductor,
|
||||||
|
expected_num_piecewise_graphs_seen,
|
||||||
|
expected_num_piecewise_capturable_graphs_seen,
|
||||||
|
expected_num_backend_compilations,
|
||||||
|
expected_num_cudagraph_captured,
|
||||||
|
):
|
||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
use_cudagraph=True,
|
use_cudagraph=True,
|
||||||
use_inductor=use_inductor,
|
use_inductor=use_inductor,
|
||||||
splitting_ops=["silly.attention"],
|
splitting_ops=splitting_ops,
|
||||||
|
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||||
cudagraph_copy_inputs=True,
|
cudagraph_copy_inputs=True,
|
||||||
cudagraph_capture_sizes=[1, 2],
|
cudagraph_capture_sizes=[1, 2],
|
||||||
))
|
))
|
||||||
@ -70,11 +76,11 @@ def test_simple_piecewise_compile(use_inductor):
|
|||||||
|
|
||||||
with compilation_counter.expect(
|
with compilation_counter.expect(
|
||||||
num_graphs_seen=1, # one graph for the model
|
num_graphs_seen=1, # one graph for the model
|
||||||
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
|
||||||
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
num_piecewise_capturable_graphs_seen=
|
||||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
expected_num_piecewise_capturable_graphs_seen,
|
||||||
num_cudagraph_captured=
|
num_backend_compilations=expected_num_backend_compilations,
|
||||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
num_cudagraph_captured=expected_num_cudagraph_captured,
|
||||||
), set_forward_context(None,
|
), set_forward_context(None,
|
||||||
vllm_config=vllm_config): # background context
|
vllm_config=vllm_config): # background context
|
||||||
# warm up with background context
|
# warm up with background context
|
||||||
@ -104,3 +110,46 @@ def test_simple_piecewise_compile(use_inductor):
|
|||||||
output = model(input)
|
output = model(input)
|
||||||
assert get_global_counter() == 2
|
assert get_global_counter() == 2
|
||||||
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_inductor", [True, False])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_simple_piecewise_compile(use_inductor):
|
||||||
|
assert VLLM_USE_V1
|
||||||
|
_run_simple_model(
|
||||||
|
splitting_ops=["silly.attention"],
|
||||||
|
use_inductor_graph_partition=False,
|
||||||
|
use_inductor=use_inductor,
|
||||||
|
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||||
|
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||||
|
expected_num_backend_compilations=
|
||||||
|
3, # num_piecewise_capturable_graphs_seen
|
||||||
|
expected_num_cudagraph_captured=
|
||||||
|
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
|
||||||
|
def test_simple_inductor_graph_partition(splitting_ops):
|
||||||
|
assert VLLM_USE_V1
|
||||||
|
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
|
pytest.skip("inductor graph partition is only available "
|
||||||
|
"in PyTorch 2.9+")
|
||||||
|
|
||||||
|
_run_simple_model(
|
||||||
|
# inductor graph partition automatically resets splitting_ops
|
||||||
|
# to be an empty list
|
||||||
|
splitting_ops=splitting_ops,
|
||||||
|
use_inductor_graph_partition=True,
|
||||||
|
use_inductor=True,
|
||||||
|
expected_num_piecewise_graphs_seen=
|
||||||
|
1, # since not splitting at fx graph level
|
||||||
|
expected_num_piecewise_capturable_graphs_seen=
|
||||||
|
1, # since not splitting at fx graph level
|
||||||
|
expected_num_backend_compilations=
|
||||||
|
1, # since not splitting at fx graph level
|
||||||
|
expected_num_cudagraph_captured=
|
||||||
|
6, # inductor graph partition still captures 6
|
||||||
|
# graph, same as fx graph partition.
|
||||||
|
)
|
||||||
|
|||||||
@ -60,4 +60,5 @@ direct_register_custom_op(
|
|||||||
mutates_args=["out"],
|
mutates_args=["out"],
|
||||||
fake_impl=silly_attention_fake,
|
fake_impl=silly_attention_fake,
|
||||||
target_lib=silly_lib,
|
target_lib=silly_lib,
|
||||||
|
tags=(torch._C.Tag.cudagraph_unsafe, ),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
@ -10,9 +11,13 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
from tests.v1.attention.utils import _Backend
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
|
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||||
|
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||||
|
PassConfig)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
from ..utils import create_new_process_for_each_test
|
from ..utils import create_new_process_for_each_test
|
||||||
|
|
||||||
@ -105,6 +110,18 @@ def test_full_graph(
|
|||||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||||
debug_dump_path=tempfile.gettempdir()),
|
debug_dump_path=tempfile.gettempdir()),
|
||||||
("facebook/opt-125m", {})),
|
("facebook/opt-125m", {})),
|
||||||
|
] + [
|
||||||
|
# graph inductor partition
|
||||||
|
(
|
||||||
|
CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
# inductor graph partition uses
|
||||||
|
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
|
||||||
|
use_inductor_graph_partition=True,
|
||||||
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
compile_sizes=[1, 2]),
|
||||||
|
model) for model in models_list(all=False)
|
||||||
|
if is_torch_equal_or_newer("2.9.0.dev")
|
||||||
])
|
])
|
||||||
# only test some of the models
|
# only test some of the models
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@ -112,11 +129,51 @@ def test_custom_compile_config(
|
|||||||
compilation_config: CompilationConfig,
|
compilation_config: CompilationConfig,
|
||||||
model_info: tuple[str, dict[str, Any]],
|
model_info: tuple[str, dict[str, Any]],
|
||||||
):
|
):
|
||||||
|
if (compilation_config.use_inductor_graph_partition
|
||||||
|
and not is_torch_equal_or_newer("2.9.0.dev")):
|
||||||
|
pytest.skip("inductor graph partition is only available "
|
||||||
|
"in PyTorch 2.9+")
|
||||||
|
|
||||||
model, model_kwargs = model_info
|
model, model_kwargs = model_info
|
||||||
print(f"MODEL={model}")
|
print(f"MODEL={model}")
|
||||||
run_model(compilation_config, model, model_kwargs)
|
run_model(compilation_config, model, model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||||
|
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
|
pytest.skip("inductor graph partition is only available "
|
||||||
|
"in PyTorch 2.9+")
|
||||||
|
|
||||||
|
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||||
|
compilation_config = CompilationConfig(
|
||||||
|
level=CompilationLevel.PIECEWISE,
|
||||||
|
use_inductor_graph_partition=True,
|
||||||
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
custom_ops=["+quant_fp8"],
|
||||||
|
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
|
||||||
|
)
|
||||||
|
model_kwargs = {
|
||||||
|
"kv_cache_dtype": "fp8",
|
||||||
|
"max_model_len": 1024,
|
||||||
|
}
|
||||||
|
with caplog_vllm.at_level(
|
||||||
|
logging.DEBUG), global_force_attn_backend_context_manager(
|
||||||
|
_Backend.FLASHINFER):
|
||||||
|
run_model(compilation_config, model, model_kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert ("Fused quantization onto 48 attention nodes"
|
||||||
|
in caplog_vllm.text), caplog_vllm.text
|
||||||
|
except AssertionError:
|
||||||
|
# Note: this message is only triggered when the compilation goes
|
||||||
|
# through the custom pass. Due to multiple layers of cache on
|
||||||
|
# PyTorch side, the compilation of a graph may be cached such
|
||||||
|
# that custom pass directly goes through cache. In this case,
|
||||||
|
# we go through this branch and assert that the pass is not
|
||||||
|
# triggered.
|
||||||
|
assert "Fused quantization" not in caplog_vllm.text
|
||||||
|
|
||||||
|
|
||||||
def run_model(compile_config: Union[int, CompilationConfig], model: str,
|
def run_model(compile_config: Union[int, CompilationConfig], model: str,
|
||||||
model_kwargs: dict[str, Any]):
|
model_kwargs: dict[str, Any]):
|
||||||
prompts = [
|
prompts = [
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
Fp8LinearOp)
|
Fp8LinearOp)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import is_torch_equal_or_newer
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
@ -339,6 +340,10 @@ else:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"split_attention",
|
"split_attention",
|
||||||
[False, True] if current_platform.is_rocm() else [False])
|
[False, True] if current_platform.is_rocm() else [False])
|
||||||
|
# TODO(boyuan): test inductor graph partition on rocm
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"use_inductor_graph_partition",
|
||||||
|
[False] if current_platform.is_rocm() else [False, True])
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||||
reason="Only test ROCm or CUDA")
|
reason="Only test ROCm or CUDA")
|
||||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||||
@ -352,9 +357,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
dtype: torch.dtype, model_name: str,
|
dtype: torch.dtype, model_name: str,
|
||||||
model_class: type[AttentionQuantPatternModel],
|
model_class: type[AttentionQuantPatternModel],
|
||||||
backend: _Backend, split_attention: bool,
|
backend: _Backend, split_attention: bool,
|
||||||
monkeypatch, dist_init):
|
use_inductor_graph_partition: bool,
|
||||||
|
monkeypatch, dist_init, caplog_vllm):
|
||||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||||
|
|
||||||
|
if use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||||
|
"2.9.0.dev"):
|
||||||
|
pytest.skip("inductor graph partition is only available "
|
||||||
|
"in PyTorch 2.9+")
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
if split_attention:
|
if split_attention:
|
||||||
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
|
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
|
||||||
@ -372,6 +383,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE,
|
level=CompilationLevel.PIECEWISE,
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
|
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||||
),
|
),
|
||||||
cache_config=CacheConfig(cache_dtype="fp8"))
|
cache_config=CacheConfig(cache_dtype="fp8"))
|
||||||
|
|
||||||
@ -444,6 +456,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
backend=test_backend,
|
backend=test_backend,
|
||||||
fullgraph=True)
|
fullgraph=True)
|
||||||
assert model_compiled.attn._o_scale_float is None
|
assert model_compiled.attn._o_scale_float is None
|
||||||
|
|
||||||
result_fused_1 = model_compiled(q, k, v)
|
result_fused_1 = model_compiled(q, k, v)
|
||||||
|
|
||||||
if backend == _Backend.FLASHINFER:
|
if backend == _Backend.FLASHINFER:
|
||||||
@ -453,6 +466,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
|||||||
# _o_scale_float
|
# _o_scale_float
|
||||||
assert model_compiled.attn._o_scale_float is not None
|
assert model_compiled.attn._o_scale_float is not None
|
||||||
result_fused_2 = model_compiled(q, k, v)
|
result_fused_2 = model_compiled(q, k, v)
|
||||||
|
|
||||||
assert model_compiled.attn._o_scale_float is not None
|
assert model_compiled.attn._o_scale_float is not None
|
||||||
|
|
||||||
torch.testing.assert_close(result_unfused,
|
torch.testing.assert_close(result_unfused,
|
||||||
|
|||||||
@ -577,6 +577,7 @@ direct_register_custom_op(
|
|||||||
mutates_args=[],
|
mutates_args=[],
|
||||||
fake_impl=unified_attention_fake,
|
fake_impl=unified_attention_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
tags=(torch._C.Tag.cudagraph_unsafe, ),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -627,4 +628,5 @@ direct_register_custom_op(
|
|||||||
mutates_args=["output", "output_block_scale"],
|
mutates_args=["output", "output_block_scale"],
|
||||||
fake_impl=unified_attention_with_output_fake,
|
fake_impl=unified_attention_with_output_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
tags=(torch._C.Tag.cudagraph_unsafe, ),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -329,6 +329,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||||
]
|
]
|
||||||
global compilation_start_time
|
global compilation_start_time
|
||||||
|
|
||||||
compiled_graph_for_dynamic_shape = self.vllm_backend.\
|
compiled_graph_for_dynamic_shape = self.vllm_backend.\
|
||||||
compiler_manager.compile(
|
compiler_manager.compile(
|
||||||
submod,
|
submod,
|
||||||
@ -339,7 +340,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
num_graphs=len(self.compile_submod_names),
|
num_graphs=len(self.compile_submod_names),
|
||||||
runtime_shape=None)
|
runtime_shape=None)
|
||||||
# Lazy import here to avoid circular import
|
# Lazy import here to avoid circular import
|
||||||
from .cuda_graph import CUDAGraphOptions
|
|
||||||
from .cuda_piecewise_backend import PiecewiseBackend
|
from .cuda_piecewise_backend import PiecewiseBackend
|
||||||
|
|
||||||
piecewise_backend = PiecewiseBackend(
|
piecewise_backend = PiecewiseBackend(
|
||||||
@ -347,7 +347,13 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
len(self.compile_submod_names), sym_shape_indices,
|
len(self.compile_submod_names), sym_shape_indices,
|
||||||
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
||||||
|
|
||||||
if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
|
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
|
and
|
||||||
|
not self.compilation_config.use_inductor_graph_partition):
|
||||||
|
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||||
|
# the whole subgraph with a static graph wrapper.
|
||||||
|
from .cuda_graph import CUDAGraphOptions
|
||||||
|
|
||||||
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
||||||
# class) as platform dependent.
|
# class) as platform dependent.
|
||||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, Optional, TypeVar, Union, overload
|
from typing import Callable, Optional, TypeVar, Union, overload
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@ -14,7 +15,7 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
|||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import supports_dynamo
|
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
|
||||||
|
|
||||||
from .monitor import start_monitoring_torch_compile
|
from .monitor import start_monitoring_torch_compile
|
||||||
|
|
||||||
@ -301,8 +302,11 @@ def _support_torch_compile(
|
|||||||
|
|
||||||
with patch.object(InliningInstructionTranslator, 'inline_call',
|
with patch.object(InliningInstructionTranslator, 'inline_call',
|
||||||
patched_inline_call), torch._dynamo.config.patch(
|
patched_inline_call), torch._dynamo.config.patch(
|
||||||
**dynamo_config_patches):
|
**dynamo_config_patches
|
||||||
|
), maybe_use_cudagraph_partition_wrapper(
|
||||||
|
self.vllm_config):
|
||||||
output = self.compiled_callable(*args, **kwargs)
|
output = self.compiled_callable(*args, **kwargs)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# usually, capturing the model once is enough, and then we can
|
# usually, capturing the model once is enough, and then we can
|
||||||
@ -314,3 +318,52 @@ def _support_torch_compile(
|
|||||||
|
|
||||||
cls.__call__ = __call__
|
cls.__call__ = __call__
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||||
|
"""
|
||||||
|
Context manager to set/unset customized cudagraph partition wrappers.
|
||||||
|
|
||||||
|
If we're using Inductor-based graph partitioning, we currently have the
|
||||||
|
whole `fx.Graph` before Inductor lowering and and the piecewise
|
||||||
|
splitting happens after all graph passes and fusions. Here, we add
|
||||||
|
a custom hook for Inductor to wrap each partition with our static
|
||||||
|
graph wrapper class to maintain more control over static graph
|
||||||
|
capture and replay.
|
||||||
|
"""
|
||||||
|
from vllm.config import CUDAGraphMode
|
||||||
|
|
||||||
|
compilation_config = vllm_config.compilation_config
|
||||||
|
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
|
and compilation_config.use_inductor_graph_partition):
|
||||||
|
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||||
|
|
||||||
|
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||||
|
current_platform.get_static_graph_wrapper_cls())
|
||||||
|
|
||||||
|
def customized_cudagraph_wrapper(f,
|
||||||
|
metadata: CUDAGraphWrapperMetadata):
|
||||||
|
partition_id = metadata.partition_index
|
||||||
|
num_partitions = metadata.num_partitions
|
||||||
|
return static_graph_wrapper_class(
|
||||||
|
runnable=f,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||||
|
cudagraph_options=CUDAGraphOptions(
|
||||||
|
debug_log_enable=partition_id == 0,
|
||||||
|
gc_disable=partition_id != 0,
|
||||||
|
weak_ref_output=partition_id == num_partitions - 1,
|
||||||
|
))
|
||||||
|
|
||||||
|
torch._inductor.utils.set_customized_partition_wrappers(
|
||||||
|
customized_cudagraph_wrapper)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
|
and compilation_config.use_inductor_graph_partition):
|
||||||
|
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||||
|
|||||||
@ -299,6 +299,26 @@ class CompilationConfig:
|
|||||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
use_inductor_graph_partition: bool = False
|
||||||
|
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
|
||||||
|
This partition happens at inductor codegen time after all passes and fusions
|
||||||
|
are finished. It generates a single `call` function which wraps
|
||||||
|
cudagraph-safe ops into partition functions and leave cudagraph-unsafe ops
|
||||||
|
outside the partition functions. For a graph with N cudagraph-unsafe ops
|
||||||
|
(e.g., Attention), there would be N+1 partitions. To mark an op as
|
||||||
|
cudagraph unsafe, we can add `tags=(torch._C.Tag.cudagraph_unsafe)` when
|
||||||
|
register the custom op.
|
||||||
|
|
||||||
|
This config supports both full cudagraph and piecewise cudagraph without
|
||||||
|
compiling twice. For piecewise cudagraph, it applies vLLM CUDAGraph wrapper
|
||||||
|
to each partition. For N+1 partitions, there would be N+1
|
||||||
|
CUDAGraph wrapper instances.
|
||||||
|
|
||||||
|
For full CUDAGraph, we always apply a single CUDAGraph wrapper outside the
|
||||||
|
inductor `call` function in the model runner. The top-level full cudagraph
|
||||||
|
capture ignores all partitioning.
|
||||||
|
"""
|
||||||
|
|
||||||
pass_config: PassConfig = field(default_factory=PassConfig)
|
pass_config: PassConfig = field(default_factory=PassConfig)
|
||||||
"""Custom inductor passes, see PassConfig for more details"""
|
"""Custom inductor passes, see PassConfig for more details"""
|
||||||
|
|
||||||
@ -461,6 +481,12 @@ class CompilationConfig:
|
|||||||
"since full_cuda_graph is deprecated.")
|
"since full_cuda_graph is deprecated.")
|
||||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
|
if (self.use_inductor_graph_partition
|
||||||
|
and not is_torch_equal_or_newer("2.9.0.dev")):
|
||||||
|
raise ValueError("use_inductor_graph_partition is only "
|
||||||
|
"supported with torch>=2.9.0.dev. Set "
|
||||||
|
"use_inductor_graph_partition=False instead.")
|
||||||
|
|
||||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||||
if self.level == CompilationLevel.NO_COMPILATION:
|
if self.level == CompilationLevel.NO_COMPILATION:
|
||||||
raise ValueError("No compilation level is set.")
|
raise ValueError("No compilation level is set.")
|
||||||
@ -540,19 +566,36 @@ class CompilationConfig:
|
|||||||
"set_splitting_ops_for_v1 should only be called when "
|
"set_splitting_ops_for_v1 should only be called when "
|
||||||
"level is CompilationLevel.PIECEWISE")
|
"level is CompilationLevel.PIECEWISE")
|
||||||
|
|
||||||
|
use_inductor_graph_partition_msg = (
|
||||||
|
"When use_inductor_graph_partition=True, splitting_ops "
|
||||||
|
"are ignored and set to an empty list. Instead, "
|
||||||
|
"\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is "
|
||||||
|
"used to annotate custom ops for graph partition.")
|
||||||
|
|
||||||
if self.splitting_ops is None:
|
if self.splitting_ops is None:
|
||||||
# NOTE: When using full cudagraph, instead of setting an empty
|
if self.use_inductor_graph_partition:
|
||||||
# list and capture the full cudagraph inside the flattened fx
|
# When using inductor graph partition, we set splitting_ops
|
||||||
# graph, we keep the piecewise fx graph structure but capture the
|
# to be empty and rely on torch._C.Tag.cudagraph_unsafe to
|
||||||
# full cudagraph outside the fx graph. This reduces some cpu
|
# annotate custom ops as splitting ops.
|
||||||
# overhead when the runtime batch_size is not cudagraph captured.
|
logger.warning_once(use_inductor_graph_partition_msg)
|
||||||
# see https://github.com/vllm-project/vllm/pull/20059 for details.
|
self.splitting_ops = []
|
||||||
# make a copy to avoid mutating the class-level list via reference.
|
else:
|
||||||
self.splitting_ops = list(self._attention_ops)
|
# NOTE: When using full cudagraph, instead of setting an empty
|
||||||
|
# list and capture the full cudagraph inside the flattened fx
|
||||||
|
# graph, we keep the piecewise fx graph structure but capture
|
||||||
|
# the full cudagraph outside the fx graph. This reduces some
|
||||||
|
# cpu overhead when the runtime batch_size is not cudagraph
|
||||||
|
# captured. see https://github.com/vllm-project/vllm/pull/20059
|
||||||
|
# for details. make a copy to avoid mutating the class-level
|
||||||
|
# list via reference.
|
||||||
|
self.splitting_ops = list(self._attention_ops)
|
||||||
elif len(self.splitting_ops) == 0:
|
elif len(self.splitting_ops) == 0:
|
||||||
logger.warning_once("Using piecewise compilation with empty "
|
logger.warning_once(
|
||||||
"splitting_ops.")
|
"Using piecewise compilation with empty "
|
||||||
if self.cudagraph_mode == CUDAGraphMode.PIECEWISE:
|
"splitting_ops and use_inductor_graph_partition"
|
||||||
|
f"={self.use_inductor_graph_partition}.")
|
||||||
|
if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||||
|
and not self.use_inductor_graph_partition):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"When compilation level is piecewise with empty "
|
"When compilation level is piecewise with empty "
|
||||||
"splitting_ops, PIECEWISE cudagraph_mode will be "
|
"splitting_ops, PIECEWISE cudagraph_mode will be "
|
||||||
@ -562,7 +605,26 @@ class CompilationConfig:
|
|||||||
"any problems.")
|
"any problems.")
|
||||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||||
self.splitting_ops = []
|
self.splitting_ops = []
|
||||||
|
elif self.use_inductor_graph_partition:
|
||||||
|
logger.warning_once(use_inductor_graph_partition_msg)
|
||||||
|
self.splitting_ops = []
|
||||||
|
|
||||||
def splitting_ops_contain_attention(self) -> bool:
|
def splitting_ops_contain_attention(self) -> bool:
|
||||||
return self.splitting_ops is not None and all(
|
return self.splitting_ops is not None and all(
|
||||||
op in self.splitting_ops for op in self._attention_ops)
|
op in self.splitting_ops for op in self._attention_ops)
|
||||||
|
|
||||||
|
def is_attention_compiled_piecewise(self) -> bool:
|
||||||
|
use_fx_graph_piecewise_compilation = (
|
||||||
|
self.level == CompilationLevel.PIECEWISE
|
||||||
|
and self.splitting_ops_contain_attention())
|
||||||
|
|
||||||
|
inductor_used = (self.level == CompilationLevel.PIECEWISE
|
||||||
|
and self.use_inductor) or (
|
||||||
|
self.level >= CompilationLevel.DYNAMO_AS_IS
|
||||||
|
and self.backend == "inductor")
|
||||||
|
use_inductor_piecewise_compilation = (
|
||||||
|
inductor_used and self.use_inductor_graph_partition
|
||||||
|
and not self.splitting_ops_contain_attention())
|
||||||
|
|
||||||
|
return use_fx_graph_piecewise_compilation or \
|
||||||
|
use_inductor_piecewise_compilation
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.forward_context import BatchDescriptor
|
from vllm.forward_context import BatchDescriptor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -39,11 +39,15 @@ class CudagraphDispatcher:
|
|||||||
CUDAGraphMode.FULL: set(),
|
CUDAGraphMode.FULL: set(),
|
||||||
}
|
}
|
||||||
|
|
||||||
assert not self.cudagraph_mode.requires_piecewise_compilation() or \
|
not_use_piecewise_compilation = (
|
||||||
(self.compilation_config.level == CompilationLevel.PIECEWISE and
|
not self.cudagraph_mode.requires_piecewise_compilation())
|
||||||
self.compilation_config.splitting_ops_contain_attention()), \
|
|
||||||
|
assert not_use_piecewise_compilation or \
|
||||||
|
self.compilation_config.is_attention_compiled_piecewise(), \
|
||||||
"Compilation level should be CompilationLevel.PIECEWISE when "\
|
"Compilation level should be CompilationLevel.PIECEWISE when "\
|
||||||
"cudagraph_mode piecewise cudagraphs is used, "\
|
"cudagraph_mode piecewise cudagraphs is used, "\
|
||||||
|
"and attention should be in splitting_ops or "\
|
||||||
|
"inductor splitting should be used. " \
|
||||||
f"cudagraph_mode={self.cudagraph_mode}, "\
|
f"cudagraph_mode={self.cudagraph_mode}, "\
|
||||||
f"compilation_level={self.compilation_config.level}, "\
|
f"compilation_level={self.compilation_config.level}, "\
|
||||||
f"splitting_ops={self.compilation_config.splitting_ops}"
|
f"splitting_ops={self.compilation_config.splitting_ops}"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user