mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 15:04:42 +08:00
Remove deprecated fields from CompilationConfig (#27593)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
728a9eb70e
commit
a742134cc5
@ -443,6 +443,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_config.py
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_fusion_attn.py
|
||||
|
||||
@ -218,16 +218,6 @@ outputs = model.generate(
|
||||
)
|
||||
```
|
||||
|
||||
### Migration from legacy flags
|
||||
|
||||
Legacy `use_cudagraph` and `full_cuda_graph` are unified by `cudagraph_mode`:
|
||||
|
||||
* `use_cudagraph=False` → `NONE`.
|
||||
* `use_cudagraph=True` and `full_cuda_graph=False` → `PIECEWISE`.
|
||||
* `full_cuda_graph=True` → directly set `FULL` and rely on the graceful fallback policy.
|
||||
|
||||
As they are deprecated and will be removed in the next major or minor release, i.e., v0.11.0 or v1.0.0, we recommend using cudagraph_mode instead.
|
||||
|
||||
### Piecewise compilation and full graph custom passes (attention fusion, sequence parallelism)
|
||||
|
||||
Unfortunately, some custom compile passes have to see the whole graph to be effective and hence aren't compatible with piecewise compilation. This includes `AttnFusionPass` and `SequenceParallelismPass`. As a short-term solution, we automatically disable piecewise compilation (by setting `splitting_ops=[]`) when attention fusion is enabled. We use CUDA Graph modes `FULL` or `FULL_DECODE_ONLY` (depending on backend support). However, this leads to another optimization incompatibility and confusing performance tradeoffs.
|
||||
|
||||
@ -203,7 +203,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
@ -281,7 +281,7 @@ def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=False,
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
splitting_ops=["silly::attention"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
)
|
||||
|
||||
@ -62,7 +62,6 @@ def _run_simple_model(
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
use_inductor=use_inductor,
|
||||
splitting_ops=splitting_ops,
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
|
||||
@ -449,7 +449,6 @@ def benchmark():
|
||||
if piecewise:
|
||||
compilation_config = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=cudagraph_sizes,
|
||||
)
|
||||
|
||||
@ -2,8 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
@ -11,7 +13,7 @@ from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||
|
||||
|
||||
def test_version():
|
||||
@ -23,14 +25,6 @@ def test_version():
|
||||
assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
|
||||
|
||||
|
||||
def test_use_cudagraphs_dynamic():
|
||||
vllm_config = VllmConfig()
|
||||
# Default V1 configuration now starts without cudagraphs enabled; the
|
||||
# engine decides when to capture based on runtime settings instead of a
|
||||
# blanket default.
|
||||
assert vllm_config.compilation_config.use_cudagraph
|
||||
|
||||
|
||||
def test_copy_pass():
|
||||
vllm_config = VllmConfig()
|
||||
inductor_pass = FixFunctionalizationPass(vllm_config)
|
||||
@ -65,7 +59,7 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
|
||||
|
||||
compilation_config = {
|
||||
"use_cudagraph": False, # speed things up a bit
|
||||
"cudagraph_mode": CUDAGraphMode.NONE, # speed things up a bit
|
||||
}
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
@ -83,20 +77,31 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
||||
|
||||
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize("enabled", [True, False])
|
||||
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
||||
@pytest.mark.parametrize(
|
||||
"cudagraph_mode,num_cudagraph_captured",
|
||||
[
|
||||
(CUDAGraphMode.NONE, 0),
|
||||
(CUDAGraphMode.FULL_DECODE_ONLY, 1),
|
||||
(CUDAGraphMode.PIECEWISE, 13),
|
||||
(CUDAGraphMode.FULL_AND_PIECEWISE, 14),
|
||||
],
|
||||
)
|
||||
def test_use_cudagraphs(
|
||||
vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
|
||||
):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
compilation_config = {
|
||||
"cudagraph_capture_sizes": [100],
|
||||
"use_cudagraph": enabled,
|
||||
"cudagraph_mode": cudagraph_mode,
|
||||
}
|
||||
num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
|
||||
with (
|
||||
compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_gpu_runner_capture_triggers=1 if enabled else 0,
|
||||
num_cudagraph_captured=13 if enabled else 0,
|
||||
num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
|
||||
num_cudagraph_captured=num_cudagraph_captured,
|
||||
),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner(
|
||||
@ -168,19 +173,18 @@ def test_splitting_ops_dynamic():
|
||||
assert not config.compilation_config.splitting_ops_contain_attention()
|
||||
|
||||
# When use_inductor_graph_partition=True
|
||||
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["vllm::unified_attention"],
|
||||
)
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["vllm::unified_attention"],
|
||||
)
|
||||
# with inductor partition we use splitting_ops directly for
|
||||
# partition rules
|
||||
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
|
||||
)
|
||||
# with inductor partition we use splitting_ops directly for
|
||||
# partition rules
|
||||
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
|
||||
|
||||
# When attn_fusion pass enabled, splitting_ops now default to attention ops.
|
||||
# When attn_fusion pass enabled.
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
@ -189,29 +193,41 @@ def test_splitting_ops_dynamic():
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
# With the new simplified logic, attention fusion works with splitting_ops
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# cudagraph mode remains PIECEWISE
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# cudagraph mode also fall back to FULL
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
if is_torch_equal_or_newer("2.9.0.dev"):
|
||||
# splitting_ops can not contain attention ops when attn_fusion
|
||||
# pass enabled.
|
||||
with pytest.raises(ValidationError):
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
splitting_ops=CompilationConfig()._attention_ops,
|
||||
)
|
||||
)
|
||||
# With inductor graph partition, attn_fusion and splitting_ops
|
||||
# work together. Default splitting_ops include attention ops.
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# enable_attn_fusion is directly supported under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
)
|
||||
# With inductor graph partition, attn_fusion and splitting_ops
|
||||
# work together. Default splitting_ops include attention ops.
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
# enable_attn_fusion is directly supported under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
|
||||
def test_should_split():
|
||||
@ -293,25 +309,36 @@ def test_should_split():
|
||||
"tp_size",
|
||||
"enable_sequence_parallelism",
|
||||
"max_num_batched_tokens",
|
||||
"use_cudagraph",
|
||||
"cudagraph_mode",
|
||||
"expected_max_size",
|
||||
),
|
||||
[
|
||||
(None, None, 1, False, 2048, True, 512),
|
||||
([1, 2, 4], 4, 1, False, 2048, True, 4),
|
||||
([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
|
||||
([1, 256], None, 1, False, 2048, 256),
|
||||
([], None, 1, False, 2048, False, 0),
|
||||
(None, 0, 1, False, 2048, False, 0),
|
||||
(None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
(
|
||||
[1, 2, 4],
|
||||
8,
|
||||
1,
|
||||
False,
|
||||
2048,
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
ValidationError,
|
||||
),
|
||||
([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
|
||||
(None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
|
||||
# truncated to nearest multiple of 8 or 16
|
||||
(None, 257, 1, False, 2048, True, 256),
|
||||
([1, 2, 4, 15], None, 1, False, 2048, True, 15), # max from list
|
||||
([1, 2, 4, 15], None, 2, True, 2048, True, 4), # filtered out 15 due to SP
|
||||
([1, 2, 4, 15], None, 1, False, 8, True, 4), # limited by the max_tokens
|
||||
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
# max from list
|
||||
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
|
||||
# filtered out 15 due to SP
|
||||
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
# limited by the max_tokens
|
||||
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
# the list should contain at least 1 element when use cudagraph
|
||||
([], None, 1, False, 2048, True, RuntimeError),
|
||||
([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
|
||||
# the max capturing size should be >= 1 when use cudagraph
|
||||
(None, 0, 1, False, 2048, True, RuntimeError),
|
||||
(None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
|
||||
],
|
||||
)
|
||||
def test_cudagraph_sizes_post_init(
|
||||
@ -320,15 +347,17 @@ def test_cudagraph_sizes_post_init(
|
||||
tp_size,
|
||||
enable_sequence_parallelism,
|
||||
max_num_batched_tokens,
|
||||
use_cudagraph,
|
||||
cudagraph_mode,
|
||||
expected_max_size,
|
||||
):
|
||||
ctx = nullcontext()
|
||||
if isinstance(expected_max_size, Exception):
|
||||
if expected_max_size == ValidationError:
|
||||
ctx = pytest.raises(expected_max_size)
|
||||
|
||||
cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
|
||||
with ctx:
|
||||
with (
|
||||
ctx,
|
||||
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
|
||||
):
|
||||
compilation_config = CompilationConfig(
|
||||
cudagraph_capture_sizes=cudagraph_capture_sizes,
|
||||
max_cudagraph_capture_size=max_cudagraph_capture_size,
|
||||
@ -342,11 +371,13 @@ def test_cudagraph_sizes_post_init(
|
||||
engine_args = EngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=tp_size,
|
||||
max_num_seqs=min(max_num_batched_tokens, 128),
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
|
||||
assert (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
|
||||
)
|
||||
assert (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||
== expected_max_size
|
||||
)
|
||||
|
||||
@ -80,7 +80,6 @@ def test_ignore_torch_compile_decorator(use_inductor_graph_partition, monkeypatc
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
@ -215,7 +214,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
@ -257,7 +255,6 @@ def test_conditional_compile_enable_if(use_inductor_graph_partition, monkeypatch
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly::attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
|
||||
@ -61,10 +61,8 @@ def test_qwen2_5_vl_evs_functionality(
|
||||
model,
|
||||
runner="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=1,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={"video": 1},
|
||||
tensor_parallel_size=1,
|
||||
video_pruning_rate=video_pruning_rate,
|
||||
) as vllm_model:
|
||||
# Generate output - this should not crash
|
||||
|
||||
@ -206,7 +206,6 @@ class CompilationConfig:
|
||||
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
|
||||
- [`compile_mm_encoder`][vllm.config.CompilationConfig.compile_mm_encoder]
|
||||
- CudaGraph capture:
|
||||
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
|
||||
- [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode]
|
||||
- [`cudagraph_capture_sizes`]
|
||||
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
|
||||
@ -216,7 +215,6 @@ class CompilationConfig:
|
||||
[vllm.config.CompilationConfig.cudagraph_num_of_warmups]
|
||||
- [`cudagraph_copy_inputs`]
|
||||
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
|
||||
- [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
|
||||
- Inductor compilation:
|
||||
- [`use_inductor`][vllm.config.CompilationConfig.use_inductor]
|
||||
- [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes]
|
||||
@ -396,18 +394,6 @@ class CompilationConfig:
|
||||
Warning: This flag is new and subject to change in addition
|
||||
more modes may be added.
|
||||
"""
|
||||
use_cudagraph: bool = True
|
||||
"""Whether to use cudagraph inside compilation:
|
||||
|
||||
- False: cudagraph inside compilation is not used.\n
|
||||
- True: cudagraph inside compilation is used. It requires
|
||||
that all input buffers have fixed addresses, and all
|
||||
splitting ops write their outputs to input buffers.
|
||||
|
||||
Warning: This flag is deprecated and will be removed in the next major or
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=FULL_AND
|
||||
_PIECEWISE instead.
|
||||
"""
|
||||
cudagraph_num_of_warmups: int = 0
|
||||
"""Number of warmup runs for cudagraph.
|
||||
It means the first several runs will be treated as warmup runs.
|
||||
@ -425,15 +411,6 @@ class CompilationConfig:
|
||||
internally managed buffer. Default is False.
|
||||
Note that this flag is only effective when cudagraph_mode is PIECEWISE.
|
||||
"""
|
||||
full_cuda_graph: bool | None = False
|
||||
"""whether to use a full cuda graph for the entire forward pass rather than
|
||||
splitting certain operations such as attention into subgraphs. Thus this
|
||||
flag cannot be used together with splitting_ops. This may provide
|
||||
performance benefits for smaller models.
|
||||
Warning: This flag is deprecated and will be removed in the next major or
|
||||
minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=
|
||||
FULL_AND_PIECEWISE instead.
|
||||
"""
|
||||
cudagraph_specialize_lora: bool = True
|
||||
"""Whether to create separate cuda graphs for cases with and without active
|
||||
LoRA adapters. When set to False, the LoRA-enabled cuda graph will be used
|
||||
@ -603,13 +580,19 @@ class CompilationConfig:
|
||||
@field_validator("cudagraph_mode", mode="before")
|
||||
@classmethod
|
||||
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
|
||||
"""
|
||||
enable parse the `cudagraph_mode` enum type from string
|
||||
"""
|
||||
"""Enable parsing of the `cudagraph_mode` enum type from string."""
|
||||
if isinstance(value, str):
|
||||
return CUDAGraphMode[value.upper()]
|
||||
return value
|
||||
|
||||
@field_validator("pass_config", mode="before")
|
||||
@classmethod
|
||||
def validate_pass_config_before(cls, value: Any) -> Any:
|
||||
"""Enable parsing of the `pass_config` field from a dictionary."""
|
||||
if isinstance(value, dict):
|
||||
return PassConfig(**value)
|
||||
return value
|
||||
|
||||
@field_validator("compile_cache_save_format")
|
||||
@classmethod
|
||||
def validate_compile_cache_save_format(cls, value: str) -> str:
|
||||
@ -666,9 +649,6 @@ class CompilationConfig:
|
||||
func if isinstance(func, InductorPass) else CallableInductorPass(func)
|
||||
)
|
||||
|
||||
if isinstance(self.pass_config, dict):
|
||||
self.pass_config = PassConfig(**self.pass_config)
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
# TODO(zhuhaoran): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
@ -684,36 +664,6 @@ class CompilationConfig:
|
||||
self.inductor_compile_config["combo_kernels"] = True
|
||||
self.inductor_compile_config["benchmark_combo_kernel"] = True
|
||||
|
||||
# migrate the deprecated flags
|
||||
if not self.use_cudagraph:
|
||||
logger.warning(
|
||||
"use_cudagraph is deprecated, use cudagraph_mode=NONE instead."
|
||||
)
|
||||
if (
|
||||
self.cudagraph_mode is not None
|
||||
and self.cudagraph_mode != CUDAGraphMode.NONE
|
||||
):
|
||||
raise ValueError(
|
||||
"use_cudagraph and cudagraph_mode are mutually"
|
||||
" exclusive, prefer cudagraph_mode since "
|
||||
"use_cudagraph is deprecated."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
if self.full_cuda_graph:
|
||||
logger.warning(
|
||||
"full_cuda_graph is deprecated, use cudagraph_mode=FULL instead."
|
||||
)
|
||||
if (
|
||||
self.cudagraph_mode is not None
|
||||
and not self.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
raise ValueError(
|
||||
"full_cuda_graph and cudagraph_mode are "
|
||||
"mutually exclusive, prefer cudagraph_mode "
|
||||
"since full_cuda_graph is deprecated."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
if self.use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||
"2.9.0.dev"
|
||||
):
|
||||
@ -891,20 +841,19 @@ class CompilationConfig:
|
||||
|
||||
def set_splitting_ops_for_attn_fusion(self):
|
||||
assert self.pass_config.enable_attn_fusion
|
||||
# For dynamo-partition (non-inductor) attention fusion,
|
||||
# set splitting_ops to empty to avoid splitting at attention ops
|
||||
self.splitting_ops = []
|
||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"enable_attn_fusion is incompatible with piecewise "
|
||||
"cudagraph when use_inductor_graph_partition is off. "
|
||||
"In this case, splitting_ops will be set to empty "
|
||||
"list, and cudagraph_mode will be set to FULL. "
|
||||
"Please ensure you are using attention backends that "
|
||||
"support cudagraph or set cudagraph_mode to NONE "
|
||||
"explicitly if encountering any problems."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
if self.splitting_ops is None:
|
||||
self.splitting_ops = []
|
||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"enable_attn_fusion is incompatible with piecewise "
|
||||
"cudagraph when use_inductor_graph_partition is off. "
|
||||
"In this case, splitting_ops will be set to empty "
|
||||
"list, and cudagraph_mode will be set to FULL. "
|
||||
"Please ensure you are using attention backends that "
|
||||
"support cudagraph or set cudagraph_mode to NONE "
|
||||
"explicitly if encountering any problems."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
assert not self.splitting_ops_contain_attention(), (
|
||||
"attention ops should not be in splitting_ops "
|
||||
|
||||
@ -656,14 +656,6 @@ class VllmConfig:
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
)
|
||||
|
||||
# final migrate the deprecated flags
|
||||
self.compilation_config.use_cudagraph = (
|
||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
)
|
||||
self.compilation_config.full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
a2a_backend = self.parallel_config.all2all_backend
|
||||
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
|
||||
@ -853,7 +845,9 @@ class VllmConfig:
|
||||
)
|
||||
# de-duplicate the sizes provided by the config
|
||||
dedup_sizes = list(set(self.compilation_config.cudagraph_capture_sizes))
|
||||
cudagraph_capture_sizes = dedup_sizes
|
||||
cudagraph_capture_sizes = [
|
||||
i for i in dedup_sizes if i <= max_num_tokens
|
||||
]
|
||||
# sort to make sure the sizes are in ascending order
|
||||
cudagraph_capture_sizes.sort()
|
||||
else:
|
||||
|
||||
@ -123,7 +123,7 @@ class Mamba1AttentionMetadataBuilder(
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.full_cuda_graph
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
|
||||
@ -302,7 +302,7 @@ class Mamba2AttentionMetadataBuilder(
|
||||
|
||||
elif (
|
||||
num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.full_cuda_graph
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
# Pad state tensor for CUDA graph
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
|
||||
@ -81,7 +81,7 @@ class ShortConvAttentionMetadataBuilder(
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.full_cuda_graph
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user