Remove deprecated fields from CompilationConfig (#27593)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-11-12 16:10:28 +00:00 committed by GitHub
parent 728a9eb70e
commit a742134cc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 122 additions and 164 deletions

View File

@ -443,6 +443,7 @@ steps:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_config.py
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_fusion_attn.py

View File

@ -218,16 +218,6 @@ outputs = model.generate(
)
```
### Migration from legacy flags
Legacy `use_cudagraph` and `full_cuda_graph` are unified by `cudagraph_mode`:
* `use_cudagraph=False``NONE`.
* `use_cudagraph=True` and `full_cuda_graph=False``PIECEWISE`.
* `full_cuda_graph=True` → directly set `FULL` and rely on the graceful fallback policy.
As they are deprecated and will be removed in the next major or minor release, i.e., v0.11.0 or v1.0.0, we recommend using cudagraph_mode instead.
### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism)
Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs.

View File

@ -203,7 +203,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
@ -281,7 +281,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=False,
cudagraph_mode=CUDAGraphMode.NONE,
splitting_ops=["silly::attention"],
use_inductor_graph_partition=use_inductor_graph_partition,
)

View File

@ -62,7 +62,6 @@ def _run_simple_model(
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
use_inductor=use_inductor,
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,

View File

@ -449,7 +449,6 @@ def benchmark():
if piecewise:
compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=cudagraph_sizes,
)

View File

@ -2,8 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from contextlib import nullcontext
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
@ -11,7 +13,7 @@ from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
from vllm.utils.torch_utils import _is_torch_equal_or_newer
def test_version():
@ -23,14 +25,6 @@ def test_version():
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
def test_use_cudagraphs_dynamic():
vllm_config = VllmConfig()
# Default V1 configuration now starts without cudagraphs enabled; the
# engine decides when to capture based on runtime settings instead of a
# blanket default.
assert vllm_config.compilation_config.use_cudagraph
def test_copy_pass():
vllm_config = VllmConfig()
inductor_pass = FixFunctionalizationPass(vllm_config)
@ -65,7 +59,7 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
compilation_config = {
"use_cudagraph": False, # speed things up a bit
"cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit
}
with (
compilation_counter.expect(
@ -83,20 +77,31 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
@pytest.mark.parametrize("enabled", [True, False])
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
@pytest.mark.parametrize(
"cudagraph_mode,num_cudagraph_captured",
[
(CUDAGraphMode.NONE, 0),
(CUDAGraphMode.FULL_DECODE_ONLY, 1),
(CUDAGraphMode.PIECEWISE, 13),
(CUDAGraphMode.FULL_AND_PIECEWISE, 14),
],
)
def test_use_cudagraphs(
vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
compilation_config = {
"cudagraph_capture_sizes": [100],
"use_cudagraph": enabled,
"cudagraph_mode": cudagraph_mode,
}
num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
with (
compilation_counter.expect(
num_graphs_seen=1,
num_gpu_runner_capture_triggers=1 if enabled else 0,
num_cudagraph_captured=13 if enabled else 0,
num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
num_cudagraph_captured=num_cudagraph_captured,
),
# loading the model causes compilation (if enabled) to happen
vllm_runner(
@ -168,19 +173,18 @@ def test_splitting_ops_dynamic():
assert not config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
if is_torch_equal_or_newer("2.9.0.dev"):
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
)
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
)
# with inductor partition we use splitting_ops directly for
# partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
)
# with inductor partition we use splitting_ops directly for
# partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
# When attn_fusion pass enabled.
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
@ -189,29 +193,41 @@ def test_splitting_ops_dynamic():
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# With the new simplified logic, attention fusion works with splitting_ops
assert config.compilation_config.splitting_ops_contain_attention()
# cudagraph mode remains PIECEWISE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
# When both use_inductor_graph_partition and attn_fusion pass enabled.
if is_torch_equal_or_newer("2.9.0.dev"):
# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(ValidationError):
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
)
)
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
# When both use_inductor_graph_partition and attn_fusion pass enabled.
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
pass_config={"enable_attn_fusion": True, "enable_noop": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
)
)
# With inductor graph partition, attn_fusion and splitting_ops
# work together. Default splitting_ops include attention ops.
assert config.compilation_config.splitting_ops_contain_attention()
# enable_attn_fusion is directly supported under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_should_split():
@ -293,25 +309,36 @@ def test_should_split():
"tp_size",
"enable_sequence_parallelism",
"max_num_batched_tokens",
"use_cudagraph",
"cudagraph_mode",
"expected_max_size",
),
[
(None, None, 1, False, 2048, True, 512),
([1, 2, 4], 4, 1, False, 2048, True, 4),
([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
([1, 256], None, 1, False, 2048, 256),
([], None, 1, False, 2048, False, 0),
(None, 0, 1, False, 2048, False, 0),
(None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
(
[1, 2, 4],
8,
1,
False,
2048,
CUDAGraphMode.FULL_AND_PIECEWISE,
ValidationError,
),
([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
(None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
# truncated to nearest multiple of 8 or 16
(None, 257, 1, False, 2048, True, 256),
([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list
([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP
([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
# max from list
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
# filtered out 15 due to SP
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# limited by the max_tokens
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# the list should contain at least 1 element when use cudagraph
([], None, 1, False, 2048, True, RuntimeError),
([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
# the max capturing size should be >= 1 when use cudagraph
(None, 0, 1, False, 2048, True, RuntimeError),
(None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
],
)
def test_cudagraph_sizes_post_init(
@ -320,15 +347,17 @@ def test_cudagraph_sizes_post_init(
tp_size,
enable_sequence_parallelism,
max_num_batched_tokens,
use_cudagraph,
cudagraph_mode,
expected_max_size,
):
ctx = nullcontext()
if isinstance(expected_max_size, Exception):
if expected_max_size == ValidationError:
ctx = pytest.raises(expected_max_size)
cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
with ctx:
with (
ctx,
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
):
compilation_config = CompilationConfig(
cudagraph_capture_sizes=cudagraph_capture_sizes,
max_cudagraph_capture_size=max_cudagraph_capture_size,
@ -342,11 +371,13 @@ def test_cudagraph_sizes_post_init(
engine_args = EngineArgs(
model="facebook/opt-125m",
tensor_parallel_size=tp_size,
max_num_seqs=min(max_num_batched_tokens, 128),
max_num_batched_tokens=max_num_batched_tokens,
compilation_config=compilation_config,
)
vllm_config = engine_args.create_engine_config()
assert (
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
)
assert (
vllm_config.compilation_config.max_cudagraph_capture_size
== expected_max_size
)

View File

@ -80,7 +80,6 @@ def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatc
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
@ -215,7 +214,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,
@ -257,7 +255,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_cudagraph=True,
splitting_ops=["silly::attention"],
cudagraph_capture_sizes=[1, 2],
use_inductor_graph_partition=use_inductor_graph_partition,

View File

@ -61,10 +61,8 @@ def test_qwen2_5_vl_evs_functionality(
model,
runner="generate",
max_model_len=4000,
max_num_seqs=1,
dtype=dtype,
limit_mm_per_prompt={"video": 1},
tensor_parallel_size=1,
video_pruning_rate=video_pruning_rate,
) as vllm_model:
# Generate output - this should not crash

View File

@ -206,7 +206,6 @@ class CompilationConfig:
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- [`compile_mm_encoder`][vllm.config.CompilationConfig.compile_mm_encoder]
- CudaGraph capture:
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
- [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
@ -216,7 +215,6 @@ class CompilationConfig:
[vllm.config.CompilationConfig.cudagraph_num_of_warmups]
- [`cudagraph_copy_inputs`]
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
- Inductor compilation:
- [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
@ -396,18 +394,6 @@ class CompilationConfig:
Warning: This flag is new and subject to change in addition
more modes may be added.
"""
use_cudagraph: bool = True
"""Whether to use cudagraph inside compilation:
- False: cudagraph inside compilation is not used.\n
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND
_PIECEWISE instead.
"""
cudagraph_num_of_warmups: int = 0
"""Number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
@ -425,15 +411,6 @@ class CompilationConfig:
internally managed buffer. Default is False.
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
"""
full_cuda_graph: bool | None = False
"""whether to use a full cuda graph for the entire forward pass rather than
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models.
Warning: This flag is deprecated and will be removed in the next major or
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
FULL_AND_PIECEWISE instead.
"""
cudagraph_specialize_lora: bool = True
"""Whether to create separate cuda graphs for cases with and without active
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
@ -603,13 +580,19 @@ class CompilationConfig:
@field_validator("cudagraph_mode", mode="before")
@classmethod
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
"""
enable parse the `cudagraph_mode` enum type from string
"""
"""Enable parsing of the `cudagraph_mode` enum type from string."""
if isinstance(value, str):
return CUDAGraphMode[value.upper()]
return value
@field_validator("pass_config", mode="before")
@classmethod
def validate_pass_config_before(cls, value: Any) -> Any:
"""Enable parsing of the `pass_config` field from a dictionary."""
if isinstance(value, dict):
return PassConfig(**value)
return value
@field_validator("compile_cache_save_format")
@classmethod
def validate_compile_cache_save_format(cls, value: str) -> str:
@ -666,9 +649,6 @@ class CompilationConfig:
func if isinstance(func, InductorPass) else CallableInductorPass(func)
)
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)
if self.pass_config.enable_qk_norm_rope_fusion:
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
@ -684,36 +664,6 @@ class CompilationConfig:
self.inductor_compile_config["combo_kernels"] = True
self.inductor_compile_config["benchmark_combo_kernel"] = True
# migrate the deprecated flags
if not self.use_cudagraph:
logger.warning(
"use_cudagraph is deprecated, use cudagraph_mode=NONE instead."
)
if (
self.cudagraph_mode is not None
and self.cudagraph_mode != CUDAGraphMode.NONE
):
raise ValueError(
"use_cudagraph and cudagraph_mode are mutually"
" exclusive, prefer cudagraph_mode since "
"use_cudagraph is deprecated."
)
self.cudagraph_mode = CUDAGraphMode.NONE
if self.full_cuda_graph:
logger.warning(
"full_cuda_graph is deprecated, use cudagraph_mode=FULL instead."
)
if (
self.cudagraph_mode is not None
and not self.cudagraph_mode.has_full_cudagraphs()
):
raise ValueError(
"full_cuda_graph and cudagraph_mode are "
"mutually exclusive, prefer cudagraph_mode "
"since full_cuda_graph is deprecated."
)
self.cudagraph_mode = CUDAGraphMode.FULL
if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"
):
@ -891,20 +841,19 @@ class CompilationConfig:
def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.enable_attn_fusion
# For dynamo-partition (non-inductor) attention fusion,
# set splitting_ops to empty to avoid splitting at attention ops
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off. "
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
"explicitly if encountering any problems."
)
self.cudagraph_mode = CUDAGraphMode.FULL
if self.splitting_ops is None:
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off. "
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
"explicitly if encountering any problems."
)
self.cudagraph_mode = CUDAGraphMode.FULL
assert not self.splitting_ops_contain_attention(), (
"attention ops should not be in splitting_ops "

View File

@ -656,14 +656,6 @@ class VllmConfig:
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
)
# final migrate the deprecated flags
self.compilation_config.use_cudagraph = (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
)
self.compilation_config.full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
if self.parallel_config.enable_dbo:
a2a_backend = self.parallel_config.all2all_backend
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
@ -853,7 +845,9 @@ class VllmConfig:
)
# de-duplicate the sizes provided by the config
dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
cudagraph_capture_sizes = dedup_sizes
cudagraph_capture_sizes = [
i for i in dedup_sizes if i <= max_num_tokens
]
# sort to make sure the sizes are in ascending order
cudagraph_capture_sizes.sort()
else:

View File

@ -123,7 +123,7 @@ class Mamba1AttentionMetadataBuilder(
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(

View File

@ -302,7 +302,7 @@ class Mamba2AttentionMetadataBuilder(
elif (
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
# Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)

View File

@ -81,7 +81,7 @@ class ShortConvAttentionMetadataBuilder(
elif (
num_decodes > 0
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.full_cuda_graph
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(