mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 10:09:08 +08:00
Improve configs - the rest! (#17562)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
7e3571134f
commit
4b2ed7926a
@ -9,7 +9,7 @@ import torch
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import CompilationConfig, CompilationLevel
|
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import create_new_process_for_each_test
|
from ..utils import create_new_process_for_each_test
|
||||||
@ -95,9 +95,6 @@ def test_full_graph(
|
|||||||
run_model(optimization_level, model, model_kwargs)
|
run_model(optimization_level, model, model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
PassConfig = CompilationConfig.PassConfig
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(luka) add other supported compilation config scenarios here
|
# TODO(luka) add other supported compilation config scenarios here
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"compilation_config, model_info",
|
"compilation_config, model_info",
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
|||||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import CompilationConfig, VllmConfig
|
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
|||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.compilation_config = CompilationConfig(pass_config= \
|
vllm_config.compilation_config = CompilationConfig(
|
||||||
CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
|
||||||
enable_noop=True))
|
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
fusion_pass = FusionPass.instance(vllm_config)
|
fusion_pass = FusionPass.instance(vllm_config)
|
||||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||||
|
|||||||
@ -9,7 +9,8 @@ from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
|||||||
FusionPass, QuantKey)
|
FusionPass, QuantKey)
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||||
|
VllmConfig)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
||||||
@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
|||||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
||||||
vllm_config.compilation_config.pass_config = \
|
vllm_config.compilation_config.pass_config = \
|
||||||
CompilationConfig.PassConfig(enable_fusion=True,
|
PassConfig(enable_fusion=True, enable_noop=True)
|
||||||
enable_noop=True)
|
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
|
|||||||
find_specified_fn_maybe, is_func)
|
find_specified_fn_maybe, is_func)
|
||||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||||
VllmConfig)
|
PassConfig, VllmConfig)
|
||||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
|
|||||||
|
|
||||||
# configure vllm config for SequenceParallelismPass
|
# configure vllm config for SequenceParallelismPass
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.compilation_config = CompilationConfig(
|
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
|
||||||
pass_config=CompilationConfig.PassConfig(
|
enable_sequence_parallelism=True))
|
||||||
enable_sequence_parallelism=True, ), )
|
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
# this is a fake model name to construct the model config
|
# this is a fake model name to construct the model config
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import vllm.envs as envs
|
|||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||||
from vllm.config import CompilationConfig, VllmConfig
|
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
|||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
config = VllmConfig()
|
config = VllmConfig()
|
||||||
config.compilation_config = CompilationConfig(
|
config.compilation_config = CompilationConfig(
|
||||||
pass_config=CompilationConfig.PassConfig(enable_fusion=True,
|
pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
|
||||||
enable_reshape=True))
|
|
||||||
fusion_pass = ActivationQuantFusionPass(config)
|
fusion_pass = ActivationQuantFusionPass(config)
|
||||||
|
|
||||||
backend = TestBackend(fusion_pass)
|
backend = TestBackend(fusion_pass)
|
||||||
|
|||||||
@ -206,7 +206,7 @@ def _compare_sp(
|
|||||||
'compile_sizes': [4, 8],
|
'compile_sizes': [4, 8],
|
||||||
'splitting_ops': [],
|
'splitting_ops': [],
|
||||||
'pass_config': {
|
'pass_config': {
|
||||||
'enable_sequence_parallism': sp_enabled,
|
'enable_sequence_parallelism': sp_enabled,
|
||||||
'enable_noop': True,
|
'enable_noop': True,
|
||||||
'enable_fusion': True,
|
'enable_fusion': True,
|
||||||
},
|
},
|
||||||
@ -223,7 +223,7 @@ def _compare_sp(
|
|||||||
"--distributed-executor-backend",
|
"--distributed-executor-backend",
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
"--compilation_config",
|
"--compilation_config",
|
||||||
str(compilation_config),
|
json.dumps(compilation_config),
|
||||||
]
|
]
|
||||||
|
|
||||||
tp_env = {
|
tp_env = {
|
||||||
|
|||||||
@ -8,21 +8,18 @@ from typing import Literal, Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import config
|
from vllm.config import CompilationConfig, config
|
||||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||||
get_type, is_not_builtin, is_type,
|
get_type, is_not_builtin, is_type,
|
||||||
literal_to_kwargs, nullable_kvs,
|
literal_to_kwargs, nullable_kvs,
|
||||||
optional_type)
|
optional_type, parse_type)
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type", "value", "expected"), [
|
@pytest.mark.parametrize(("type", "value", "expected"), [
|
||||||
(int, "42", 42),
|
(int, "42", 42),
|
||||||
(int, "None", None),
|
|
||||||
(float, "3.14", 3.14),
|
(float, "3.14", 3.14),
|
||||||
(float, "None", None),
|
|
||||||
(str, "Hello World!", "Hello World!"),
|
(str, "Hello World!", "Hello World!"),
|
||||||
(str, "None", None),
|
|
||||||
(json.loads, '{"foo":1,"bar":2}', {
|
(json.loads, '{"foo":1,"bar":2}', {
|
||||||
"foo": 1,
|
"foo": 1,
|
||||||
"bar": 2
|
"bar": 2
|
||||||
@ -31,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser
|
|||||||
"foo": 1,
|
"foo": 1,
|
||||||
"bar": 2
|
"bar": 2
|
||||||
}),
|
}),
|
||||||
(json.loads, "None", None),
|
|
||||||
])
|
])
|
||||||
def test_optional_type(type, value, expected):
|
def test_parse_type(type, value, expected):
|
||||||
optional_type_func = optional_type(type)
|
parse_type_func = parse_type(type)
|
||||||
context = nullcontext()
|
context = nullcontext()
|
||||||
if value == "foo=1,bar=2":
|
if value == "foo=1,bar=2":
|
||||||
context = pytest.warns(DeprecationWarning)
|
context = pytest.warns(DeprecationWarning)
|
||||||
with context:
|
with context:
|
||||||
assert optional_type_func(value) == expected
|
assert parse_type_func(value) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_optional_type():
|
||||||
|
optional_type_func = optional_type(int)
|
||||||
|
assert optional_type_func("None") is None
|
||||||
|
assert optional_type_func("42") == 42
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
|
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
|
||||||
@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected):
|
|||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class DummyConfigClass:
|
class NestedConfig:
|
||||||
|
field: int = 1
|
||||||
|
"""field"""
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class FromCliConfig1:
|
||||||
|
field: int = 1
|
||||||
|
"""field"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli(cls, cli_value: str):
|
||||||
|
inst = cls(**json.loads(cli_value))
|
||||||
|
inst.field += 1
|
||||||
|
return inst
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class FromCliConfig2:
|
||||||
|
field: int = 1
|
||||||
|
"""field"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli(cls, cli_value: str):
|
||||||
|
inst = cls(**json.loads(cli_value))
|
||||||
|
inst.field += 2
|
||||||
|
return inst
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class DummyConfig:
|
||||||
regular_bool: bool = True
|
regular_bool: bool = True
|
||||||
"""Regular bool with default True"""
|
"""Regular bool with default True"""
|
||||||
optional_bool: Optional[bool] = None
|
optional_bool: Optional[bool] = None
|
||||||
@ -108,18 +143,24 @@ class DummyConfigClass:
|
|||||||
"""Literal of literals with default 1"""
|
"""Literal of literals with default 1"""
|
||||||
json_tip: dict = field(default_factory=dict)
|
json_tip: dict = field(default_factory=dict)
|
||||||
"""Dict which will be JSON in CLI"""
|
"""Dict which will be JSON in CLI"""
|
||||||
|
nested_config: NestedConfig = field(default_factory=NestedConfig)
|
||||||
|
"""Nested config"""
|
||||||
|
from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1)
|
||||||
|
"""Config with from_cli method"""
|
||||||
|
from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2)
|
||||||
|
"""Different config with from_cli method"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("type_hint", "expected"), [
|
@pytest.mark.parametrize(("type_hint", "expected"), [
|
||||||
(int, False),
|
(int, False),
|
||||||
(DummyConfigClass, True),
|
(DummyConfig, True),
|
||||||
])
|
])
|
||||||
def test_is_not_builtin(type_hint, expected):
|
def test_is_not_builtin(type_hint, expected):
|
||||||
assert is_not_builtin(type_hint) == expected
|
assert is_not_builtin(type_hint) == expected
|
||||||
|
|
||||||
|
|
||||||
def test_get_kwargs():
|
def test_get_kwargs():
|
||||||
kwargs = get_kwargs(DummyConfigClass)
|
kwargs = get_kwargs(DummyConfig)
|
||||||
print(kwargs)
|
print(kwargs)
|
||||||
|
|
||||||
# bools should not have their type set
|
# bools should not have their type set
|
||||||
@ -142,6 +183,11 @@ def test_get_kwargs():
|
|||||||
# dict should have json tip in help
|
# dict should have json tip in help
|
||||||
json_tip = "\n\nShould be a valid JSON string."
|
json_tip = "\n\nShould be a valid JSON string."
|
||||||
assert kwargs["json_tip"]["help"].endswith(json_tip)
|
assert kwargs["json_tip"]["help"].endswith(json_tip)
|
||||||
|
# nested config should should construct the nested config
|
||||||
|
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
|
||||||
|
# from_cli configs should be constructed with the correct method
|
||||||
|
assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3
|
||||||
|
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("arg", "expected"), [
|
@pytest.mark.parametrize(("arg", "expected"), [
|
||||||
@ -177,7 +223,7 @@ def test_compilation_config():
|
|||||||
|
|
||||||
# default value
|
# default value
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
assert args.compilation_config is None
|
assert args.compilation_config == CompilationConfig()
|
||||||
|
|
||||||
# set to O3
|
# set to O3
|
||||||
args = parser.parse_args(["-O3"])
|
args = parser.parse_args(["-O3"])
|
||||||
@ -194,7 +240,7 @@ def test_compilation_config():
|
|||||||
# set to string form of a dict
|
# set to string form of a dict
|
||||||
args = parser.parse_args([
|
args = parser.parse_args([
|
||||||
"--compilation-config",
|
"--compilation-config",
|
||||||
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
|
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
|
||||||
])
|
])
|
||||||
assert (args.compilation_config.level == 3 and
|
assert (args.compilation_config.level == 3 and
|
||||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
||||||
@ -202,7 +248,7 @@ def test_compilation_config():
|
|||||||
# set to string form of a dict
|
# set to string form of a dict
|
||||||
args = parser.parse_args([
|
args = parser.parse_args([
|
||||||
"--compilation-config="
|
"--compilation-config="
|
||||||
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
|
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
|
||||||
])
|
])
|
||||||
assert (args.compilation_config.level == 3 and
|
assert (args.compilation_config.level == 3 and
|
||||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import time
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import CompilationConfig, VllmConfig
|
from vllm.config import PassConfig, VllmConfig
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
@ -56,10 +56,7 @@ class VllmInductorPass(InductorPass):
|
|||||||
|
|
||||||
class PrinterInductorPass(VllmInductorPass):
|
class PrinterInductorPass(VllmInductorPass):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, name: str, config: PassConfig, always=False):
|
||||||
name: str,
|
|
||||||
config: CompilationConfig.PassConfig,
|
|
||||||
always=False):
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.name = name
|
self.name = name
|
||||||
self.always = always
|
self.always = always
|
||||||
|
|||||||
513
vllm/config.py
513
vllm/config.py
@ -11,8 +11,8 @@ import textwrap
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
from dataclasses import (MISSING, Field, asdict, dataclass, field, fields,
|
||||||
replace)
|
is_dataclass, replace)
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -20,7 +20,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
|
|||||||
Protocol, TypeVar, Union, cast, get_args, get_origin)
|
Protocol, TypeVar, Union, cast, get_args, get_origin)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
|
||||||
from torch.distributed import ProcessGroup, ReduceOp
|
from torch.distributed import ProcessGroup, ReduceOp
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
@ -57,7 +56,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
ConfigType = type[DataclassInstance]
|
ConfigType = type[DataclassInstance]
|
||||||
else:
|
else:
|
||||||
QuantizationConfig = None
|
QuantizationConfig = Any
|
||||||
ConfigType = type
|
ConfigType = type
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -169,6 +168,12 @@ def config(cls: ConfigT) -> ConfigT:
|
|||||||
"""
|
"""
|
||||||
A decorator that ensures all fields in a dataclass have default values
|
A decorator that ensures all fields in a dataclass have default values
|
||||||
and that each field has a docstring.
|
and that each field has a docstring.
|
||||||
|
|
||||||
|
If a `ConfigT` is used as a CLI argument itself, the default value provided
|
||||||
|
by `get_kwargs` will be the result parsing a JSON string as the kwargs
|
||||||
|
(i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT`
|
||||||
|
requires custom construction from CLI (i.e. `CompilationConfig`), it can
|
||||||
|
have a `from_cli` method, which will be called instead.
|
||||||
"""
|
"""
|
||||||
if not is_dataclass(cls):
|
if not is_dataclass(cls):
|
||||||
raise TypeError("The decorated class must be a dataclass.")
|
raise TypeError("The decorated class must be a dataclass.")
|
||||||
@ -202,7 +207,7 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
|||||||
cls_fields = {f.name: f for f in fields(cls)}
|
cls_fields = {f.name: f for f in fields(cls)}
|
||||||
if name not in cls_fields:
|
if name not in cls_fields:
|
||||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
|
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
|
||||||
named_field: Field = cls_fields.get(name)
|
named_field: Field = cls_fields[name]
|
||||||
if (default_factory := named_field.default_factory) is not MISSING:
|
if (default_factory := named_field.default_factory) is not MISSING:
|
||||||
return field(default_factory=default_factory)
|
return field(default_factory=default_factory)
|
||||||
if (default := named_field.default) is not MISSING:
|
if (default := named_field.default) is not MISSING:
|
||||||
@ -211,6 +216,10 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
|||||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||||
|
|
||||||
|
|
||||||
|
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||||
|
return next(f for f in fields(cls) if f.name == name).init
|
||||||
|
|
||||||
|
|
||||||
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||||
|
|
||||||
@ -2007,13 +2016,13 @@ class SchedulerConfig:
|
|||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.max_model_len is None:
|
if self.max_model_len is None:
|
||||||
self.max_model_len = 8192
|
self.max_model_len = 8192
|
||||||
logger.warning(
|
logger.warning_once(
|
||||||
"max_model_len was is not set. Defaulting to arbitrary value "
|
"max_model_len was is not set. Defaulting to arbitrary value "
|
||||||
"of %d.", self.max_model_len)
|
"of %d.", self.max_model_len)
|
||||||
|
|
||||||
if self.max_num_seqs is None:
|
if self.max_num_seqs is None:
|
||||||
self.max_num_seqs = 128
|
self.max_num_seqs = 128
|
||||||
logger.warning(
|
logger.warning_once(
|
||||||
"max_num_seqs was is not set. Defaulting to arbitrary value "
|
"max_num_seqs was is not set. Defaulting to arbitrary value "
|
||||||
"of %d.", self.max_num_seqs)
|
"of %d.", self.max_num_seqs)
|
||||||
|
|
||||||
@ -2840,8 +2849,8 @@ class PromptAdapterConfig:
|
|||||||
class MultiModalConfig:
|
class MultiModalConfig:
|
||||||
"""Controls the behavior of multimodal models."""
|
"""Controls the behavior of multimodal models."""
|
||||||
|
|
||||||
limit_per_prompt: dict[str, int] = get_field(ModelConfig,
|
limit_per_prompt: dict[str, int] = \
|
||||||
"limit_mm_per_prompt")
|
cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt"))
|
||||||
"""
|
"""
|
||||||
The maximum number of input items allowed per prompt for each modality.
|
The maximum number of input items allowed per prompt for each modality.
|
||||||
Defaults to 1 (V0) or 999 (V1) for each modality.
|
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||||
@ -3415,41 +3424,49 @@ class ObservabilityConfig:
|
|||||||
self.collect_detailed_traces[0].split(","))
|
self.collect_detailed_traces[0].split(","))
|
||||||
|
|
||||||
|
|
||||||
class KVTransferConfig(BaseModel):
|
KVProducer = Literal["kv_producer", "kv_both"]
|
||||||
|
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||||
|
KVRole = Literal[KVProducer, KVConsumer]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class KVTransferConfig:
|
||||||
"""Configuration for distributed KV cache transfer."""
|
"""Configuration for distributed KV cache transfer."""
|
||||||
|
|
||||||
# The KV connector for vLLM to transmit KV caches between vLLM instances.
|
|
||||||
kv_connector: Optional[str] = None
|
kv_connector: Optional[str] = None
|
||||||
|
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||||
|
"""
|
||||||
|
|
||||||
# The device used by kv connector to buffer the KV cache.
|
|
||||||
# Currently only support 'cuda'.
|
|
||||||
kv_buffer_device: Optional[str] = "cuda"
|
kv_buffer_device: Optional[str] = "cuda"
|
||||||
|
"""The device used by kv connector to buffer the KV cache.
|
||||||
|
Currently only support 'cuda'."""
|
||||||
|
|
||||||
# The buffer size for TorchDistributedConnector. Measured in number of
|
|
||||||
# bytes. Recommended value: 1e9 (about 1GB).
|
|
||||||
kv_buffer_size: float = 1e9
|
kv_buffer_size: float = 1e9
|
||||||
|
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||||
|
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||||
|
|
||||||
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
kv_role: Optional[KVRole] = None
|
||||||
# are 'kv_producer', 'kv_consumer', and 'both'.
|
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||||
kv_role: Optional[str] = None
|
are 'kv_producer', 'kv_consumer', and 'both'."""
|
||||||
|
|
||||||
# The rank of this vLLM instance in the KV cache transfer. Typical value:
|
|
||||||
# 0 for prefill instance, 1 for decode instance.
|
|
||||||
# Currently only 1P1D is supported.
|
|
||||||
kv_rank: Optional[int] = None
|
kv_rank: Optional[int] = None
|
||||||
|
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||||
|
0 for prefill instance, 1 for decode instance.
|
||||||
|
Currently only 1P1D is supported."""
|
||||||
|
|
||||||
# The number of parallel instances for KV cache transfer. For
|
|
||||||
# PyNcclConnector, this should be 2.
|
|
||||||
kv_parallel_size: int = 1
|
kv_parallel_size: int = 1
|
||||||
|
"""The number of parallel instances for KV cache transfer. For
|
||||||
|
PyNcclConnector, this should be 2."""
|
||||||
|
|
||||||
# The KV connector ip, used to build distributed connection
|
|
||||||
kv_ip: str = "127.0.0.1"
|
kv_ip: str = "127.0.0.1"
|
||||||
|
"""The KV connector ip, used to build distributed connection."""
|
||||||
|
|
||||||
# The KV connector port, used to build distributed connection
|
|
||||||
kv_port: int = 14579
|
kv_port: int = 14579
|
||||||
|
"""The KV connector port, used to build distributed connection."""
|
||||||
|
|
||||||
# any extra config that the connector may need
|
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||||
kv_connector_extra_config: dict[str, Any] = {}
|
"""any extra config that the connector may need."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -3470,46 +3487,37 @@ class KVTransferConfig(BaseModel):
|
|||||||
usedforsecurity=False).hexdigest()
|
usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
@classmethod
|
def __post_init__(self) -> None:
|
||||||
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
|
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
||||||
"""Parse the CLI value for the kv cache transfer config."""
|
raise ValueError(f"Unsupported kv_role: {self.kv_role}. "
|
||||||
return KVTransferConfig.model_validate_json(cli_value)
|
f"Supported roles are {get_args(KVRole)}")
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
|
||||||
|
|
||||||
if self.kv_role is not None and self.kv_role not in [
|
|
||||||
"kv_producer", "kv_consumer", "kv_both"
|
|
||||||
]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported kv_role: {self.kv_role}. "
|
|
||||||
f"Supported roles are `kv_producer`, `kv_consumer`, "
|
|
||||||
f"and `kv_both`")
|
|
||||||
|
|
||||||
if self.kv_connector is not None and self.kv_role is None:
|
if self.kv_connector is not None and self.kv_role is None:
|
||||||
raise ValueError("Please specify kv_disagg_role when kv_connector "
|
raise ValueError("Please specify kv_disagg_role when kv_connector "
|
||||||
"is set, supported roles are `kv_producer`, "
|
f"is set, supported roles are {get_args(KVRole)}")
|
||||||
"`kv_consumer`, and `kv_both`")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_kv_transfer_instance(self) -> bool:
|
def is_kv_transfer_instance(self) -> bool:
|
||||||
return self.kv_connector is not None and \
|
return self.kv_connector is not None and \
|
||||||
self.kv_role in ["kv_producer", "kv_consumer", "kv_both"]
|
self.kv_role in get_args(KVRole)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_kv_producer(self) -> bool:
|
def is_kv_producer(self) -> bool:
|
||||||
return self.kv_connector is not None and \
|
return self.kv_connector is not None and \
|
||||||
self.kv_role in ["kv_producer", "kv_both"]
|
self.kv_role in get_args(KVProducer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_kv_consumer(self) -> bool:
|
def is_kv_consumer(self) -> bool:
|
||||||
return self.kv_connector is not None and \
|
return self.kv_connector is not None and \
|
||||||
self.kv_role in ["kv_consumer", "kv_both"]
|
self.kv_role in get_args(KVConsumer)
|
||||||
|
|
||||||
def get_from_extra_config(self, key, default) -> Any:
|
def get_from_extra_config(self, key, default) -> Any:
|
||||||
return self.kv_connector_extra_config.get(key, default)
|
return self.kv_connector_extra_config.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
class KVEventsConfig(BaseModel):
|
@config
|
||||||
|
@dataclass
|
||||||
|
class KVEventsConfig:
|
||||||
"""Configuration for KV event publishing."""
|
"""Configuration for KV event publishing."""
|
||||||
|
|
||||||
enable_kv_cache_events: bool = False
|
enable_kv_cache_events: bool = False
|
||||||
@ -3548,11 +3556,6 @@ class KVEventsConfig(BaseModel):
|
|||||||
this topic to receive events.
|
this topic to receive events.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
|
|
||||||
"""Parse the CLI value for the event publisher config."""
|
|
||||||
return KVEventsConfig.model_validate_json(cli_value)
|
|
||||||
|
|
||||||
|
|
||||||
class CompilationLevel:
|
class CompilationLevel:
|
||||||
# constants for the levels of the compilation process
|
# constants for the levels of the compilation process
|
||||||
@ -3562,80 +3565,72 @@ class CompilationLevel:
|
|||||||
PIECEWISE = 3
|
PIECEWISE = 3
|
||||||
|
|
||||||
|
|
||||||
class CompilationConfig(BaseModel):
|
@config
|
||||||
"""
|
@dataclass
|
||||||
Configuration for compilation.
|
class PassConfig:
|
||||||
It has three parts:
|
"""Configuration for custom Inductor passes.
|
||||||
|
|
||||||
|
This is separate from general `CompilationConfig` so that inductor passes
|
||||||
|
don't all have access to full configuration - that would create a cycle as
|
||||||
|
the `PassManager` is set as a property of config."""
|
||||||
|
|
||||||
|
dump_graph_stages: list[str] = field(default_factory=list)
|
||||||
|
"""List of stages for which we want to dump the graph. Each pass defines
|
||||||
|
its own stages (before, after, maybe in-between)."""
|
||||||
|
dump_graph_dir: Path = Path(".")
|
||||||
|
"""Directory to dump the graphs."""
|
||||||
|
# TODO(luka) better pass enabling system.
|
||||||
|
enable_fusion: bool = True
|
||||||
|
"""Whether to enable the custom fusion pass."""
|
||||||
|
enable_noop: bool = True
|
||||||
|
"""Whether to enable the custom no-op elimination pass."""
|
||||||
|
enable_sequence_parallelism: bool = False
|
||||||
|
"""Whether to enable sequence parallelism."""
|
||||||
|
|
||||||
|
def uuid(self):
|
||||||
|
"""
|
||||||
|
Produces a hash unique to the pass configuration.
|
||||||
|
Any new fields that affect compilation should be added to the hash.
|
||||||
|
Do not include dump_graph_* in the hash - they don't affect
|
||||||
|
compilation.
|
||||||
|
"""
|
||||||
|
include = {
|
||||||
|
"enable_fusion", "enable_noop", "enable_sequence_parallelism"
|
||||||
|
}
|
||||||
|
dict_ = {k: v for k, v in asdict(self).items() if k in include}
|
||||||
|
return InductorPass.hash_dict(dict_)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if not self.enable_noop and self.enable_fusion:
|
||||||
|
logger.warning_once(
|
||||||
|
"Fusion enabled but reshape elimination disabled. "
|
||||||
|
"RMSNorm + quant (fp8) fusion might not work")
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class CompilationConfig:
|
||||||
|
"""Configuration for compilation. It has three parts:
|
||||||
|
|
||||||
- Top-level Compilation control:
|
- Top-level Compilation control:
|
||||||
- level: the level of compilation.
|
- {attr}`level`
|
||||||
- 0: no compilation.
|
- {attr}`debug_dump_path`
|
||||||
- 1: dynamo as is.
|
- {attr}`cache_dir`
|
||||||
- 2: dynamo once.
|
- {attr}`backend`
|
||||||
- 3: piecewise compilation.
|
- {attr}`custom_ops`
|
||||||
- debug_dump_path: the path to dump the debug information.
|
- {attr}`splitting_ops`
|
||||||
- cache_dir: the directory to store the compiled graph, to
|
|
||||||
accelerate Inductor compilation. By default, it will use
|
|
||||||
model-related information to generate a cache directory.
|
|
||||||
- backend: the backend for compilation. It needs to be a string.
|
|
||||||
- "" (empty string): use the default backend.
|
|
||||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
|
||||||
- "full.module.name": a qualified name which can be used to import the backend function.
|
|
||||||
We use string to avoid serialization issues when using compilation in a distributed setting.
|
|
||||||
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
|
|
||||||
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
|
|
||||||
- custom_ops: fine-grained control over which custom ops to enable/disable.
|
|
||||||
Use 'all' to enable all, 'none' to disable all.
|
|
||||||
Also specify a list of custom op names to enable (prefixed with a '+'),
|
|
||||||
or disable (prefixed with a '-').
|
|
||||||
Examples:
|
|
||||||
- 'all,-op1' to enable all except op1
|
|
||||||
- 'none,+op1,+op2' to enable only op1 and op2
|
|
||||||
By default, all custom ops are enabled when running without Inductor
|
|
||||||
and disabled when running with Inductor (compile_level >= Inductor).
|
|
||||||
- splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
|
|
||||||
- CudaGraph capture:
|
- CudaGraph capture:
|
||||||
- use_cudagraph: whether to use cudagraph inside compilation.
|
- {attr}`use_cudagraph`
|
||||||
- False: cudagraph inside compilation is not used.
|
- {attr}`cudagraph_capture_sizes`
|
||||||
- True: cudagraph inside compilation is used. It requires
|
- {attr}`cudagraph_num_of_warmups`
|
||||||
that all input buffers have fixed addresses, and all
|
- {attr}`cudagraph_copy_inputs`
|
||||||
splitting ops write their outputs to input buffers.
|
- {attr}`full_cuda_graph`
|
||||||
Note that this is orthogonal to the cudagraph capture logic
|
|
||||||
outside of compilation.
|
|
||||||
TODO: move outside cudagraph logic into compilation.
|
|
||||||
torch.compile will handle cudagraph capture logic in the future.
|
|
||||||
- cudagraph_capture_sizes: sizes to capture cudagraph.
|
|
||||||
- None (default): capture sizes are inferred from vllm config.
|
|
||||||
- list[int]: capture sizes are specified as given.
|
|
||||||
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
|
|
||||||
It means the first several runs will be treated as warmup runs.
|
|
||||||
Only after that, the execution will be recorded, and the recorded
|
|
||||||
cudagraph will be used for subsequent runs.
|
|
||||||
- cudagraph_copy_inputs: whether to copy input tensors for
|
|
||||||
cudagraph. If the caller can guarantee that the same input buffers
|
|
||||||
are always used, it can set this to False. Otherwise, it should
|
|
||||||
set this to True, and the compiler will copy the input to an
|
|
||||||
internally managed buffer. Default is False.
|
|
||||||
- full_cuda_graph: whether to use a full cuda graph for the entire forward
|
|
||||||
pass rather than splitting certain operations such as attention into subgraphs.
|
|
||||||
Thus this flag cannot be used together with splitting_ops. This may provide
|
|
||||||
performance benefits for smaller models.
|
|
||||||
- Inductor compilation:
|
- Inductor compilation:
|
||||||
- use_inductor: whether to use inductor compilation.
|
- {attr}`use_inductor`
|
||||||
- False: inductor compilation is not used. graph runs in eager.
|
- {attr}`compile_sizes`
|
||||||
- True: inductor compilation is used. one graph for symbolic shape
|
- {attr}`inductor_compile_config`
|
||||||
is compiled. In addition, compile for compile_sizes,
|
- {attr}`inductor_passes`
|
||||||
using configurations in inductor_compile_config.
|
- custom inductor passes
|
||||||
- compile_sizes: sizes to compile for inductor. In addition
|
|
||||||
to integers, it also supports "cudagraph_capture_sizes" to
|
|
||||||
specify the sizes for cudagraph capture.
|
|
||||||
- inductor_compile_config: additional configurations for inductor.
|
|
||||||
- None: use default configurations.
|
|
||||||
- inductor_passes: additional passes for inductor. It is a dictionary
|
|
||||||
from pass name to pass function qualified name. We use function
|
|
||||||
name because the config uses json format. If we pass the config
|
|
||||||
from Python, functions can also be passed directly via Python object
|
|
||||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
|
||||||
- custom inductor passes: see PassConfig for more details
|
|
||||||
|
|
||||||
Why we have different sizes for cudagraph and inductor:
|
Why we have different sizes for cudagraph and inductor:
|
||||||
- cudagraph: a cudagraph captured for a specific size can only be used
|
- cudagraph: a cudagraph captured for a specific size can only be used
|
||||||
@ -3646,83 +3641,135 @@ class CompilationConfig(BaseModel):
|
|||||||
static shapes. However, we find the general shape compilation is
|
static shapes. However, we find the general shape compilation is
|
||||||
sufficient for most cases. It might be beneficial to compile for
|
sufficient for most cases. It might be beneficial to compile for
|
||||||
certain small batchsizes, where inductor is good at optimizing.
|
certain small batchsizes, where inductor is good at optimizing.
|
||||||
""" # noqa
|
"""
|
||||||
|
# Top-level Compilation control
|
||||||
level: int = 0
|
level: int = 0
|
||||||
|
"""The level of compilation:
|
||||||
|
|
||||||
|
- 0: no compilation.
|
||||||
|
- 1: dynamo as is.
|
||||||
|
- 2: dynamo once.
|
||||||
|
- 3: piecewise compilation."""
|
||||||
debug_dump_path: str = ""
|
debug_dump_path: str = ""
|
||||||
|
"""The path to dump the debug information."""
|
||||||
cache_dir: str = ""
|
cache_dir: str = ""
|
||||||
|
"""The directory to store the compiled graph, to accelerate Inductor
|
||||||
|
compilation. By default, it will use model-related information to generate
|
||||||
|
a cache directory."""
|
||||||
backend: str = ""
|
backend: str = ""
|
||||||
custom_ops: list[str] = Field(default_factory=list)
|
"""The backend for compilation. It needs to be a string:
|
||||||
splitting_ops: list[str] = Field(default=None) # type: ignore
|
|
||||||
|
|
||||||
|
- "" (empty string): use the default backend.
|
||||||
|
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||||
|
- "full.module.name": a qualified name which can be used to import the
|
||||||
|
|
||||||
|
backend function.
|
||||||
|
We use string to avoid serialization issues when using compilation in a
|
||||||
|
distributed setting. When the compilation level is 1 or 2, the backend is
|
||||||
|
used for the compilation directly (it sees the whole graph). When the
|
||||||
|
compilation level is 3, the backend is used for the piecewise compilation
|
||||||
|
(it sees a part of the graph)."""
|
||||||
|
custom_ops: list[str] = field(default_factory=list)
|
||||||
|
"""Fine-grained control over which custom ops to enable/disable. Use 'all'
|
||||||
|
to enable all, 'none' to disable all. Also specify a list of custom op
|
||||||
|
names to enable (prefixed with a '+'), or disable (prefixed with a '-').
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
- 'all,-op1' to enable all except op1
|
||||||
|
- 'none,+op1,+op2' to enable only op1 and op2
|
||||||
|
|
||||||
|
By default, all custom ops are enabled when running without Inductor and
|
||||||
|
disabled when running with Inductor (compile_level >= Inductor)."""
|
||||||
|
splitting_ops: list[str] = field(default_factory=list)
|
||||||
|
"""A list of ops to split the full graph into subgraphs, used in piecewise
|
||||||
|
compilation."""
|
||||||
|
|
||||||
|
# Inductor capture
|
||||||
use_inductor: bool = True
|
use_inductor: bool = True
|
||||||
compile_sizes: Optional[list[Union[int, str]]] = Field(default=None)
|
"""Whether to use inductor compilation:
|
||||||
inductor_compile_config: dict = Field(default_factory=dict)
|
|
||||||
inductor_passes: dict[str, str] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
- False: inductor compilation is not used. graph runs in eager.
|
||||||
|
- True: inductor compilation is used. one graph for symbolic shape
|
||||||
|
is compiled. In addition, compile for compile_sizes,
|
||||||
|
using configurations in inductor_compile_config."""
|
||||||
|
compile_sizes: Optional[list[Union[int, str]]] = None
|
||||||
|
"""Sizes to compile for inductor. In addition
|
||||||
|
to integers, it also supports "cudagraph_capture_sizes" to
|
||||||
|
specify the sizes for cudagraph capture."""
|
||||||
|
inductor_compile_config: dict = field(default_factory=dict)
|
||||||
|
"""Additional configurations for inductor.
|
||||||
|
- None: use default configurations."""
|
||||||
|
inductor_passes: dict[str, str] = field(default_factory=dict)
|
||||||
|
"""Additional passes for inductor. It is a dictionary
|
||||||
|
from pass name to pass function qualified name. We use function
|
||||||
|
name because the config uses JSON format. If we pass the config
|
||||||
|
from Python, functions can also be passed directly via Python object
|
||||||
|
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
||||||
|
|
||||||
|
# CudaGraph compilation
|
||||||
use_cudagraph: bool = False
|
use_cudagraph: bool = False
|
||||||
|
"""Whether to use cudagraph inside compilation.
|
||||||
|
- False: cudagraph inside compilation is not used.
|
||||||
|
- True: cudagraph inside compilation is used. It requires
|
||||||
|
that all input buffers have fixed addresses, and all
|
||||||
|
splitting ops write their outputs to input buffers.
|
||||||
|
Note that this is orthogonal to the cudagraph capture logic
|
||||||
|
outside of compilation.
|
||||||
|
TODO: move outside cudagraph logic into compilation.
|
||||||
|
torch.compile will handle cudagraph capture logic in the future."""
|
||||||
cudagraph_num_of_warmups: int = 0
|
cudagraph_num_of_warmups: int = 0
|
||||||
|
"""Number of warmup runs for cudagraph.
|
||||||
|
It means the first several runs will be treated as warmup runs.
|
||||||
|
Only after that, the execution will be recorded, and the recorded
|
||||||
|
cudagraph will be used for subsequent runs."""
|
||||||
cudagraph_capture_sizes: Optional[list[int]] = None
|
cudagraph_capture_sizes: Optional[list[int]] = None
|
||||||
|
"""Sizes to capture cudagraph.
|
||||||
|
- None (default): capture sizes are inferred from vllm config.
|
||||||
|
- list[int]: capture sizes are specified as given."""
|
||||||
cudagraph_copy_inputs: bool = False
|
cudagraph_copy_inputs: bool = False
|
||||||
|
"""Whether to copy input tensors for
|
||||||
|
cudagraph. If the caller can guarantee that the same input buffers
|
||||||
|
are always used, it can set this to False. Otherwise, it should
|
||||||
|
set this to True, and the compiler will copy the input to an
|
||||||
|
internally managed buffer. Default is False."""
|
||||||
full_cuda_graph: bool = False
|
full_cuda_graph: bool = False
|
||||||
|
"""whether to use a full cuda graph for the entire forward pass rather than
|
||||||
|
splitting certain operations such as attention into subgraphs. Thus this
|
||||||
|
flag cannot be used together with splitting_ops. This may provide
|
||||||
|
performance benefits for smaller models."""
|
||||||
|
|
||||||
class PassConfig(BaseModel):
|
pass_config: PassConfig = field(default_factory=PassConfig)
|
||||||
"""
|
"""Custom inductor passes, see PassConfig for more details"""
|
||||||
Configuration for custom Inductor passes.
|
|
||||||
This is separate from general CompilationConfig so that inductor passes
|
|
||||||
don't all have access to full configuration - that would create a cycle
|
|
||||||
as the PassManager is set as a property of config.
|
|
||||||
- dump_graph_stages: list of stages for which we want to dump the graph.
|
|
||||||
Each pass defines its own stages (before, after, maybe in-between).
|
|
||||||
- dump_graph_dir: directory to dump the graphs. Default is .
|
|
||||||
- enable_fusion: whether to enable the custom fusion pass.
|
|
||||||
- enable_noop: whether to enable the custom no-op elimination pass.
|
|
||||||
TODO(luka) better pass enabling system.
|
|
||||||
- enable_sequence_parallelism: whether to enable sequence parallelism.
|
|
||||||
"""
|
|
||||||
dump_graph_stages: list[str] = Field(default_factory=list)
|
|
||||||
dump_graph_dir: Path = Field(default=Path("."))
|
|
||||||
enable_fusion: bool = True
|
|
||||||
enable_noop: bool = True
|
|
||||||
enable_sequence_parallelism: bool = False
|
|
||||||
|
|
||||||
def uuid(self):
|
max_capture_size: int = field(default=None, init=False) # type: ignore
|
||||||
"""
|
"""not configurable, computed after init"""
|
||||||
Produces a hash unique to the pass configuration.
|
local_cache_dir: str = field(default=None, init=False) # type: ignore
|
||||||
Any new fields that affect compilation should be added to the hash.
|
"""local cache dir for each rank"""
|
||||||
Do not include dump_graph_* in the hash - they don't affect
|
bs_to_padded_graph_size: list[int] = field(
|
||||||
compilation.
|
default=None, # type: ignore
|
||||||
"""
|
init=False)
|
||||||
dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \
|
"""optimization:
|
||||||
"enable_sequence_parallelism"})
|
Intuitively, bs_to_padded_graph_size should be dict[int, int].
|
||||||
return InductorPass.hash_dict(dict_)
|
since we know all keys are in a range [0, max_capture_size],
|
||||||
|
we can optimize it to list[int] for better lookup performance."""
|
||||||
def model_post_init(self, __context: Any) -> None:
|
|
||||||
if not self.enable_noop and self.enable_fusion:
|
|
||||||
logger.warning_once(
|
|
||||||
"Fusion enabled but reshape elimination disabled. "
|
|
||||||
"RMSNorm + quant (fp8) fusion might not work")
|
|
||||||
|
|
||||||
pass_config: PassConfig = Field(default_factory=PassConfig)
|
|
||||||
|
|
||||||
# not configurable, computed after init
|
|
||||||
max_capture_size: int = PrivateAttr
|
|
||||||
local_cache_dir: str = PrivateAttr # local cache dir for each rank
|
|
||||||
# optimization:
|
|
||||||
# Intuitively, bs_to_padded_graph_size should be dict[int, int].
|
|
||||||
# since we know all keys are in a range [0, max_capture_size],
|
|
||||||
# we can optimize it to list[int] for better lookup performance.
|
|
||||||
bs_to_padded_graph_size: list[int] = PrivateAttr
|
|
||||||
|
|
||||||
# keep track of enabled and disabled custom ops
|
# keep track of enabled and disabled custom ops
|
||||||
enabled_custom_ops: Counter[str] = PrivateAttr
|
enabled_custom_ops: Counter[str] = field(default_factory=Counter,
|
||||||
disabled_custom_ops: Counter[str] = PrivateAttr
|
init=False)
|
||||||
traced_files: set[str] = PrivateAttr
|
"""custom ops that are enabled"""
|
||||||
compilation_time: float = PrivateAttr
|
disabled_custom_ops: Counter[str] = field(default_factory=Counter,
|
||||||
|
init=False)
|
||||||
|
"""custom ops that are disabled"""
|
||||||
|
traced_files: set[str] = field(default_factory=set, init=False)
|
||||||
|
"""files that are traced for compilation"""
|
||||||
|
compilation_time: float = field(default=0.0, init=False)
|
||||||
|
"""time taken for compilation"""
|
||||||
|
|
||||||
# Per-model forward context
|
static_forward_context: dict[str, Any] = field(default_factory=dict,
|
||||||
# Map from layer name to layer objects that need to be accessed outside
|
init=False)
|
||||||
# model code, e.g., Attention, FusedMOE when dp_size>1.
|
"""Per-model forward context
|
||||||
static_forward_context: dict[str, Any] = PrivateAttr
|
Map from layer name to layer objects that need to be accessed outside
|
||||||
|
model code, e.g., Attention, FusedMOE when dp_size>1."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -3757,7 +3804,17 @@ class CompilationConfig(BaseModel):
|
|||||||
"pass_config",
|
"pass_config",
|
||||||
"traced_files",
|
"traced_files",
|
||||||
}
|
}
|
||||||
return self.model_dump_json(exclude=exclude, exclude_unset=True)
|
include = dict()
|
||||||
|
for k, v in asdict(self).items():
|
||||||
|
if k in exclude:
|
||||||
|
continue
|
||||||
|
f = get_field(CompilationConfig, k)
|
||||||
|
if (d := f.default) is not MISSING and d == v:
|
||||||
|
continue
|
||||||
|
if (df := f.default_factory) is not MISSING and df() == v:
|
||||||
|
continue
|
||||||
|
include[k] = v
|
||||||
|
return json.dumps(include)
|
||||||
|
|
||||||
__str__ = __repr__
|
__str__ = __repr__
|
||||||
|
|
||||||
@ -3766,12 +3823,9 @@ class CompilationConfig(BaseModel):
|
|||||||
"""Parse the CLI value for the compilation config."""
|
"""Parse the CLI value for the compilation config."""
|
||||||
if cli_value in ["0", "1", "2", "3"]:
|
if cli_value in ["0", "1", "2", "3"]:
|
||||||
return cls(level=int(cli_value))
|
return cls(level=int(cli_value))
|
||||||
# do not use `eval`, it is dangerous and can execute arbitrary code
|
return cls(**json.loads(cli_value))
|
||||||
dict_value = ast.literal_eval(cli_value)
|
|
||||||
return CompilationConfig.model_validate(dict_value)
|
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
count_none = self.custom_ops.count("none")
|
count_none = self.custom_ops.count("none")
|
||||||
count_all = self.custom_ops.count("all")
|
count_all = self.custom_ops.count("all")
|
||||||
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
|
||||||
@ -3789,9 +3843,6 @@ class CompilationConfig(BaseModel):
|
|||||||
if KEY not in self.inductor_compile_config:
|
if KEY not in self.inductor_compile_config:
|
||||||
self.inductor_compile_config[KEY] = False
|
self.inductor_compile_config[KEY] = False
|
||||||
|
|
||||||
if self.splitting_ops is None:
|
|
||||||
self.splitting_ops = []
|
|
||||||
|
|
||||||
for k, v in self.inductor_passes.items():
|
for k, v in self.inductor_passes.items():
|
||||||
if not isinstance(v, str):
|
if not isinstance(v, str):
|
||||||
assert callable(v), (
|
assert callable(v), (
|
||||||
@ -3808,11 +3859,8 @@ class CompilationConfig(BaseModel):
|
|||||||
self.inductor_compile_config[k] = func if isinstance(
|
self.inductor_compile_config[k] = func if isinstance(
|
||||||
func, InductorPass) else CallableInductorPass(func)
|
func, InductorPass) else CallableInductorPass(func)
|
||||||
|
|
||||||
self.enabled_custom_ops = Counter()
|
if isinstance(self.pass_config, dict):
|
||||||
self.disabled_custom_ops = Counter()
|
self.pass_config = PassConfig(**self.pass_config)
|
||||||
self.traced_files = set()
|
|
||||||
self.static_forward_context = {}
|
|
||||||
self.compilation_time = 0.0
|
|
||||||
|
|
||||||
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
|
||||||
if self.level == CompilationLevel.NO_COMPILATION:
|
if self.level == CompilationLevel.NO_COMPILATION:
|
||||||
@ -3899,39 +3947,67 @@ class CompilationConfig(BaseModel):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class VllmConfig:
|
class VllmConfig:
|
||||||
"""Dataclass which contains all vllm-related configuration. This
|
"""Dataclass which contains all vllm-related configuration. This
|
||||||
simplifies passing around the distinct configurations in the codebase.
|
simplifies passing around the distinct configurations in the codebase.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config: ModelConfig = field(default=None, init=True) # type: ignore
|
model_config: ModelConfig = field(default_factory=ModelConfig)
|
||||||
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
|
"""Model configuration."""
|
||||||
parallel_config: ParallelConfig = field(default_factory=ParallelConfig,
|
cache_config: CacheConfig = field(default_factory=CacheConfig)
|
||||||
init=True)
|
"""Cache configuration."""
|
||||||
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig,
|
parallel_config: ParallelConfig = field(default_factory=ParallelConfig)
|
||||||
init=True)
|
"""Parallel configuration."""
|
||||||
device_config: DeviceConfig = field(default=None,
|
scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig)
|
||||||
init=True) # type: ignore
|
"""Scheduler configuration."""
|
||||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
device_config: DeviceConfig = field(default_factory=DeviceConfig)
|
||||||
|
"""Device configuration."""
|
||||||
|
load_config: LoadConfig = field(default_factory=LoadConfig)
|
||||||
|
"""Load configuration."""
|
||||||
lora_config: Optional[LoRAConfig] = None
|
lora_config: Optional[LoRAConfig] = None
|
||||||
speculative_config: SpeculativeConfig = field(default=None,
|
"""LoRA configuration."""
|
||||||
init=True) # type: ignore
|
speculative_config: Optional[SpeculativeConfig] = None
|
||||||
|
"""Speculative decoding configuration."""
|
||||||
decoding_config: Optional[DecodingConfig] = None
|
decoding_config: Optional[DecodingConfig] = None
|
||||||
|
"""Decoding configuration."""
|
||||||
observability_config: Optional[ObservabilityConfig] = None
|
observability_config: Optional[ObservabilityConfig] = None
|
||||||
|
"""Observability configuration."""
|
||||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||||
|
"""Prompt adapter configuration."""
|
||||||
quant_config: Optional[QuantizationConfig] = None
|
quant_config: Optional[QuantizationConfig] = None
|
||||||
compilation_config: CompilationConfig = field(default=None,
|
"""Quantization configuration."""
|
||||||
init=True) # type: ignore
|
compilation_config: CompilationConfig = field(
|
||||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
default_factory=CompilationConfig)
|
||||||
init=True) # type: ignore
|
"""`torch.compile` configuration for the model.
|
||||||
|
|
||||||
|
When it is a number (0, 1, 2, 3), it will be interpreted as the
|
||||||
|
optimization level.
|
||||||
|
|
||||||
|
NOTE: level 0 is the default level without any optimization. level 1 and 2
|
||||||
|
are for internal testing only. level 3 is the recommended level for
|
||||||
|
production.
|
||||||
|
|
||||||
|
Following the convention of traditional compilers, using `-O` without space
|
||||||
|
is also supported. `-O3` is equivalent to `-O 3`.
|
||||||
|
|
||||||
|
You can specify the full compilation config like so:
|
||||||
|
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
|
||||||
|
"""
|
||||||
|
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||||
|
"""The configurations for distributed KV cache transfer."""
|
||||||
kv_events_config: Optional[KVEventsConfig] = None
|
kv_events_config: Optional[KVEventsConfig] = None
|
||||||
|
"""The configurations for event publishing."""
|
||||||
# some opaque config, only used to provide additional information
|
# some opaque config, only used to provide additional information
|
||||||
# for the hash computation, mainly used for testing, debugging or out of
|
# for the hash computation, mainly used for testing, debugging or out of
|
||||||
# tree config registration.
|
# tree config registration.
|
||||||
additional_config: SupportsHash = field(default=None,
|
additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
|
||||||
init=True) # type: ignore
|
"""Additional config for specified platform. Different platforms may
|
||||||
|
support different configs. Make sure the configs are valid for the platform
|
||||||
|
you are using. Contents must be hashable."""
|
||||||
instance_id: str = ""
|
instance_id: str = ""
|
||||||
|
"""The ID of the vLLM instance."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -4012,7 +4088,14 @@ class VllmConfig:
|
|||||||
else:
|
else:
|
||||||
vllm_factors.append("None")
|
vllm_factors.append("None")
|
||||||
if self.additional_config:
|
if self.additional_config:
|
||||||
vllm_factors.append(self.additional_config.compute_hash())
|
if isinstance(additional_config := self.additional_config, dict):
|
||||||
|
additional_config_hash = hashlib.md5(
|
||||||
|
json.dumps(additional_config, sort_keys=True).encode(),
|
||||||
|
usedforsecurity=False,
|
||||||
|
).hexdigest()
|
||||||
|
else:
|
||||||
|
additional_config_hash = additional_config.compute_hash()
|
||||||
|
vllm_factors.append(additional_config_hash)
|
||||||
else:
|
else:
|
||||||
vllm_factors.append("None")
|
vllm_factors.append("None")
|
||||||
factors.append(vllm_factors)
|
factors.append(vllm_factors)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from dataclasses import asdict
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
@ -284,7 +285,7 @@ class EventPublisherFactory:
|
|||||||
if not config:
|
if not config:
|
||||||
return NullEventPublisher()
|
return NullEventPublisher()
|
||||||
|
|
||||||
config_dict = config.model_dump()
|
config_dict = asdict(config)
|
||||||
|
|
||||||
kind = config_dict.pop("publisher", "null")
|
kind = config_dict.pop("publisher", "null")
|
||||||
config_dict.pop("enable_kv_cache_events")
|
config_dict.pop("enable_kv_cache_events")
|
||||||
|
|||||||
@ -7,10 +7,10 @@ import json
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import MISSING, dataclass, fields
|
from dataclasses import MISSING, dataclass, fields, is_dataclass
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
|
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
|
||||||
TypeVar, Union, cast, get_args, get_origin)
|
Type, TypeVar, Union, cast, get_args, get_origin)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import TypeIs, deprecated
|
from typing_extensions import TypeIs, deprecated
|
||||||
@ -36,7 +36,8 @@ from vllm.reasoning import ReasoningParserManager
|
|||||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
|
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
|
||||||
|
is_in_ray_actor)
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
@ -48,12 +49,9 @@ TypeHint = Union[type[Any], object]
|
|||||||
TypeHintT = Union[type[T], object]
|
TypeHintT = Union[type[T], object]
|
||||||
|
|
||||||
|
|
||||||
def optional_type(
|
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
||||||
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
|
|
||||||
|
|
||||||
def _optional_type(val: str) -> Optional[T]:
|
def _parse_type(val: str) -> T:
|
||||||
if val == "" or val == "None":
|
|
||||||
return None
|
|
||||||
try:
|
try:
|
||||||
if return_type is json.loads and not re.match("^{.*}$", val):
|
if return_type is json.loads and not re.match("^{.*}$", val):
|
||||||
return cast(T, nullable_kvs(val))
|
return cast(T, nullable_kvs(val))
|
||||||
@ -62,14 +60,24 @@ def optional_type(
|
|||||||
raise argparse.ArgumentTypeError(
|
raise argparse.ArgumentTypeError(
|
||||||
f"Value {val} cannot be converted to {return_type}.") from e
|
f"Value {val} cannot be converted to {return_type}.") from e
|
||||||
|
|
||||||
|
return _parse_type
|
||||||
|
|
||||||
|
|
||||||
|
def optional_type(
|
||||||
|
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
|
||||||
|
|
||||||
|
def _optional_type(val: str) -> Optional[T]:
|
||||||
|
if val == "" or val == "None":
|
||||||
|
return None
|
||||||
|
return parse_type(return_type)(val)
|
||||||
|
|
||||||
return _optional_type
|
return _optional_type
|
||||||
|
|
||||||
|
|
||||||
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
|
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
|
||||||
if not re.match("^{.*}$", val):
|
if not re.match("^{.*}$", val):
|
||||||
return str(val)
|
return str(val)
|
||||||
else:
|
return optional_type(json.loads)(val)
|
||||||
return optional_type(json.loads)(val)
|
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
@ -144,10 +152,25 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|||||||
cls_docs = get_attr_docs(cls)
|
cls_docs = get_attr_docs(cls)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for field in fields(cls):
|
for field in fields(cls):
|
||||||
|
# Get the set of possible types for the field
|
||||||
|
type_hints: set[TypeHint] = set()
|
||||||
|
if get_origin(field.type) in {Union, Annotated}:
|
||||||
|
type_hints.update(get_args(field.type))
|
||||||
|
else:
|
||||||
|
type_hints.add(field.type)
|
||||||
|
|
||||||
|
# If the field is a dataclass, we can use the model_validate_json
|
||||||
|
generator = (th for th in type_hints if is_dataclass(th))
|
||||||
|
dataclass_cls = next(generator, None)
|
||||||
|
|
||||||
# Get the default value of the field
|
# Get the default value of the field
|
||||||
default = field.default
|
if field.default is not MISSING:
|
||||||
if field.default_factory is not MISSING:
|
default = field.default
|
||||||
default = field.default_factory()
|
elif field.default_factory is not MISSING:
|
||||||
|
if is_dataclass(field.default_factory) and is_in_doc_build():
|
||||||
|
default = {}
|
||||||
|
else:
|
||||||
|
default = field.default_factory()
|
||||||
|
|
||||||
# Get the help text for the field
|
# Get the help text for the field
|
||||||
name = field.name
|
name = field.name
|
||||||
@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|||||||
# Initialise the kwargs dictionary for the field
|
# Initialise the kwargs dictionary for the field
|
||||||
kwargs[name] = {"default": default, "help": help}
|
kwargs[name] = {"default": default, "help": help}
|
||||||
|
|
||||||
# Get the set of possible types for the field
|
|
||||||
type_hints: set[TypeHint] = set()
|
|
||||||
if get_origin(field.type) is Union:
|
|
||||||
type_hints.update(get_args(field.type))
|
|
||||||
else:
|
|
||||||
type_hints.add(field.type)
|
|
||||||
|
|
||||||
# Set other kwargs based on the type hints
|
# Set other kwargs based on the type hints
|
||||||
json_tip = "\n\nShould be a valid JSON string."
|
json_tip = "\n\nShould be a valid JSON string."
|
||||||
if contains_type(type_hints, bool):
|
if dataclass_cls is not None:
|
||||||
|
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
|
||||||
|
# Special case for configs with a from_cli method
|
||||||
|
if hasattr(dataclass_cls, "from_cli"):
|
||||||
|
from_cli = dataclass_cls.from_cli
|
||||||
|
dataclass_init = lambda x, f=from_cli: f(x)
|
||||||
|
kwargs[name]["type"] = dataclass_init
|
||||||
|
kwargs[name]["help"] += json_tip
|
||||||
|
elif contains_type(type_hints, bool):
|
||||||
# Creates --no-<name> and --<name> flags
|
# Creates --no-<name> and --<name> flags
|
||||||
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||||
elif contains_type(type_hints, Literal):
|
elif contains_type(type_hints, Literal):
|
||||||
@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|||||||
kwargs[name]["type"] = union_dict_and_str
|
kwargs[name]["type"] = union_dict_and_str
|
||||||
elif contains_type(type_hints, dict):
|
elif contains_type(type_hints, dict):
|
||||||
# Dict arguments will always be optional
|
# Dict arguments will always be optional
|
||||||
kwargs[name]["type"] = optional_type(json.loads)
|
kwargs[name]["type"] = parse_type(json.loads)
|
||||||
kwargs[name]["help"] += json_tip
|
kwargs[name]["help"] += json_tip
|
||||||
elif (contains_type(type_hints, str)
|
elif (contains_type(type_hints, str)
|
||||||
or any(is_not_builtin(th) for th in type_hints)):
|
or any(is_not_builtin(th) for th in type_hints)):
|
||||||
@ -771,63 +795,20 @@ class EngineArgs:
|
|||||||
scheduler_group.add_argument("--scheduler-cls",
|
scheduler_group.add_argument("--scheduler-cls",
|
||||||
**scheduler_kwargs["scheduler_cls"])
|
**scheduler_kwargs["scheduler_cls"])
|
||||||
|
|
||||||
# Compilation arguments
|
|
||||||
# compilation_kwargs = get_kwargs(CompilationConfig)
|
|
||||||
compilation_group = parser.add_argument_group(
|
|
||||||
title="CompilationConfig",
|
|
||||||
description=CompilationConfig.__doc__,
|
|
||||||
)
|
|
||||||
compilation_group.add_argument(
|
|
||||||
"--compilation-config",
|
|
||||||
"-O",
|
|
||||||
type=CompilationConfig.from_cli,
|
|
||||||
default=None,
|
|
||||||
help="torch.compile configuration for the model. "
|
|
||||||
"When it is a number (0, 1, 2, 3), it will be "
|
|
||||||
"interpreted as the optimization level.\n"
|
|
||||||
"NOTE: level 0 is the default level without "
|
|
||||||
"any optimization. level 1 and 2 are for internal "
|
|
||||||
"testing only. level 3 is the recommended level "
|
|
||||||
"for production.\n"
|
|
||||||
"To specify the full compilation config, "
|
|
||||||
"use a JSON string, e.g. ``{\"level\": 3, "
|
|
||||||
"\"cudagraph_capture_sizes\": [1, 2, 4, 8]}``\n"
|
|
||||||
"Following the convention of traditional "
|
|
||||||
"compilers, using ``-O`` without space is also "
|
|
||||||
"supported. ``-O3`` is equivalent to ``-O 3``.")
|
|
||||||
|
|
||||||
# KVTransfer arguments
|
|
||||||
# kv_transfer_kwargs = get_kwargs(KVTransferConfig)
|
|
||||||
kv_transfer_group = parser.add_argument_group(
|
|
||||||
title="KVTransferConfig",
|
|
||||||
description=KVTransferConfig.__doc__,
|
|
||||||
)
|
|
||||||
kv_transfer_group.add_argument(
|
|
||||||
"--kv-transfer-config",
|
|
||||||
type=KVTransferConfig.from_cli,
|
|
||||||
default=None,
|
|
||||||
help="The configurations for distributed KV cache "
|
|
||||||
"transfer. Should be a JSON string.")
|
|
||||||
kv_transfer_group.add_argument(
|
|
||||||
'--kv-events-config',
|
|
||||||
type=KVEventsConfig.from_cli,
|
|
||||||
default=None,
|
|
||||||
help='The configurations for event publishing.')
|
|
||||||
|
|
||||||
# vLLM arguments
|
# vLLM arguments
|
||||||
# vllm_kwargs = get_kwargs(VllmConfig)
|
vllm_kwargs = get_kwargs(VllmConfig)
|
||||||
vllm_group = parser.add_argument_group(
|
vllm_group = parser.add_argument_group(
|
||||||
title="VllmConfig",
|
title="VllmConfig",
|
||||||
description=VllmConfig.__doc__,
|
description=VllmConfig.__doc__,
|
||||||
)
|
)
|
||||||
vllm_group.add_argument(
|
vllm_group.add_argument("--kv-transfer-config",
|
||||||
"--additional-config",
|
**vllm_kwargs["kv_transfer_config"])
|
||||||
type=json.loads,
|
vllm_group.add_argument('--kv-events-config',
|
||||||
default=None,
|
**vllm_kwargs["kv_events_config"])
|
||||||
help="Additional config for specified platform in JSON format. "
|
vllm_group.add_argument("--compilation-config", "-O",
|
||||||
"Different platforms may support different configs. Make sure the "
|
**vllm_kwargs["compilation_config"])
|
||||||
"configs are valid for the platform you are using. The input format"
|
vllm_group.add_argument("--additional-config",
|
||||||
" is like '{\"config_key\":\"config_value\"}'")
|
**vllm_kwargs["additional_config"])
|
||||||
|
|
||||||
# Other arguments
|
# Other arguments
|
||||||
parser.add_argument('--use-v2-block-manager',
|
parser.add_argument('--use-v2-block-manager',
|
||||||
|
|||||||
@ -13,7 +13,8 @@ from typing_extensions import TypeVar, deprecated
|
|||||||
|
|
||||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||||
BeamSearchSequence, get_beam_search_score)
|
BeamSearchSequence, get_beam_search_score)
|
||||||
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
|
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
|
||||||
|
is_init_field)
|
||||||
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
||||||
TaskOption)
|
TaskOption)
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
@ -204,9 +205,13 @@ class LLM:
|
|||||||
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
|
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
|
||||||
|
|
||||||
if compilation_config is not None:
|
if compilation_config is not None:
|
||||||
if isinstance(compilation_config, (int, dict)):
|
if isinstance(compilation_config, int):
|
||||||
compilation_config_instance = CompilationConfig.from_cli(
|
compilation_config_instance = CompilationConfig(
|
||||||
str(compilation_config))
|
level=compilation_config)
|
||||||
|
elif isinstance(compilation_config, dict):
|
||||||
|
predicate = lambda x: is_init_field(CompilationConfig, x[0])
|
||||||
|
compilation_config_instance = CompilationConfig(
|
||||||
|
**dict(filter(predicate, compilation_config.items())))
|
||||||
else:
|
else:
|
||||||
compilation_config_instance = compilation_config
|
compilation_config_instance = compilation_config
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tpu_info import device
|
from tpu_info import device
|
||||||
@ -13,9 +13,10 @@ from vllm.sampling_params import SamplingParams, SamplingType
|
|||||||
from .interface import Platform, PlatformEnum, _Backend
|
from .interface import Platform, PlatformEnum, _Backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
else:
|
else:
|
||||||
|
BlockSize = None
|
||||||
ModelConfig = None
|
ModelConfig = None
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
PoolingParams = None
|
PoolingParams = None
|
||||||
@ -94,7 +95,7 @@ class TpuPlatform(Platform):
|
|||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
# For v0, the default block size is 16.
|
# For v0, the default block size is 16.
|
||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 16
|
cache_config.block_size = cast(BlockSize, 16)
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
|
|
||||||
# TPU only supports DYNAMO_ONCE compilation level
|
# TPU only supports DYNAMO_ONCE compilation level
|
||||||
@ -118,7 +119,7 @@ class TpuPlatform(Platform):
|
|||||||
from vllm.v1.attention.backends.pallas import (
|
from vllm.v1.attention.backends.pallas import (
|
||||||
PallasAttentionBackend)
|
PallasAttentionBackend)
|
||||||
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
||||||
vllm_config)
|
vllm_config) # type: ignore[assignment]
|
||||||
min_page_size = PallasAttentionBackend.get_min_page_size(
|
min_page_size = PallasAttentionBackend.get_min_page_size(
|
||||||
vllm_config)
|
vllm_config)
|
||||||
if min_page_size > cache_config.block_size:
|
if min_page_size > cache_config.block_size:
|
||||||
@ -128,7 +129,7 @@ class TpuPlatform(Platform):
|
|||||||
cache_config.block_size,
|
cache_config.block_size,
|
||||||
min_page_size,
|
min_page_size,
|
||||||
)
|
)
|
||||||
cache_config.block_size = min_page_size
|
cache_config.block_size = min_page_size # type: ignore[assignment]
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
|||||||
@ -1820,6 +1820,14 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def is_in_doc_build() -> bool:
|
||||||
|
try:
|
||||||
|
from sphinx.ext.autodoc.mock import _MockModule
|
||||||
|
return isinstance(zmq, _MockModule)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
|
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
|
||||||
"""
|
"""
|
||||||
Import a Python file according to its file path.
|
Import a Python file according to its file path.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user