[Frontend][torch.compile] CompilationConfig Overhaul (#20283): Set up -O infrastructure (#26847)

Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: adabeyta <aabeyta@redhat.com>
Signed-off-by: Morrison Turnansky <mturnans@redhat.com>
Co-authored-by: adabeyta <aabeyta@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Morrison Turnansky 2025-11-27 04:55:58 -05:00 committed by GitHub
parent 00d3310d2d
commit 0838b52e2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 735 additions and 64 deletions

View File

@ -0,0 +1,69 @@
<!-- markdownlint-disable -->
# Optimization Levels
## Overview
vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechnaism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out of the box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten.
## Level Summaries and Usage Examples
```bash
# CLI usage
python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O0
# Python API usage
from vllm.entrypoints.llm import LLM
llm = LLM(
model="RedHatAI/Llama-3.2-1B-FP8",
optimization_level=0
)
```
#### `-O1`: Quick Optimizations
- **Startup**: Moderate startup time
- **Performance**: Inductor compilation, CUDAGraphMode.PIECEWISE
- **Use case**: Balance for most development scenarios
```bash
# CLI usage
python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O1
# Python API usage
from vllm.entrypoints.llm import LLM
llm = LLM(
model="RedHatAI/Llama-3.2-1B-FP8",
optimization_level=1
)
```
#### `-O2`: Full Optimizations (Default)
- **Startup**: Longer startup time
- **Performance**: `-O1` + CUDAGraphMode.FULL_AND_PIECEWISE
- **Use case**: Production workloads where performance is important. This is the default use case. It is also very similar to the previous default. The primary difference is that noop & fusion flags are enabled.
```bash
# CLI usage (default, so optional)
python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O2
# Python API usage
from vllm.entrypoints.llm import LLM
llm = LLM(
model="RedHatAI/Llama-3.2-1B-FP8",
optimization_level=2 # This is the default
)
```
#### `-O3`: Full Optimization
Still in development. Added infrastructure to prevent changing API in future
release. Currently behaves the same O2.
## Troubleshooting
### Common Issues
1. **Startup Time Too Long**: Use `-O0` or `-O1` for faster startup
2. **Compilation Errors**: Use `debug_dump_path` for additional debugging information
3. **Performance Issues**: Ensure using `-O2` for production

View File

@ -172,8 +172,8 @@ def test_splitting_ops_dynamic():
config = VllmConfig()
# Default V1 config leaves cudagraph mode unset; splitting ops are only
# populated when the engine decides to use piecewise compilation.
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert not config.compilation_config.splitting_ops_contain_attention()
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
assert config.compilation_config.splitting_ops_contain_attention()
# When use_inductor_graph_partition=True
config = VllmConfig(

View File

@ -222,6 +222,47 @@ def test_media_io_kwargs_parser(arg, expected):
assert args.media_io_kwargs == expected
@pytest.mark.parametrize(
("args", "expected"),
[
(["-O", "1"], "1"),
(["-O", "2"], "2"),
(["-O", "3"], "3"),
(["-O0"], "0"),
(["-O1"], "1"),
(["-O2"], "2"),
(["-O3"], "3"),
],
)
def test_optimization_level(args, expected):
"""
Test space-separated optimization levels (-O 1, -O 2, -O 3) map to
optimization_level.
"""
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
parsed_args = parser.parse_args(args)
assert parsed_args.optimization_level == expected
assert parsed_args.compilation_config.mode is None
@pytest.mark.parametrize(
("args", "expected"),
[
(["-O.mode=0"], 0),
(["-O.mode=1"], 1),
(["-O.mode=2"], 2),
(["-O.mode=3"], 3),
],
)
def test_mode_parser(args, expected):
"""
Test compilation config modes (-O.mode=int) map to compilation_config.
"""
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
parsed_args = parser.parse_args(args)
assert parsed_args.compilation_config.mode == expected
def test_compilation_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
@ -229,22 +270,6 @@ def test_compilation_config():
args = parser.parse_args([])
assert args.compilation_config == CompilationConfig()
# set to O3
args = parser.parse_args(["-O0"])
assert args.compilation_config.mode == 0
# set to O 3 (space)
args = parser.parse_args(["-O", "1"])
assert args.compilation_config.mode == 1
# set to O 3 (equals)
args = parser.parse_args(["-O=2"])
assert args.compilation_config.mode == 2
# set to O.mode 3
args = parser.parse_args(["-O.mode", "3"])
assert args.compilation_config.mode == 3
# set to string form of a dict
args = parser.parse_args(
[

View File

@ -5,7 +5,12 @@ import pytest
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.config import (
CompilationConfig,
VllmConfig,
get_cached_compilation_config,
set_current_vllm_config,
)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (
GeluAndMul,
@ -86,6 +91,7 @@ def test_enabled_ops(
backend=backend, mode=compilation_mode, custom_ops=custom_ops
)
)
get_cached_compilation_config.cache_clear()
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

View File

@ -8,9 +8,20 @@ from unittest.mock import patch
import pytest
from vllm.compilation.backends import VllmBackend
from vllm.config import ModelConfig, PoolerConfig, VllmConfig, update_config
from vllm.config import (
CompilationConfig,
ModelConfig,
PoolerConfig,
VllmConfig,
update_config,
)
from vllm.config.compilation import CompilationMode, CUDAGraphMode
from vllm.config.load import LoadConfig
from vllm.config.utils import get_field
from vllm.config.vllm import (
OPTIMIZATION_LEVEL_TO_CONFIG,
OptimizationLevel,
)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@ -235,6 +246,43 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
assert model_config.pooler_config.pooling_type == pooling_type
@pytest.mark.parametrize(
("model_id", "expected_is_moe_model"),
[
("RedHatAI/Qwen3-8B-speculator.eagle3", False),
("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", False),
("RedHatAI/Llama-3.2-1B-FP8", False),
("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", False),
("RedHatAI/gpt-oss-20b", True),
("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
("RedHatAI/Llama-4-Scout-17B-16E-Instruct", True),
("RedHatAI/Mixtral-8x7B-Instruct-v0.1", True),
],
)
def test_moe_model_detection(model_id, expected_is_moe_model):
model_config = ModelConfig(model_id)
# Just check that is_moe_model field exists and is a boolean
assert model_config.is_model_moe() == expected_is_moe_model
@pytest.mark.parametrize(
("model_id", "quantized"),
[
("RedHatAI/Qwen3-8B-speculator.eagle3", False),
("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", True),
("RedHatAI/Llama-3.2-1B-FP8", True),
("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", True),
("RedHatAI/gpt-oss-20b", True),
("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
("RedHatAI/Mixtral-8x7B-Instruct-v0.1", False),
],
)
def test_is_quantized(model_id, quantized):
model_config = ModelConfig(model_id)
# Just check that quantized field exists and is a boolean
assert model_config.is_quantized() == quantized
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
@ -552,3 +600,260 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
assert os.path.exists(config1.tokenizer) and os.path.isdir(config1.tokenizer)
assert os.path.exists(config2.model) and os.path.isdir(config2.model)
assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
@pytest.mark.parametrize(
("backend", "custom_ops", "expected"),
[
("eager", [], True),
("eager", ["+fused_layernorm"], True),
("eager", ["all", "-fused_layernorm"], False),
("inductor", [], False),
("inductor", ["none", "+fused_layernorm"], True),
("inductor", ["none", "-fused_layernorm"], False),
],
)
def test_is_custom_op_enabled(backend: str, custom_ops: list[str], expected: bool):
"""Test that is_custom_op_enabled works correctly."""
config = VllmConfig(
compilation_config=CompilationConfig(backend=backend, custom_ops=custom_ops)
)
assert config.compilation_config.is_custom_op_enabled("fused_layernorm") is expected
def test_vllm_config_defaults_are_none():
"""Verify that optimization-level defaults are None when not set by user."""
# Test all optimization levels to ensure defaults work correctly
for opt_level in OptimizationLevel:
config = object.__new__(VllmConfig)
config.compilation_config = CompilationConfig()
config.optimization_level = opt_level
config.model_config = None
# Use the global optimization level defaults
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[opt_level]
# Verify that all pass_config values are None before defaults are applied
for pass_k in default_config["compilation_config"]["pass_config"]:
assert getattr(config.compilation_config.pass_config, pass_k) is None
# Verify that other config values are None before defaults are applied
for k in default_config["compilation_config"]:
if k != "pass_config":
assert getattr(config.compilation_config, k) is None
@pytest.mark.parametrize(
("model_id", "compiliation_config", "optimization_level"),
[
(
None,
CompilationConfig(backend="eager", custom_ops=["+quant_fp8"]),
OptimizationLevel.O0,
),
(None, CompilationConfig(), OptimizationLevel.O0),
(None, CompilationConfig(), OptimizationLevel.O1),
(None, CompilationConfig(), OptimizationLevel.O2),
(None, CompilationConfig(), OptimizationLevel.O3),
(
"RedHatAI/Qwen3-8B-speculator.eagle3",
CompilationConfig(backend="inductor", custom_ops=["+quant_fp8"]),
OptimizationLevel.O2,
),
(
"RedHatAI/Qwen3-8B-speculator.eagle3",
CompilationConfig(),
OptimizationLevel.O0,
),
(
"RedHatAI/Qwen3-8B-speculator.eagle3",
CompilationConfig(),
OptimizationLevel.O1,
),
(
"RedHatAI/Qwen3-8B-speculator.eagle3",
CompilationConfig(),
OptimizationLevel.O2,
),
(
"RedHatAI/Qwen3-8B-speculator.eagle3",
CompilationConfig(),
OptimizationLevel.O3,
),
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O0),
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O1),
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O2),
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O3),
],
)
def test_vllm_config_defaults(model_id, compiliation_config, optimization_level):
"""Test that optimization-level defaults are correctly applied."""
model_config = None
if model_id is not None:
model_config = ModelConfig(model_id)
vllm_config = VllmConfig(
model_config=model_config,
compilation_config=compiliation_config,
optimization_level=optimization_level,
)
else:
vllm_config = VllmConfig(
compilation_config=compiliation_config,
optimization_level=optimization_level,
)
# Use the global optimization level defaults
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[optimization_level]
# Verify pass_config defaults (nested under compilation_config)
pass_config_dict = default_config["compilation_config"]["pass_config"]
for pass_k, pass_v in pass_config_dict.items():
actual = getattr(vllm_config.compilation_config.pass_config, pass_k)
expected = pass_v(vllm_config) if callable(pass_v) else pass_v
assert actual == expected, (
f"pass_config.{pass_k}: expected {expected}, got {actual}"
)
# Verify other compilation_config defaults
compilation_config_dict = default_config["compilation_config"]
for k, v in compilation_config_dict.items():
if k != "pass_config":
actual = getattr(vllm_config.compilation_config, k)
expected = v(vllm_config) if callable(v) else v
assert actual == expected, (
f"compilation_config.{k}: expected {expected}, got {actual}"
)
def test_vllm_config_callable_defaults():
"""Test that callable defaults work in the config system.
Verifies that lambdas in default configs can inspect VllmConfig properties
(e.g., is_quantized, is_model_moe) to conditionally set optimization flags.
"""
config_no_model = VllmConfig(optimization_level=OptimizationLevel.O2)
# Callable that checks if model exists
has_model = lambda cfg: cfg.model_config is not None
assert has_model(config_no_model) is False
# Test with quantized model
quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
config_quantized = VllmConfig(
model_config=quantized_model, optimization_level=OptimizationLevel.O2
)
enable_if_quantized = lambda cfg: (
cfg.model_config is not None and cfg.model_config.is_quantized()
)
assert enable_if_quantized(config_quantized) is True
assert enable_if_quantized(config_no_model) is False
# Test with MoE model
moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
config_moe = VllmConfig(
model_config=moe_model, optimization_level=OptimizationLevel.O2
)
enable_if_sequential = lambda cfg: (
cfg.model_config is not None and not cfg.model_config.is_model_moe()
)
assert enable_if_sequential(config_moe) is False
assert enable_if_sequential(config_quantized) is True
def test_vllm_config_explicit_overrides():
"""Test that explicit property overrides work correctly with callable defaults.
When users explicitly set configuration properties, those values
take precedence over callable defaults, across different models and
optimization levels.
"""
from vllm.config.compilation import PassConfig
quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
regular_model = ModelConfig("Qwen/Qwen1.5-7B")
# Explicit compilation mode override on O0 (where default is NONE)
compilation_config = CompilationConfig(mode=CompilationMode.VLLM_COMPILE)
config = VllmConfig(
optimization_level=OptimizationLevel.O0,
compilation_config=compilation_config,
)
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
# Explicit pass config flags to override defaults
pass_config = PassConfig(enable_noop=True, enable_attn_fusion=True)
compilation_config = CompilationConfig(pass_config=pass_config)
config = VllmConfig(
optimization_level=OptimizationLevel.O0,
compilation_config=compilation_config,
)
assert config.compilation_config.pass_config.enable_noop is True
assert config.compilation_config.pass_config.enable_attn_fusion is True
# Explicit cudagraph mode override on quantized model at O2
pass_config = PassConfig(enable_async_tp=True)
compilation_config = CompilationConfig(
cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
)
config = VllmConfig(
model_config=quantized_model,
optimization_level=OptimizationLevel.O2,
compilation_config=compilation_config,
)
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert config.compilation_config.pass_config.enable_async_tp is True
# Mode should still use default for O2
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
# Different optimization levels with same model
config_o0 = VllmConfig(
model_config=regular_model, optimization_level=OptimizationLevel.O0
)
config_o2 = VllmConfig(
model_config=regular_model, optimization_level=OptimizationLevel.O2
)
assert config_o0.compilation_config.mode == CompilationMode.NONE
assert config_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config_o0.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
assert (
config_o2.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
)
# Same optimization level across different model types
config_moe_o2 = VllmConfig(
model_config=moe_model, optimization_level=OptimizationLevel.O2
)
config_regular_o2 = VllmConfig(
model_config=regular_model, optimization_level=OptimizationLevel.O2
)
config_quantized_o2 = VllmConfig(
model_config=quantized_model, optimization_level=OptimizationLevel.O2
)
# All should have same base compilation settings at O2
assert config_moe_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config_regular_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config_quantized_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert (
config_moe_o2.compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_AND_PIECEWISE
)
assert (
config_regular_o2.compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_AND_PIECEWISE
)
# Override one field but not others
pass_config = PassConfig(enable_noop=False)
compilation_config = CompilationConfig(pass_config=pass_config)
config = VllmConfig(
model_config=regular_model,
optimization_level=OptimizationLevel.O2,
compilation_config=compilation_config,
)
# Explicit override should be respected
assert config.compilation_config.pass_config.enable_noop is False
# Other fields should still use defaults
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE

View File

@ -28,6 +28,7 @@ def parser():
parser.add_argument("--enable-feature", action="store_true")
parser.add_argument("--hf-overrides", type=json.loads)
parser.add_argument("-O", "--compilation-config", type=json.loads)
parser.add_argument("--optimization-level", type=int)
return parser
@ -217,8 +218,8 @@ def test_dict_args(parser):
"key15": "-minus.and.dot",
},
}
assert parsed_args.optimization_level == 1
assert parsed_args.compilation_config == {
"mode": 1,
"use_inductor_graph_partition": True,
"backend": "custom",
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
@ -241,12 +242,13 @@ def test_duplicate_dict_args(caplog_vllm, parser):
parsed_args = parser.parse_args(args)
# Should be the last value
assert parsed_args.hf_overrides == {"key1": "val2"}
assert parsed_args.compilation_config == {"mode": 3}
assert parsed_args.optimization_level == 3
assert parsed_args.compilation_config == {"mode": 2}
assert len(caplog_vllm.records) == 1
assert "duplicate" in caplog_vllm.text
assert "--hf-overrides.key1" in caplog_vllm.text
assert "-O.mode" in caplog_vllm.text
assert "--optimization-level" in caplog_vllm.text
def test_model_specification(
@ -383,7 +385,7 @@ def test_compilation_mode_string_values(parser):
assert args.compilation_config == {"mode": 0}
args = parser.parse_args(["-O3"])
assert args.compilation_config == {"mode": 3}
assert args.optimization_level == 3
args = parser.parse_args(["-O.mode=NONE"])
assert args.compilation_config == {"mode": "NONE"}

View File

@ -117,9 +117,9 @@ else:
combo_cases_2 = [
("FA2", "FULL", CompilationMode.NONE, True),
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
("FA2", "PIECEWISE", CompilationMode.NONE, False),
("FA2", "PIECEWISE", CompilationMode.NONE, True),
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),

View File

@ -8,7 +8,7 @@ from dataclasses import asdict, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from pydantic import TypeAdapter, field_validator
from pydantic import Field, TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
@ -97,19 +97,25 @@ class PassConfig:
This is separate from general `CompilationConfig` so that inductor passes
don't all have access to full configuration - that would create a cycle as
the `PassManager` is set as a property of config."""
the `PassManager` is set as a property of config.
enable_fusion: bool = False
You must pass PassConfig to VLLMConfig constructor via the CompilationConfig
constructor. VLLMConfig's post_init does further initialization.
If used outside of the VLLMConfig, some fields may be left in an
improper state.
"""
enable_fusion: bool = Field(default=None)
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
enable_attn_fusion: bool = False
enable_attn_fusion: bool = Field(default=None)
"""Whether to enable the custom attention+quant fusion pass."""
enable_noop: bool = False
enable_noop: bool = Field(default=None)
"""Whether to enable the custom no-op elimination pass."""
enable_sequence_parallelism: bool = False
enable_sequence_parallelism: bool = Field(default=None)
"""Whether to enable sequence parallelism."""
enable_async_tp: bool = False
enable_async_tp: bool = Field(default=None)
"""Whether to enable async TP."""
enable_fi_allreduce_fusion: bool = False
enable_fi_allreduce_fusion: bool = Field(default=None)
"""Whether to enable flashinfer allreduce fusion."""
fi_allreduce_fusion_max_size_mb: float | None = None
"""The threshold of the communicated tensor sizes under which
@ -167,6 +173,22 @@ class PassConfig:
"""
return InductorPass.hash_dict(asdict(self))
@field_validator(
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
mode="wrap",
)
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
if value is None:
return value
return handler(value)
def __post_init__(self) -> None:
if not self.enable_noop:
if self.enable_fusion:
@ -243,7 +265,13 @@ class DynamicShapesConfig:
@config
@dataclass
class CompilationConfig:
"""Configuration for compilation. It has three parts:
"""Configuration for compilation.
You must pass CompilationConfig to VLLMConfig constructor.
VLLMConfig's post_init does further initialization. If used outside of the
VLLMConfig, some fields will be left in an improper state.
It has three parts:
- Top-level Compilation control:
- [`mode`][vllm.config.CompilationConfig.mode]
@ -282,14 +310,14 @@ class CompilationConfig:
"""
# Top-level Compilation control
level: int | None = None
level: int = Field(default=None)
"""
Level is deprecated and will be removed in the next release,
either 0.12.0 or 0.11.2 whichever is soonest.
Please use mode. Currently all levels are mapped to mode.
"""
# Top-level Compilation control
mode: CompilationMode | None = None
mode: CompilationMode = Field(default=None)
"""The compilation approach used for torch.compile-based compilation of the
model.
@ -390,7 +418,7 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation
cudagraph_mode: CUDAGraphMode | None = None
cudagraph_mode: CUDAGraphMode = Field(default=None)
"""
The mode of the cudagraph:
@ -452,7 +480,7 @@ class CompilationConfig:
When `enable_lora` is False, this option has no effect.
"""
use_inductor_graph_partition: bool = False
use_inductor_graph_partition: bool = Field(default=None)
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
This partition happens at inductor codegen time after all passes and fusions
are finished. It generates a single `call` function which wraps
@ -648,6 +676,20 @@ class CompilationConfig:
)
return value
@field_validator(
"level",
"mode",
"cudagraph_mode",
"use_inductor_graph_partition",
mode="wrap",
)
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
if value is None:
return value
return handler(value)
def __post_init__(self) -> None:
if self.level is not None:
logger.warning(
@ -948,6 +990,13 @@ class CompilationConfig:
op,
)
def is_custom_op_enabled(self, op: str) -> bool:
if "all" in self.custom_ops:
return f"-{op}" not in self.custom_ops
assert "none" in self.custom_ops
return f"+{op}" in self.custom_ops
def adjust_cudagraph_sizes_for_spec_decode(
self, uniform_decode_query_len: int, tensor_parallel_size: int
):

View File

@ -1752,6 +1752,14 @@ class ModelConfig:
logger.info("Using max model len %s", max_model_len)
return max_model_len
def is_model_moe(
self,
) -> bool:
return self.get_num_experts() > 1
def is_quantized(self) -> bool:
return getattr(self.hf_config, "quantization_config", None) is not None
def get_served_model_name(model: str, served_model_name: str | list[str] | None):
"""

View File

@ -9,8 +9,9 @@ import tempfile
import threading
import time
from contextlib import contextmanager
from dataclasses import replace
from dataclasses import is_dataclass, replace
from datetime import datetime
from enum import IntEnum
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar, get_args
@ -57,6 +58,103 @@ else:
logger = init_logger(__name__)
class OptimizationLevel(IntEnum):
"""Optimization level enum."""
O0 = 0
"""O0 : No optimization. no compilation, no cudagraphs, no other
optimization, just starting up immediately"""
O1 = 1
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
cudagraphs"""
O2 = 2
"""O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
O3 = 3
"""O3: Currently the same as -O2s."""
IS_QUANTIZED = False
IS_DENSE = False
# The optimizations that depend on these properties currently set to False
# in all cases.
# if model_config is not None:
# IS_QUANTIZED = lambda c: c.model_config.is_quantized()
# IS_DENSE = lambda c: not c.model_config.is_model_moe()
# See https://github.com/vllm-project/vllm/issues/25689.
def enable_fusion(cfg: "VllmConfig") -> bool:
"""Returns True if RMS norm or quant FP8 is enabled."""
return cfg.compilation_config.is_custom_op_enabled(
"rms_norm"
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
OPTIMIZATION_LEVEL_00 = {
"compilation_config": {
"pass_config": {
"enable_noop": False,
"enable_fusion": False,
"enable_fi_allreduce_fusion": False,
"enable_attn_fusion": False,
"enable_sequence_parallelism": False,
"enable_async_tp": False,
},
"cudagraph_mode": CUDAGraphMode.NONE,
"use_inductor_graph_partition": False,
},
}
OPTIMIZATION_LEVEL_01 = {
"compilation_config": {
"pass_config": {
"enable_noop": True,
"enable_fusion": enable_fusion,
"enable_fi_allreduce_fusion": False,
"enable_attn_fusion": False,
"enable_sequence_parallelism": False,
"enable_async_tp": False,
},
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
"use_inductor_graph_partition": False,
},
}
OPTIMIZATION_LEVEL_02 = {
"compilation_config": {
"pass_config": {
"enable_noop": True,
"enable_fusion": enable_fusion,
"enable_fi_allreduce_fusion": False,
"enable_attn_fusion": IS_QUANTIZED,
"enable_sequence_parallelism": IS_DENSE,
"enable_async_tp": IS_DENSE,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
},
}
OPTIMIZATION_LEVEL_03 = {
"compilation_config": {
"pass_config": {
"enable_noop": True,
"enable_fusion": enable_fusion,
"enable_fi_allreduce_fusion": False,
"enable_attn_fusion": IS_QUANTIZED,
"enable_sequence_parallelism": IS_DENSE,
"enable_async_tp": IS_DENSE,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
},
}
OPTIMIZATION_LEVEL_TO_CONFIG = {
OptimizationLevel.O0: OPTIMIZATION_LEVEL_00,
OptimizationLevel.O1: OPTIMIZATION_LEVEL_01,
OptimizationLevel.O2: OPTIMIZATION_LEVEL_02,
OptimizationLevel.O3: OPTIMIZATION_LEVEL_03,
}
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
@ -116,6 +214,11 @@ class VllmConfig:
you are using. Contents must be hashable."""
instance_id: str = ""
"""The ID of the vLLM instance."""
optimization_level: OptimizationLevel = OptimizationLevel.O2
"""The optimization level. These levels trade startup time cost for
performance, with -O0 having the best startup time and -O3 having the best
performance. -02 is used by defult. See OptimizationLevel for full
description."""
def compute_hash(self) -> str:
"""
@ -297,6 +400,50 @@ class VllmConfig:
return replace(self, model_config=model_config)
def _set_config_default(self, config_obj: Any, key: str, value: Any) -> None:
"""Set config attribute to default if not already set by user.
Args:
config_obj: Configuration object to update.
key: Attribute name.
value: Default value (static or callable).
"""
if getattr(config_obj, key) is None:
# Some config values are known before initialization and are
# hard coded.
# Other values depend on the user given configuration, so they are
# implemented with lambda functions and decided at run time.
setattr(config_obj, key, value(self) if callable(value) else value)
def _apply_optimization_level_defaults(self, defaults: dict[str, Any]) -> None:
"""Apply optimization level defaults using self as root.
Recursively applies values from defaults into nested config objects.
Only fields present in defaults are overwritten.
If the user configuration does not specify a value for a default field
and if the default field is still None after all user selections are
applied, then default values will be applied to the field. User speciied
fields will not be overridden by the default.
Args:
defaults: Dictionary of default values to apply.
"""
def apply_recursive(config_obj: Any, config_defaults: dict[str, Any]) -> None:
"""Recursively apply defaults to config_obj, using self as root."""
for key, value in config_defaults.items():
if not hasattr(config_obj, key):
continue
current = getattr(config_obj, key)
if isinstance(value, dict) and is_dataclass(current):
apply_recursive(current, value)
else:
self._set_config_default(config_obj, key, value)
apply_recursive(self, defaults)
def _post_init_kv_transfer_config(self) -> None:
"""Update KVTransferConfig based on top-level configs in VllmConfig.
@ -434,17 +581,47 @@ class VllmConfig:
"precision for chunked prefill triton kernels."
)
# If the user does not explicitly set a compilation mode, then
# we use the default mode. The default mode depends on other
# settings (see the below code).
if (
self.optimization_level > OptimizationLevel.O0
and self.model_config is not None
and self.model_config.enforce_eager
):
logger.warning("Enforce eager set, overriding optimization level to -O0")
self.optimization_level = OptimizationLevel.O0
if self.compilation_config.backend == "eager" or (
self.compilation_config.mode is not None
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
):
logger.warning(
"Inductor compilation was disabled by user settings,"
"Optimizations settings that are only active during"
"Inductor compilation will be ignored."
)
def has_blocked_weights():
if self.quant_config is not None:
if hasattr(self.quant_config, "weight_block_size"):
return self.quant_config.weight_block_size is not None
elif hasattr(self.quant_config, "has_blocked_weights"):
return self.quant_config.has_blocked_weights()
return False
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
# On H100 the CUDA kernel is faster than
# native implementation
# https://github.com/vllm-project/vllm/issues/25094
if has_blocked_weights():
custom_ops = self.compilation_config.custom_ops
if "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")
if self.compilation_config.mode is None:
if self.model_config is not None and not self.model_config.enforce_eager:
if self.optimization_level > OptimizationLevel.O0:
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
else:
self.compilation_config.mode = CompilationMode.NONE
# If user does not set custom ops via none or all set it here based on
# compilation mode and backend.
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
if (
self.compilation_config.backend == "inductor"
@ -454,23 +631,33 @@ class VllmConfig:
else:
self.compilation_config.custom_ops.append("all")
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
self._apply_optimization_level_defaults(default_config)
if (
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
):
logger.info(
"Cudagraph mode %s is not compatible with compilation mode %s."
"Overriding to NONE.",
self.compilation_config.cudagraph_mode,
self.compilation_config.mode,
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp:
self.compilation_config.pass_config.enable_sequence_parallelism = True
if self.compilation_config.pass_config.enable_sequence_parallelism:
if "-rms_norm" in self.compilation_config.custom_ops:
logger.warning(
"RMS norm force disabled, sequence parallelism might break"
)
else:
self.compilation_config.custom_ops.append("+rms_norm")
if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default
# value
if self.compilation_config.cudagraph_mode is None:
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
# default to full and piecewise for most models
self.compilation_config.cudagraph_mode = (
CUDAGraphMode.FULL_AND_PIECEWISE
)
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# if cudagraph_mode has full cudagraphs, we need to check support
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs

View File

@ -77,6 +77,7 @@ from vllm.config.observability import DetailedTraceModules
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
from vllm.config.scheduler import SchedulerPolicy
from vllm.config.utils import get_field
from vllm.config.vllm import OptimizationLevel
from vllm.logger import init_logger, suppress_logging
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@ -560,6 +561,7 @@ class EngineArgs:
stream_interval: int = SchedulerConfig.stream_interval
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
optimization_level: OptimizationLevel = VllmConfig.optimization_level
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
kv_offloading_backend: KVOffloadingBackend | None = (
@ -1114,6 +1116,10 @@ class EngineArgs:
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
)
vllm_group.add_argument(
"--optimization-level", **vllm_kwargs["optimization_level"]
)
# Other arguments
parser.add_argument(
"--disable-log-stats",
@ -1733,7 +1739,6 @@ class EngineArgs:
compilation_config.max_cudagraph_capture_size = (
self.max_cudagraph_capture_size
)
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
@ -1750,6 +1755,7 @@ class EngineArgs:
kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config,
additional_config=self.additional_config,
optimization_level=self.optimization_level,
)
return config

View File

@ -247,16 +247,16 @@ class FlexibleArgumentParser(ArgumentParser):
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<mode> here
mode = arg[3:] if arg[2] == "=" else arg[2:]
processed_args.append(f"-O.mode={mode}")
# also handle -O=<optimization_level> here
optimization_level = arg[3:] if arg[2] == "=" else arg[2:]
processed_args += ["--optimization-level", optimization_level]
elif (
arg == "-O"
and i + 1 < len(args)
and args[i + 1] in {"0", "1", "2", "3"}
):
# Convert -O <n> to -O.mode <n>
processed_args.append("-O.mode")
# Convert -O <n> to --optimization-level <n>
processed_args.append("--optimization-level")
else:
processed_args.append(arg)
@ -294,10 +294,24 @@ class FlexibleArgumentParser(ArgumentParser):
delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict)
duplicates = set[str]()
# Track regular arguments (non-dict args) for duplicate detection
regular_args_seen = set[str]()
for i, processed_arg in enumerate(processed_args):
if i in delete: # skip if value from previous arg
continue
if processed_arg.startswith("--") and "." not in processed_arg:
if "=" in processed_arg:
arg_name = processed_arg.split("=", 1)[0]
else:
arg_name = processed_arg
if arg_name in regular_args_seen:
duplicates.add(arg_name)
else:
regular_args_seen.add(arg_name)
continue
if processed_arg.startswith("-") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value_str = processed_arg.split("=", 1)

View File

@ -37,7 +37,7 @@ class CudaGraphManager:
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode: CUDAGraphMode
if self.compilation_config.cudagraph_mode is None:
self.cudagraph_mode = CUDAGraphMode.NONE
else: