mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:44:56 +08:00
Signed-off-by: morrison-turnansky <mturnans@redhat.com> Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Morrison Turnansky <mturnans@redhat.com> Co-authored-by: adabeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
00d3310d2d
commit
0838b52e2e
69
docs/design/optimization_levels.md
Normal file
69
docs/design/optimization_levels.md
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
<!-- markdownlint-disable -->
|
||||||
|
|
||||||
|
# Optimization Levels
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
vLLM now supports optimization levels (`-O0`, `-O1`, `-O2`, `-O3`). Optimization levels provide an intuitive mechnaism for users to trade startup time for performance. Higher levels have better performance but worse startup time. These optimization levels have associated defaults to help users get desired out of the box performance. Importantly, defaults set by optimization levels are purely defaults; explicit user settings will not be overwritten.
|
||||||
|
|
||||||
|
## Level Summaries and Usage Examples
|
||||||
|
```bash
|
||||||
|
# CLI usage
|
||||||
|
python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O0
|
||||||
|
|
||||||
|
# Python API usage
|
||||||
|
from vllm.entrypoints.llm import LLM
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="RedHatAI/Llama-3.2-1B-FP8",
|
||||||
|
optimization_level=0
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `-O1`: Quick Optimizations
|
||||||
|
- **Startup**: Moderate startup time
|
||||||
|
- **Performance**: Inductor compilation, CUDAGraphMode.PIECEWISE
|
||||||
|
- **Use case**: Balance for most development scenarios
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# CLI usage
|
||||||
|
python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O1
|
||||||
|
|
||||||
|
# Python API usage
|
||||||
|
from vllm.entrypoints.llm import LLM
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="RedHatAI/Llama-3.2-1B-FP8",
|
||||||
|
optimization_level=1
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `-O2`: Full Optimizations (Default)
|
||||||
|
- **Startup**: Longer startup time
|
||||||
|
- **Performance**: `-O1` + CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
- **Use case**: Production workloads where performance is important. This is the default use case. It is also very similar to the previous default. The primary difference is that noop & fusion flags are enabled.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# CLI usage (default, so optional)
|
||||||
|
python -m vllm.entrypoints.api_server --model RedHatAI/Llama-3.2-1B-FP8 -O2
|
||||||
|
|
||||||
|
# Python API usage
|
||||||
|
from vllm.entrypoints.llm import LLM
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="RedHatAI/Llama-3.2-1B-FP8",
|
||||||
|
optimization_level=2 # This is the default
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### `-O3`: Full Optimization
|
||||||
|
Still in development. Added infrastructure to prevent changing API in future
|
||||||
|
release. Currently behaves the same O2.
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Startup Time Too Long**: Use `-O0` or `-O1` for faster startup
|
||||||
|
2. **Compilation Errors**: Use `debug_dump_path` for additional debugging information
|
||||||
|
3. **Performance Issues**: Ensure using `-O2` for production
|
||||||
@ -172,8 +172,8 @@ def test_splitting_ops_dynamic():
|
|||||||
config = VllmConfig()
|
config = VllmConfig()
|
||||||
# Default V1 config leaves cudagraph mode unset; splitting ops are only
|
# Default V1 config leaves cudagraph mode unset; splitting ops are only
|
||||||
# populated when the engine decides to use piecewise compilation.
|
# populated when the engine decides to use piecewise compilation.
|
||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
assert not config.compilation_config.splitting_ops_contain_attention()
|
assert config.compilation_config.splitting_ops_contain_attention()
|
||||||
|
|
||||||
# When use_inductor_graph_partition=True
|
# When use_inductor_graph_partition=True
|
||||||
config = VllmConfig(
|
config = VllmConfig(
|
||||||
|
|||||||
@ -222,6 +222,47 @@ def test_media_io_kwargs_parser(arg, expected):
|
|||||||
assert args.media_io_kwargs == expected
|
assert args.media_io_kwargs == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("args", "expected"),
|
||||||
|
[
|
||||||
|
(["-O", "1"], "1"),
|
||||||
|
(["-O", "2"], "2"),
|
||||||
|
(["-O", "3"], "3"),
|
||||||
|
(["-O0"], "0"),
|
||||||
|
(["-O1"], "1"),
|
||||||
|
(["-O2"], "2"),
|
||||||
|
(["-O3"], "3"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_optimization_level(args, expected):
|
||||||
|
"""
|
||||||
|
Test space-separated optimization levels (-O 1, -O 2, -O 3) map to
|
||||||
|
optimization_level.
|
||||||
|
"""
|
||||||
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
|
parsed_args = parser.parse_args(args)
|
||||||
|
assert parsed_args.optimization_level == expected
|
||||||
|
assert parsed_args.compilation_config.mode is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("args", "expected"),
|
||||||
|
[
|
||||||
|
(["-O.mode=0"], 0),
|
||||||
|
(["-O.mode=1"], 1),
|
||||||
|
(["-O.mode=2"], 2),
|
||||||
|
(["-O.mode=3"], 3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mode_parser(args, expected):
|
||||||
|
"""
|
||||||
|
Test compilation config modes (-O.mode=int) map to compilation_config.
|
||||||
|
"""
|
||||||
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
|
parsed_args = parser.parse_args(args)
|
||||||
|
assert parsed_args.compilation_config.mode == expected
|
||||||
|
|
||||||
|
|
||||||
def test_compilation_config():
|
def test_compilation_config():
|
||||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
|
|
||||||
@ -229,22 +270,6 @@ def test_compilation_config():
|
|||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
assert args.compilation_config == CompilationConfig()
|
assert args.compilation_config == CompilationConfig()
|
||||||
|
|
||||||
# set to O3
|
|
||||||
args = parser.parse_args(["-O0"])
|
|
||||||
assert args.compilation_config.mode == 0
|
|
||||||
|
|
||||||
# set to O 3 (space)
|
|
||||||
args = parser.parse_args(["-O", "1"])
|
|
||||||
assert args.compilation_config.mode == 1
|
|
||||||
|
|
||||||
# set to O 3 (equals)
|
|
||||||
args = parser.parse_args(["-O=2"])
|
|
||||||
assert args.compilation_config.mode == 2
|
|
||||||
|
|
||||||
# set to O.mode 3
|
|
||||||
args = parser.parse_args(["-O.mode", "3"])
|
|
||||||
assert args.compilation_config.mode == 3
|
|
||||||
|
|
||||||
# set to string form of a dict
|
# set to string form of a dict
|
||||||
args = parser.parse_args(
|
args = parser.parse_args(
|
||||||
[
|
[
|
||||||
|
|||||||
@ -5,7 +5,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import (
|
||||||
|
CompilationConfig,
|
||||||
|
VllmConfig,
|
||||||
|
get_cached_compilation_config,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.activation import (
|
from vllm.model_executor.layers.activation import (
|
||||||
GeluAndMul,
|
GeluAndMul,
|
||||||
@ -86,6 +91,7 @@ def test_enabled_ops(
|
|||||||
backend=backend, mode=compilation_mode, custom_ops=custom_ops
|
backend=backend, mode=compilation_mode, custom_ops=custom_ops
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
get_cached_compilation_config.cache_clear()
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
assert CustomOp.default_on() == default_on
|
assert CustomOp.default_on() == default_on
|
||||||
|
|
||||||
|
|||||||
@ -8,9 +8,20 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.compilation.backends import VllmBackend
|
from vllm.compilation.backends import VllmBackend
|
||||||
from vllm.config import ModelConfig, PoolerConfig, VllmConfig, update_config
|
from vllm.config import (
|
||||||
|
CompilationConfig,
|
||||||
|
ModelConfig,
|
||||||
|
PoolerConfig,
|
||||||
|
VllmConfig,
|
||||||
|
update_config,
|
||||||
|
)
|
||||||
|
from vllm.config.compilation import CompilationMode, CUDAGraphMode
|
||||||
from vllm.config.load import LoadConfig
|
from vllm.config.load import LoadConfig
|
||||||
from vllm.config.utils import get_field
|
from vllm.config.utils import get_field
|
||||||
|
from vllm.config.vllm import (
|
||||||
|
OPTIMIZATION_LEVEL_TO_CONFIG,
|
||||||
|
OptimizationLevel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.pooler import PoolingType
|
from vllm.model_executor.layers.pooler import PoolingType
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -235,6 +246,43 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
|
|||||||
assert model_config.pooler_config.pooling_type == pooling_type
|
assert model_config.pooler_config.pooling_type == pooling_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_id", "expected_is_moe_model"),
|
||||||
|
[
|
||||||
|
("RedHatAI/Qwen3-8B-speculator.eagle3", False),
|
||||||
|
("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", False),
|
||||||
|
("RedHatAI/Llama-3.2-1B-FP8", False),
|
||||||
|
("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", False),
|
||||||
|
("RedHatAI/gpt-oss-20b", True),
|
||||||
|
("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
|
||||||
|
("RedHatAI/Llama-4-Scout-17B-16E-Instruct", True),
|
||||||
|
("RedHatAI/Mixtral-8x7B-Instruct-v0.1", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_moe_model_detection(model_id, expected_is_moe_model):
|
||||||
|
model_config = ModelConfig(model_id)
|
||||||
|
# Just check that is_moe_model field exists and is a boolean
|
||||||
|
assert model_config.is_model_moe() == expected_is_moe_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_id", "quantized"),
|
||||||
|
[
|
||||||
|
("RedHatAI/Qwen3-8B-speculator.eagle3", False),
|
||||||
|
("RedHatAI/Llama-3.1-8B-Instruct-NVFP4", True),
|
||||||
|
("RedHatAI/Llama-3.2-1B-FP8", True),
|
||||||
|
("RedHatAI/Mistral-Small-24B-Instruct-2501-quantized.w8a8", True),
|
||||||
|
("RedHatAI/gpt-oss-20b", True),
|
||||||
|
("RedHatAI/DeepSeek-V2.5-1210-FP8", True),
|
||||||
|
("RedHatAI/Mixtral-8x7B-Instruct-v0.1", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_quantized(model_id, quantized):
|
||||||
|
model_config = ModelConfig(model_id)
|
||||||
|
# Just check that quantized field exists and is a boolean
|
||||||
|
assert model_config.is_quantized() == quantized
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
|
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
|
||||||
)
|
)
|
||||||
@ -552,3 +600,260 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
|
|||||||
assert os.path.exists(config1.tokenizer) and os.path.isdir(config1.tokenizer)
|
assert os.path.exists(config1.tokenizer) and os.path.isdir(config1.tokenizer)
|
||||||
assert os.path.exists(config2.model) and os.path.isdir(config2.model)
|
assert os.path.exists(config2.model) and os.path.isdir(config2.model)
|
||||||
assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
|
assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("backend", "custom_ops", "expected"),
|
||||||
|
[
|
||||||
|
("eager", [], True),
|
||||||
|
("eager", ["+fused_layernorm"], True),
|
||||||
|
("eager", ["all", "-fused_layernorm"], False),
|
||||||
|
("inductor", [], False),
|
||||||
|
("inductor", ["none", "+fused_layernorm"], True),
|
||||||
|
("inductor", ["none", "-fused_layernorm"], False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_custom_op_enabled(backend: str, custom_ops: list[str], expected: bool):
|
||||||
|
"""Test that is_custom_op_enabled works correctly."""
|
||||||
|
config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(backend=backend, custom_ops=custom_ops)
|
||||||
|
)
|
||||||
|
assert config.compilation_config.is_custom_op_enabled("fused_layernorm") is expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_vllm_config_defaults_are_none():
|
||||||
|
"""Verify that optimization-level defaults are None when not set by user."""
|
||||||
|
# Test all optimization levels to ensure defaults work correctly
|
||||||
|
for opt_level in OptimizationLevel:
|
||||||
|
config = object.__new__(VllmConfig)
|
||||||
|
config.compilation_config = CompilationConfig()
|
||||||
|
config.optimization_level = opt_level
|
||||||
|
config.model_config = None
|
||||||
|
|
||||||
|
# Use the global optimization level defaults
|
||||||
|
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[opt_level]
|
||||||
|
|
||||||
|
# Verify that all pass_config values are None before defaults are applied
|
||||||
|
for pass_k in default_config["compilation_config"]["pass_config"]:
|
||||||
|
assert getattr(config.compilation_config.pass_config, pass_k) is None
|
||||||
|
|
||||||
|
# Verify that other config values are None before defaults are applied
|
||||||
|
for k in default_config["compilation_config"]:
|
||||||
|
if k != "pass_config":
|
||||||
|
assert getattr(config.compilation_config, k) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_id", "compiliation_config", "optimization_level"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
None,
|
||||||
|
CompilationConfig(backend="eager", custom_ops=["+quant_fp8"]),
|
||||||
|
OptimizationLevel.O0,
|
||||||
|
),
|
||||||
|
(None, CompilationConfig(), OptimizationLevel.O0),
|
||||||
|
(None, CompilationConfig(), OptimizationLevel.O1),
|
||||||
|
(None, CompilationConfig(), OptimizationLevel.O2),
|
||||||
|
(None, CompilationConfig(), OptimizationLevel.O3),
|
||||||
|
(
|
||||||
|
"RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||||
|
CompilationConfig(backend="inductor", custom_ops=["+quant_fp8"]),
|
||||||
|
OptimizationLevel.O2,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||||
|
CompilationConfig(),
|
||||||
|
OptimizationLevel.O0,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||||
|
CompilationConfig(),
|
||||||
|
OptimizationLevel.O1,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||||
|
CompilationConfig(),
|
||||||
|
OptimizationLevel.O2,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||||
|
CompilationConfig(),
|
||||||
|
OptimizationLevel.O3,
|
||||||
|
),
|
||||||
|
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O0),
|
||||||
|
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O1),
|
||||||
|
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O2),
|
||||||
|
("RedHatAI/DeepSeek-V2.5-1210-FP8", CompilationConfig(), OptimizationLevel.O3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_vllm_config_defaults(model_id, compiliation_config, optimization_level):
|
||||||
|
"""Test that optimization-level defaults are correctly applied."""
|
||||||
|
|
||||||
|
model_config = None
|
||||||
|
if model_id is not None:
|
||||||
|
model_config = ModelConfig(model_id)
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
compilation_config=compiliation_config,
|
||||||
|
optimization_level=optimization_level,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=compiliation_config,
|
||||||
|
optimization_level=optimization_level,
|
||||||
|
)
|
||||||
|
# Use the global optimization level defaults
|
||||||
|
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[optimization_level]
|
||||||
|
|
||||||
|
# Verify pass_config defaults (nested under compilation_config)
|
||||||
|
pass_config_dict = default_config["compilation_config"]["pass_config"]
|
||||||
|
for pass_k, pass_v in pass_config_dict.items():
|
||||||
|
actual = getattr(vllm_config.compilation_config.pass_config, pass_k)
|
||||||
|
expected = pass_v(vllm_config) if callable(pass_v) else pass_v
|
||||||
|
assert actual == expected, (
|
||||||
|
f"pass_config.{pass_k}: expected {expected}, got {actual}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify other compilation_config defaults
|
||||||
|
compilation_config_dict = default_config["compilation_config"]
|
||||||
|
for k, v in compilation_config_dict.items():
|
||||||
|
if k != "pass_config":
|
||||||
|
actual = getattr(vllm_config.compilation_config, k)
|
||||||
|
expected = v(vllm_config) if callable(v) else v
|
||||||
|
assert actual == expected, (
|
||||||
|
f"compilation_config.{k}: expected {expected}, got {actual}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vllm_config_callable_defaults():
|
||||||
|
"""Test that callable defaults work in the config system.
|
||||||
|
|
||||||
|
Verifies that lambdas in default configs can inspect VllmConfig properties
|
||||||
|
(e.g., is_quantized, is_model_moe) to conditionally set optimization flags.
|
||||||
|
"""
|
||||||
|
config_no_model = VllmConfig(optimization_level=OptimizationLevel.O2)
|
||||||
|
|
||||||
|
# Callable that checks if model exists
|
||||||
|
has_model = lambda cfg: cfg.model_config is not None
|
||||||
|
assert has_model(config_no_model) is False
|
||||||
|
|
||||||
|
# Test with quantized model
|
||||||
|
quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
|
||||||
|
config_quantized = VllmConfig(
|
||||||
|
model_config=quantized_model, optimization_level=OptimizationLevel.O2
|
||||||
|
)
|
||||||
|
enable_if_quantized = lambda cfg: (
|
||||||
|
cfg.model_config is not None and cfg.model_config.is_quantized()
|
||||||
|
)
|
||||||
|
assert enable_if_quantized(config_quantized) is True
|
||||||
|
assert enable_if_quantized(config_no_model) is False
|
||||||
|
|
||||||
|
# Test with MoE model
|
||||||
|
moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
|
||||||
|
config_moe = VllmConfig(
|
||||||
|
model_config=moe_model, optimization_level=OptimizationLevel.O2
|
||||||
|
)
|
||||||
|
enable_if_sequential = lambda cfg: (
|
||||||
|
cfg.model_config is not None and not cfg.model_config.is_model_moe()
|
||||||
|
)
|
||||||
|
assert enable_if_sequential(config_moe) is False
|
||||||
|
assert enable_if_sequential(config_quantized) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_vllm_config_explicit_overrides():
|
||||||
|
"""Test that explicit property overrides work correctly with callable defaults.
|
||||||
|
|
||||||
|
When users explicitly set configuration properties, those values
|
||||||
|
take precedence over callable defaults, across different models and
|
||||||
|
optimization levels.
|
||||||
|
"""
|
||||||
|
from vllm.config.compilation import PassConfig
|
||||||
|
|
||||||
|
quantized_model = ModelConfig("RedHatAI/Llama-3.2-1B-FP8")
|
||||||
|
moe_model = ModelConfig("deepseek-ai/DeepSeek-V2-Lite")
|
||||||
|
regular_model = ModelConfig("Qwen/Qwen1.5-7B")
|
||||||
|
|
||||||
|
# Explicit compilation mode override on O0 (where default is NONE)
|
||||||
|
compilation_config = CompilationConfig(mode=CompilationMode.VLLM_COMPILE)
|
||||||
|
config = VllmConfig(
|
||||||
|
optimization_level=OptimizationLevel.O0,
|
||||||
|
compilation_config=compilation_config,
|
||||||
|
)
|
||||||
|
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||||
|
|
||||||
|
# Explicit pass config flags to override defaults
|
||||||
|
pass_config = PassConfig(enable_noop=True, enable_attn_fusion=True)
|
||||||
|
compilation_config = CompilationConfig(pass_config=pass_config)
|
||||||
|
config = VllmConfig(
|
||||||
|
optimization_level=OptimizationLevel.O0,
|
||||||
|
compilation_config=compilation_config,
|
||||||
|
)
|
||||||
|
assert config.compilation_config.pass_config.enable_noop is True
|
||||||
|
assert config.compilation_config.pass_config.enable_attn_fusion is True
|
||||||
|
|
||||||
|
# Explicit cudagraph mode override on quantized model at O2
|
||||||
|
pass_config = PassConfig(enable_async_tp=True)
|
||||||
|
compilation_config = CompilationConfig(
|
||||||
|
cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
|
||||||
|
)
|
||||||
|
config = VllmConfig(
|
||||||
|
model_config=quantized_model,
|
||||||
|
optimization_level=OptimizationLevel.O2,
|
||||||
|
compilation_config=compilation_config,
|
||||||
|
)
|
||||||
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||||
|
assert config.compilation_config.pass_config.enable_async_tp is True
|
||||||
|
# Mode should still use default for O2
|
||||||
|
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
|
||||||
|
# Different optimization levels with same model
|
||||||
|
config_o0 = VllmConfig(
|
||||||
|
model_config=regular_model, optimization_level=OptimizationLevel.O0
|
||||||
|
)
|
||||||
|
config_o2 = VllmConfig(
|
||||||
|
model_config=regular_model, optimization_level=OptimizationLevel.O2
|
||||||
|
)
|
||||||
|
assert config_o0.compilation_config.mode == CompilationMode.NONE
|
||||||
|
assert config_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
assert config_o0.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||||
|
assert (
|
||||||
|
config_o2.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Same optimization level across different model types
|
||||||
|
config_moe_o2 = VllmConfig(
|
||||||
|
model_config=moe_model, optimization_level=OptimizationLevel.O2
|
||||||
|
)
|
||||||
|
config_regular_o2 = VllmConfig(
|
||||||
|
model_config=regular_model, optimization_level=OptimizationLevel.O2
|
||||||
|
)
|
||||||
|
config_quantized_o2 = VllmConfig(
|
||||||
|
model_config=quantized_model, optimization_level=OptimizationLevel.O2
|
||||||
|
)
|
||||||
|
# All should have same base compilation settings at O2
|
||||||
|
assert config_moe_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
assert config_regular_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
assert config_quantized_o2.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
assert (
|
||||||
|
config_moe_o2.compilation_config.cudagraph_mode
|
||||||
|
== CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
config_regular_o2.compilation_config.cudagraph_mode
|
||||||
|
== CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override one field but not others
|
||||||
|
pass_config = PassConfig(enable_noop=False)
|
||||||
|
compilation_config = CompilationConfig(pass_config=pass_config)
|
||||||
|
config = VllmConfig(
|
||||||
|
model_config=regular_model,
|
||||||
|
optimization_level=OptimizationLevel.O2,
|
||||||
|
compilation_config=compilation_config,
|
||||||
|
)
|
||||||
|
# Explicit override should be respected
|
||||||
|
assert config.compilation_config.pass_config.enable_noop is False
|
||||||
|
# Other fields should still use defaults
|
||||||
|
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
|||||||
@ -28,6 +28,7 @@ def parser():
|
|||||||
parser.add_argument("--enable-feature", action="store_true")
|
parser.add_argument("--enable-feature", action="store_true")
|
||||||
parser.add_argument("--hf-overrides", type=json.loads)
|
parser.add_argument("--hf-overrides", type=json.loads)
|
||||||
parser.add_argument("-O", "--compilation-config", type=json.loads)
|
parser.add_argument("-O", "--compilation-config", type=json.loads)
|
||||||
|
parser.add_argument("--optimization-level", type=int)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -217,8 +218,8 @@ def test_dict_args(parser):
|
|||||||
"key15": "-minus.and.dot",
|
"key15": "-minus.and.dot",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
assert parsed_args.optimization_level == 1
|
||||||
assert parsed_args.compilation_config == {
|
assert parsed_args.compilation_config == {
|
||||||
"mode": 1,
|
|
||||||
"use_inductor_graph_partition": True,
|
"use_inductor_graph_partition": True,
|
||||||
"backend": "custom",
|
"backend": "custom",
|
||||||
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
|
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
|
||||||
@ -241,12 +242,13 @@ def test_duplicate_dict_args(caplog_vllm, parser):
|
|||||||
parsed_args = parser.parse_args(args)
|
parsed_args = parser.parse_args(args)
|
||||||
# Should be the last value
|
# Should be the last value
|
||||||
assert parsed_args.hf_overrides == {"key1": "val2"}
|
assert parsed_args.hf_overrides == {"key1": "val2"}
|
||||||
assert parsed_args.compilation_config == {"mode": 3}
|
assert parsed_args.optimization_level == 3
|
||||||
|
assert parsed_args.compilation_config == {"mode": 2}
|
||||||
|
|
||||||
assert len(caplog_vllm.records) == 1
|
assert len(caplog_vllm.records) == 1
|
||||||
assert "duplicate" in caplog_vllm.text
|
assert "duplicate" in caplog_vllm.text
|
||||||
assert "--hf-overrides.key1" in caplog_vllm.text
|
assert "--hf-overrides.key1" in caplog_vllm.text
|
||||||
assert "-O.mode" in caplog_vllm.text
|
assert "--optimization-level" in caplog_vllm.text
|
||||||
|
|
||||||
|
|
||||||
def test_model_specification(
|
def test_model_specification(
|
||||||
@ -383,7 +385,7 @@ def test_compilation_mode_string_values(parser):
|
|||||||
assert args.compilation_config == {"mode": 0}
|
assert args.compilation_config == {"mode": 0}
|
||||||
|
|
||||||
args = parser.parse_args(["-O3"])
|
args = parser.parse_args(["-O3"])
|
||||||
assert args.compilation_config == {"mode": 3}
|
assert args.optimization_level == 3
|
||||||
|
|
||||||
args = parser.parse_args(["-O.mode=NONE"])
|
args = parser.parse_args(["-O.mode=NONE"])
|
||||||
assert args.compilation_config == {"mode": "NONE"}
|
assert args.compilation_config == {"mode": "NONE"}
|
||||||
|
|||||||
@ -117,9 +117,9 @@ else:
|
|||||||
combo_cases_2 = [
|
combo_cases_2 = [
|
||||||
("FA2", "FULL", CompilationMode.NONE, True),
|
("FA2", "FULL", CompilationMode.NONE, True),
|
||||||
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
|
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
|
||||||
("FA2", "PIECEWISE", CompilationMode.NONE, False),
|
("FA2", "PIECEWISE", CompilationMode.NONE, True),
|
||||||
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||||
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, False),
|
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
|
||||||
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
|
||||||
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
|
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
|
||||||
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
|
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from dataclasses import asdict, field
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||||
|
|
||||||
from pydantic import TypeAdapter, field_validator
|
from pydantic import Field, TypeAdapter, field_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -97,19 +97,25 @@ class PassConfig:
|
|||||||
|
|
||||||
This is separate from general `CompilationConfig` so that inductor passes
|
This is separate from general `CompilationConfig` so that inductor passes
|
||||||
don't all have access to full configuration - that would create a cycle as
|
don't all have access to full configuration - that would create a cycle as
|
||||||
the `PassManager` is set as a property of config."""
|
the `PassManager` is set as a property of config.
|
||||||
|
|
||||||
enable_fusion: bool = False
|
You must pass PassConfig to VLLMConfig constructor via the CompilationConfig
|
||||||
|
constructor. VLLMConfig's post_init does further initialization.
|
||||||
|
If used outside of the VLLMConfig, some fields may be left in an
|
||||||
|
improper state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enable_fusion: bool = Field(default=None)
|
||||||
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
|
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
|
||||||
enable_attn_fusion: bool = False
|
enable_attn_fusion: bool = Field(default=None)
|
||||||
"""Whether to enable the custom attention+quant fusion pass."""
|
"""Whether to enable the custom attention+quant fusion pass."""
|
||||||
enable_noop: bool = False
|
enable_noop: bool = Field(default=None)
|
||||||
"""Whether to enable the custom no-op elimination pass."""
|
"""Whether to enable the custom no-op elimination pass."""
|
||||||
enable_sequence_parallelism: bool = False
|
enable_sequence_parallelism: bool = Field(default=None)
|
||||||
"""Whether to enable sequence parallelism."""
|
"""Whether to enable sequence parallelism."""
|
||||||
enable_async_tp: bool = False
|
enable_async_tp: bool = Field(default=None)
|
||||||
"""Whether to enable async TP."""
|
"""Whether to enable async TP."""
|
||||||
enable_fi_allreduce_fusion: bool = False
|
enable_fi_allreduce_fusion: bool = Field(default=None)
|
||||||
"""Whether to enable flashinfer allreduce fusion."""
|
"""Whether to enable flashinfer allreduce fusion."""
|
||||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||||
"""The threshold of the communicated tensor sizes under which
|
"""The threshold of the communicated tensor sizes under which
|
||||||
@ -167,6 +173,22 @@ class PassConfig:
|
|||||||
"""
|
"""
|
||||||
return InductorPass.hash_dict(asdict(self))
|
return InductorPass.hash_dict(asdict(self))
|
||||||
|
|
||||||
|
@field_validator(
|
||||||
|
"enable_fusion",
|
||||||
|
"enable_attn_fusion",
|
||||||
|
"enable_noop",
|
||||||
|
"enable_sequence_parallelism",
|
||||||
|
"enable_async_tp",
|
||||||
|
"enable_fi_allreduce_fusion",
|
||||||
|
mode="wrap",
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||||
|
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return handler(value)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if not self.enable_noop:
|
if not self.enable_noop:
|
||||||
if self.enable_fusion:
|
if self.enable_fusion:
|
||||||
@ -243,7 +265,13 @@ class DynamicShapesConfig:
|
|||||||
@config
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompilationConfig:
|
class CompilationConfig:
|
||||||
"""Configuration for compilation. It has three parts:
|
"""Configuration for compilation.
|
||||||
|
|
||||||
|
You must pass CompilationConfig to VLLMConfig constructor.
|
||||||
|
VLLMConfig's post_init does further initialization. If used outside of the
|
||||||
|
VLLMConfig, some fields will be left in an improper state.
|
||||||
|
|
||||||
|
It has three parts:
|
||||||
|
|
||||||
- Top-level Compilation control:
|
- Top-level Compilation control:
|
||||||
- [`mode`][vllm.config.CompilationConfig.mode]
|
- [`mode`][vllm.config.CompilationConfig.mode]
|
||||||
@ -282,14 +310,14 @@ class CompilationConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Top-level Compilation control
|
# Top-level Compilation control
|
||||||
level: int | None = None
|
level: int = Field(default=None)
|
||||||
"""
|
"""
|
||||||
Level is deprecated and will be removed in the next release,
|
Level is deprecated and will be removed in the next release,
|
||||||
either 0.12.0 or 0.11.2 whichever is soonest.
|
either 0.12.0 or 0.11.2 whichever is soonest.
|
||||||
Please use mode. Currently all levels are mapped to mode.
|
Please use mode. Currently all levels are mapped to mode.
|
||||||
"""
|
"""
|
||||||
# Top-level Compilation control
|
# Top-level Compilation control
|
||||||
mode: CompilationMode | None = None
|
mode: CompilationMode = Field(default=None)
|
||||||
"""The compilation approach used for torch.compile-based compilation of the
|
"""The compilation approach used for torch.compile-based compilation of the
|
||||||
model.
|
model.
|
||||||
|
|
||||||
@ -390,7 +418,7 @@ class CompilationConfig:
|
|||||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
||||||
|
|
||||||
# CudaGraph compilation
|
# CudaGraph compilation
|
||||||
cudagraph_mode: CUDAGraphMode | None = None
|
cudagraph_mode: CUDAGraphMode = Field(default=None)
|
||||||
"""
|
"""
|
||||||
The mode of the cudagraph:
|
The mode of the cudagraph:
|
||||||
|
|
||||||
@ -452,7 +480,7 @@ class CompilationConfig:
|
|||||||
When `enable_lora` is False, this option has no effect.
|
When `enable_lora` is False, this option has no effect.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
use_inductor_graph_partition: bool = False
|
use_inductor_graph_partition: bool = Field(default=None)
|
||||||
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
|
"""Use inductor graph partition to split the graph at cudagraph_unsafe ops.
|
||||||
This partition happens at inductor codegen time after all passes and fusions
|
This partition happens at inductor codegen time after all passes and fusions
|
||||||
are finished. It generates a single `call` function which wraps
|
are finished. It generates a single `call` function which wraps
|
||||||
@ -648,6 +676,20 @@ class CompilationConfig:
|
|||||||
)
|
)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@field_validator(
|
||||||
|
"level",
|
||||||
|
"mode",
|
||||||
|
"cudagraph_mode",
|
||||||
|
"use_inductor_graph_partition",
|
||||||
|
mode="wrap",
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||||
|
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return handler(value)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.level is not None:
|
if self.level is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -948,6 +990,13 @@ class CompilationConfig:
|
|||||||
op,
|
op,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_custom_op_enabled(self, op: str) -> bool:
|
||||||
|
if "all" in self.custom_ops:
|
||||||
|
return f"-{op}" not in self.custom_ops
|
||||||
|
|
||||||
|
assert "none" in self.custom_ops
|
||||||
|
return f"+{op}" in self.custom_ops
|
||||||
|
|
||||||
def adjust_cudagraph_sizes_for_spec_decode(
|
def adjust_cudagraph_sizes_for_spec_decode(
|
||||||
self, uniform_decode_query_len: int, tensor_parallel_size: int
|
self, uniform_decode_query_len: int, tensor_parallel_size: int
|
||||||
):
|
):
|
||||||
|
|||||||
@ -1752,6 +1752,14 @@ class ModelConfig:
|
|||||||
logger.info("Using max model len %s", max_model_len)
|
logger.info("Using max model len %s", max_model_len)
|
||||||
return max_model_len
|
return max_model_len
|
||||||
|
|
||||||
|
def is_model_moe(
|
||||||
|
self,
|
||||||
|
) -> bool:
|
||||||
|
return self.get_num_experts() > 1
|
||||||
|
|
||||||
|
def is_quantized(self) -> bool:
|
||||||
|
return getattr(self.hf_config, "quantization_config", None) is not None
|
||||||
|
|
||||||
|
|
||||||
def get_served_model_name(model: str, served_model_name: str | list[str] | None):
|
def get_served_model_name(model: str, served_model_name: str | list[str] | None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -9,8 +9,9 @@ import tempfile
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import replace
|
from dataclasses import is_dataclass, replace
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import IntEnum
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, TypeVar, get_args
|
from typing import TYPE_CHECKING, Any, TypeVar, get_args
|
||||||
@ -57,6 +58,103 @@ else:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizationLevel(IntEnum):
|
||||||
|
"""Optimization level enum."""
|
||||||
|
|
||||||
|
O0 = 0
|
||||||
|
"""O0 : No optimization. no compilation, no cudagraphs, no other
|
||||||
|
optimization, just starting up immediately"""
|
||||||
|
O1 = 1
|
||||||
|
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
|
||||||
|
cudagraphs"""
|
||||||
|
O2 = 2
|
||||||
|
"""O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
|
||||||
|
O3 = 3
|
||||||
|
"""O3: Currently the same as -O2s."""
|
||||||
|
|
||||||
|
|
||||||
|
IS_QUANTIZED = False
|
||||||
|
IS_DENSE = False
|
||||||
|
# The optimizations that depend on these properties currently set to False
|
||||||
|
# in all cases.
|
||||||
|
# if model_config is not None:
|
||||||
|
# IS_QUANTIZED = lambda c: c.model_config.is_quantized()
|
||||||
|
# IS_DENSE = lambda c: not c.model_config.is_model_moe()
|
||||||
|
# See https://github.com/vllm-project/vllm/issues/25689.
|
||||||
|
|
||||||
|
|
||||||
|
def enable_fusion(cfg: "VllmConfig") -> bool:
|
||||||
|
"""Returns True if RMS norm or quant FP8 is enabled."""
|
||||||
|
return cfg.compilation_config.is_custom_op_enabled(
|
||||||
|
"rms_norm"
|
||||||
|
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
||||||
|
|
||||||
|
|
||||||
|
OPTIMIZATION_LEVEL_00 = {
|
||||||
|
"compilation_config": {
|
||||||
|
"pass_config": {
|
||||||
|
"enable_noop": False,
|
||||||
|
"enable_fusion": False,
|
||||||
|
"enable_fi_allreduce_fusion": False,
|
||||||
|
"enable_attn_fusion": False,
|
||||||
|
"enable_sequence_parallelism": False,
|
||||||
|
"enable_async_tp": False,
|
||||||
|
},
|
||||||
|
"cudagraph_mode": CUDAGraphMode.NONE,
|
||||||
|
"use_inductor_graph_partition": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
OPTIMIZATION_LEVEL_01 = {
|
||||||
|
"compilation_config": {
|
||||||
|
"pass_config": {
|
||||||
|
"enable_noop": True,
|
||||||
|
"enable_fusion": enable_fusion,
|
||||||
|
"enable_fi_allreduce_fusion": False,
|
||||||
|
"enable_attn_fusion": False,
|
||||||
|
"enable_sequence_parallelism": False,
|
||||||
|
"enable_async_tp": False,
|
||||||
|
},
|
||||||
|
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
||||||
|
"use_inductor_graph_partition": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
OPTIMIZATION_LEVEL_02 = {
|
||||||
|
"compilation_config": {
|
||||||
|
"pass_config": {
|
||||||
|
"enable_noop": True,
|
||||||
|
"enable_fusion": enable_fusion,
|
||||||
|
"enable_fi_allreduce_fusion": False,
|
||||||
|
"enable_attn_fusion": IS_QUANTIZED,
|
||||||
|
"enable_sequence_parallelism": IS_DENSE,
|
||||||
|
"enable_async_tp": IS_DENSE,
|
||||||
|
},
|
||||||
|
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||||
|
"use_inductor_graph_partition": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
OPTIMIZATION_LEVEL_03 = {
|
||||||
|
"compilation_config": {
|
||||||
|
"pass_config": {
|
||||||
|
"enable_noop": True,
|
||||||
|
"enable_fusion": enable_fusion,
|
||||||
|
"enable_fi_allreduce_fusion": False,
|
||||||
|
"enable_attn_fusion": IS_QUANTIZED,
|
||||||
|
"enable_sequence_parallelism": IS_DENSE,
|
||||||
|
"enable_async_tp": IS_DENSE,
|
||||||
|
},
|
||||||
|
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||||
|
"use_inductor_graph_partition": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
OPTIMIZATION_LEVEL_TO_CONFIG = {
|
||||||
|
OptimizationLevel.O0: OPTIMIZATION_LEVEL_00,
|
||||||
|
OptimizationLevel.O1: OPTIMIZATION_LEVEL_01,
|
||||||
|
OptimizationLevel.O2: OPTIMIZATION_LEVEL_02,
|
||||||
|
OptimizationLevel.O3: OPTIMIZATION_LEVEL_03,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||||
class VllmConfig:
|
class VllmConfig:
|
||||||
@ -116,6 +214,11 @@ class VllmConfig:
|
|||||||
you are using. Contents must be hashable."""
|
you are using. Contents must be hashable."""
|
||||||
instance_id: str = ""
|
instance_id: str = ""
|
||||||
"""The ID of the vLLM instance."""
|
"""The ID of the vLLM instance."""
|
||||||
|
optimization_level: OptimizationLevel = OptimizationLevel.O2
|
||||||
|
"""The optimization level. These levels trade startup time cost for
|
||||||
|
performance, with -O0 having the best startup time and -O3 having the best
|
||||||
|
performance. -02 is used by defult. See OptimizationLevel for full
|
||||||
|
description."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -297,6 +400,50 @@ class VllmConfig:
|
|||||||
|
|
||||||
return replace(self, model_config=model_config)
|
return replace(self, model_config=model_config)
|
||||||
|
|
||||||
|
def _set_config_default(self, config_obj: Any, key: str, value: Any) -> None:
|
||||||
|
"""Set config attribute to default if not already set by user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_obj: Configuration object to update.
|
||||||
|
key: Attribute name.
|
||||||
|
value: Default value (static or callable).
|
||||||
|
"""
|
||||||
|
if getattr(config_obj, key) is None:
|
||||||
|
# Some config values are known before initialization and are
|
||||||
|
# hard coded.
|
||||||
|
# Other values depend on the user given configuration, so they are
|
||||||
|
# implemented with lambda functions and decided at run time.
|
||||||
|
setattr(config_obj, key, value(self) if callable(value) else value)
|
||||||
|
|
||||||
|
def _apply_optimization_level_defaults(self, defaults: dict[str, Any]) -> None:
|
||||||
|
"""Apply optimization level defaults using self as root.
|
||||||
|
|
||||||
|
Recursively applies values from defaults into nested config objects.
|
||||||
|
Only fields present in defaults are overwritten.
|
||||||
|
|
||||||
|
If the user configuration does not specify a value for a default field
|
||||||
|
and if the default field is still None after all user selections are
|
||||||
|
applied, then default values will be applied to the field. User speciied
|
||||||
|
fields will not be overridden by the default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
defaults: Dictionary of default values to apply.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def apply_recursive(config_obj: Any, config_defaults: dict[str, Any]) -> None:
|
||||||
|
"""Recursively apply defaults to config_obj, using self as root."""
|
||||||
|
for key, value in config_defaults.items():
|
||||||
|
if not hasattr(config_obj, key):
|
||||||
|
continue
|
||||||
|
|
||||||
|
current = getattr(config_obj, key)
|
||||||
|
if isinstance(value, dict) and is_dataclass(current):
|
||||||
|
apply_recursive(current, value)
|
||||||
|
else:
|
||||||
|
self._set_config_default(config_obj, key, value)
|
||||||
|
|
||||||
|
apply_recursive(self, defaults)
|
||||||
|
|
||||||
def _post_init_kv_transfer_config(self) -> None:
|
def _post_init_kv_transfer_config(self) -> None:
|
||||||
"""Update KVTransferConfig based on top-level configs in VllmConfig.
|
"""Update KVTransferConfig based on top-level configs in VllmConfig.
|
||||||
|
|
||||||
@ -434,17 +581,47 @@ class VllmConfig:
|
|||||||
"precision for chunked prefill triton kernels."
|
"precision for chunked prefill triton kernels."
|
||||||
)
|
)
|
||||||
|
|
||||||
# If the user does not explicitly set a compilation mode, then
|
if (
|
||||||
# we use the default mode. The default mode depends on other
|
self.optimization_level > OptimizationLevel.O0
|
||||||
# settings (see the below code).
|
and self.model_config is not None
|
||||||
|
and self.model_config.enforce_eager
|
||||||
|
):
|
||||||
|
logger.warning("Enforce eager set, overriding optimization level to -O0")
|
||||||
|
self.optimization_level = OptimizationLevel.O0
|
||||||
|
|
||||||
|
if self.compilation_config.backend == "eager" or (
|
||||||
|
self.compilation_config.mode is not None
|
||||||
|
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Inductor compilation was disabled by user settings,"
|
||||||
|
"Optimizations settings that are only active during"
|
||||||
|
"Inductor compilation will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
def has_blocked_weights():
|
||||||
|
if self.quant_config is not None:
|
||||||
|
if hasattr(self.quant_config, "weight_block_size"):
|
||||||
|
return self.quant_config.weight_block_size is not None
|
||||||
|
elif hasattr(self.quant_config, "has_blocked_weights"):
|
||||||
|
return self.quant_config.has_blocked_weights()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Enable quant_fp8 CUDA ops (TODO disable in follow up)
|
||||||
|
# On H100 the CUDA kernel is faster than
|
||||||
|
# native implementation
|
||||||
|
# https://github.com/vllm-project/vllm/issues/25094
|
||||||
|
if has_blocked_weights():
|
||||||
|
custom_ops = self.compilation_config.custom_ops
|
||||||
|
if "-quant_fp8" not in custom_ops:
|
||||||
|
custom_ops.append("+quant_fp8")
|
||||||
|
|
||||||
if self.compilation_config.mode is None:
|
if self.compilation_config.mode is None:
|
||||||
if self.model_config is not None and not self.model_config.enforce_eager:
|
if self.optimization_level > OptimizationLevel.O0:
|
||||||
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
||||||
else:
|
else:
|
||||||
self.compilation_config.mode = CompilationMode.NONE
|
self.compilation_config.mode = CompilationMode.NONE
|
||||||
|
|
||||||
# If user does not set custom ops via none or all set it here based on
|
|
||||||
# compilation mode and backend.
|
|
||||||
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
|
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
|
||||||
if (
|
if (
|
||||||
self.compilation_config.backend == "inductor"
|
self.compilation_config.backend == "inductor"
|
||||||
@ -454,23 +631,33 @@ class VllmConfig:
|
|||||||
else:
|
else:
|
||||||
self.compilation_config.custom_ops.append("all")
|
self.compilation_config.custom_ops.append("all")
|
||||||
|
|
||||||
|
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
|
||||||
|
self._apply_optimization_level_defaults(default_config)
|
||||||
|
if (
|
||||||
|
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
|
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Cudagraph mode %s is not compatible with compilation mode %s."
|
||||||
|
"Overriding to NONE.",
|
||||||
|
self.compilation_config.cudagraph_mode,
|
||||||
|
self.compilation_config.mode,
|
||||||
|
)
|
||||||
|
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
# async tp is built on top of sequence parallelism
|
# async tp is built on top of sequence parallelism
|
||||||
# and requires it to be enabled.
|
# and requires it to be enabled.
|
||||||
if self.compilation_config.pass_config.enable_async_tp:
|
if self.compilation_config.pass_config.enable_async_tp:
|
||||||
self.compilation_config.pass_config.enable_sequence_parallelism = True
|
self.compilation_config.pass_config.enable_sequence_parallelism = True
|
||||||
|
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||||
|
if "-rms_norm" in self.compilation_config.custom_ops:
|
||||||
|
logger.warning(
|
||||||
|
"RMS norm force disabled, sequence parallelism might break"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.compilation_config.custom_ops.append("+rms_norm")
|
||||||
|
|
||||||
if current_platform.support_static_graph_mode():
|
if current_platform.support_static_graph_mode():
|
||||||
# if cudagraph_mode is not explicitly set by users, set default
|
|
||||||
# value
|
|
||||||
if self.compilation_config.cudagraph_mode is None:
|
|
||||||
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
|
||||||
# default to full and piecewise for most models
|
|
||||||
self.compilation_config.cudagraph_mode = (
|
|
||||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
|
||||||
|
|
||||||
# if cudagraph_mode has full cudagraphs, we need to check support
|
# if cudagraph_mode has full cudagraphs, we need to check support
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
# decode context parallel does not support full cudagraphs
|
# decode context parallel does not support full cudagraphs
|
||||||
|
|||||||
@ -77,6 +77,7 @@ from vllm.config.observability import DetailedTraceModules
|
|||||||
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
|
from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy
|
||||||
from vllm.config.scheduler import SchedulerPolicy
|
from vllm.config.scheduler import SchedulerPolicy
|
||||||
from vllm.config.utils import get_field
|
from vllm.config.utils import get_field
|
||||||
|
from vllm.config.vllm import OptimizationLevel
|
||||||
from vllm.logger import init_logger, suppress_logging
|
from vllm.logger import init_logger, suppress_logging
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
@ -560,6 +561,7 @@ class EngineArgs:
|
|||||||
stream_interval: int = SchedulerConfig.stream_interval
|
stream_interval: int = SchedulerConfig.stream_interval
|
||||||
|
|
||||||
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
||||||
|
optimization_level: OptimizationLevel = VllmConfig.optimization_level
|
||||||
|
|
||||||
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
|
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
|
||||||
kv_offloading_backend: KVOffloadingBackend | None = (
|
kv_offloading_backend: KVOffloadingBackend | None = (
|
||||||
@ -1114,6 +1116,10 @@ class EngineArgs:
|
|||||||
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
|
"--structured-outputs-config", **vllm_kwargs["structured_outputs_config"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vllm_group.add_argument(
|
||||||
|
"--optimization-level", **vllm_kwargs["optimization_level"]
|
||||||
|
)
|
||||||
|
|
||||||
# Other arguments
|
# Other arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-log-stats",
|
"--disable-log-stats",
|
||||||
@ -1733,7 +1739,6 @@ class EngineArgs:
|
|||||||
compilation_config.max_cudagraph_capture_size = (
|
compilation_config.max_cudagraph_capture_size = (
|
||||||
self.max_cudagraph_capture_size
|
self.max_cudagraph_capture_size
|
||||||
)
|
)
|
||||||
|
|
||||||
config = VllmConfig(
|
config = VllmConfig(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
@ -1750,6 +1755,7 @@ class EngineArgs:
|
|||||||
kv_events_config=self.kv_events_config,
|
kv_events_config=self.kv_events_config,
|
||||||
ec_transfer_config=self.ec_transfer_config,
|
ec_transfer_config=self.ec_transfer_config,
|
||||||
additional_config=self.additional_config,
|
additional_config=self.additional_config,
|
||||||
|
optimization_level=self.optimization_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|||||||
@ -247,16 +247,16 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
|
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
|
||||||
# allow -O flag to be used without space, e.g. -O3 or -Odecode
|
# allow -O flag to be used without space, e.g. -O3 or -Odecode
|
||||||
# -O.<...> handled later
|
# -O.<...> handled later
|
||||||
# also handle -O=<mode> here
|
# also handle -O=<optimization_level> here
|
||||||
mode = arg[3:] if arg[2] == "=" else arg[2:]
|
optimization_level = arg[3:] if arg[2] == "=" else arg[2:]
|
||||||
processed_args.append(f"-O.mode={mode}")
|
processed_args += ["--optimization-level", optimization_level]
|
||||||
elif (
|
elif (
|
||||||
arg == "-O"
|
arg == "-O"
|
||||||
and i + 1 < len(args)
|
and i + 1 < len(args)
|
||||||
and args[i + 1] in {"0", "1", "2", "3"}
|
and args[i + 1] in {"0", "1", "2", "3"}
|
||||||
):
|
):
|
||||||
# Convert -O <n> to -O.mode <n>
|
# Convert -O <n> to --optimization-level <n>
|
||||||
processed_args.append("-O.mode")
|
processed_args.append("--optimization-level")
|
||||||
else:
|
else:
|
||||||
processed_args.append(arg)
|
processed_args.append(arg)
|
||||||
|
|
||||||
@ -294,10 +294,24 @@ class FlexibleArgumentParser(ArgumentParser):
|
|||||||
delete = set[int]()
|
delete = set[int]()
|
||||||
dict_args = defaultdict[str, dict[str, Any]](dict)
|
dict_args = defaultdict[str, dict[str, Any]](dict)
|
||||||
duplicates = set[str]()
|
duplicates = set[str]()
|
||||||
|
# Track regular arguments (non-dict args) for duplicate detection
|
||||||
|
regular_args_seen = set[str]()
|
||||||
for i, processed_arg in enumerate(processed_args):
|
for i, processed_arg in enumerate(processed_args):
|
||||||
if i in delete: # skip if value from previous arg
|
if i in delete: # skip if value from previous arg
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if processed_arg.startswith("--") and "." not in processed_arg:
|
||||||
|
if "=" in processed_arg:
|
||||||
|
arg_name = processed_arg.split("=", 1)[0]
|
||||||
|
else:
|
||||||
|
arg_name = processed_arg
|
||||||
|
|
||||||
|
if arg_name in regular_args_seen:
|
||||||
|
duplicates.add(arg_name)
|
||||||
|
else:
|
||||||
|
regular_args_seen.add(arg_name)
|
||||||
|
continue
|
||||||
|
|
||||||
if processed_arg.startswith("-") and "." in processed_arg:
|
if processed_arg.startswith("-") and "." in processed_arg:
|
||||||
if "=" in processed_arg:
|
if "=" in processed_arg:
|
||||||
processed_arg, value_str = processed_arg.split("=", 1)
|
processed_arg, value_str = processed_arg.split("=", 1)
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class CudaGraphManager:
|
|||||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
assert self.compilation_config is not None
|
assert self.compilation_config is not None
|
||||||
|
self.cudagraph_mode: CUDAGraphMode
|
||||||
if self.compilation_config.cudagraph_mode is None:
|
if self.compilation_config.cudagraph_mode is None:
|
||||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user