mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:14:54 +08:00
[Frontend] Change CompilationMode to a proper Enum (#28165)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
parent
1788aa1efb
commit
48c879369f
@ -127,7 +127,9 @@ def test_compile_correctness(
|
|||||||
CompilationMode.VLLM_COMPILE,
|
CompilationMode.VLLM_COMPILE,
|
||||||
]:
|
]:
|
||||||
for mode in [CompilationMode.NONE, comp_mode]:
|
for mode in [CompilationMode.NONE, comp_mode]:
|
||||||
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=inductor"])
|
all_args.append(
|
||||||
|
final_args + [f"-O.mode={mode.name}", "-O.backend=inductor"]
|
||||||
|
)
|
||||||
|
|
||||||
# inductor will change the output, so we only compare if the output
|
# inductor will change the output, so we only compare if the output
|
||||||
# is close, not exactly the same.
|
# is close, not exactly the same.
|
||||||
@ -146,7 +148,7 @@ def test_compile_correctness(
|
|||||||
CompilationMode.DYNAMO_TRACE_ONCE,
|
CompilationMode.DYNAMO_TRACE_ONCE,
|
||||||
CompilationMode.VLLM_COMPILE,
|
CompilationMode.VLLM_COMPILE,
|
||||||
]:
|
]:
|
||||||
all_args.append(final_args + [f"-O.mode={mode}", "-O.backend=eager"])
|
all_args.append(final_args + [f"-O.mode={mode.name}", "-O.backend=eager"])
|
||||||
all_envs.append({})
|
all_envs.append({})
|
||||||
all_envs.append({})
|
all_envs.append({})
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import os
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
||||||
|
|
||||||
@ -376,6 +377,65 @@ def test_load_config_file(tmp_path):
|
|||||||
os.remove(str(config_file_path))
|
os.remove(str(config_file_path))
|
||||||
|
|
||||||
|
|
||||||
|
def test_compilation_mode_string_values(parser):
|
||||||
|
"""Test that -O.mode accepts both integer and string mode values."""
|
||||||
|
args = parser.parse_args(["-O.mode", "0"])
|
||||||
|
assert args.compilation_config == {"mode": 0}
|
||||||
|
|
||||||
|
args = parser.parse_args(["-O3"])
|
||||||
|
assert args.compilation_config == {"mode": 3}
|
||||||
|
|
||||||
|
args = parser.parse_args(["-O.mode=NONE"])
|
||||||
|
assert args.compilation_config == {"mode": "NONE"}
|
||||||
|
|
||||||
|
args = parser.parse_args(["-O.mode", "STOCK_TORCH_COMPILE"])
|
||||||
|
assert args.compilation_config == {"mode": "STOCK_TORCH_COMPILE"}
|
||||||
|
|
||||||
|
args = parser.parse_args(["-O.mode=DYNAMO_TRACE_ONCE"])
|
||||||
|
assert args.compilation_config == {"mode": "DYNAMO_TRACE_ONCE"}
|
||||||
|
|
||||||
|
args = parser.parse_args(["-O.mode", "VLLM_COMPILE"])
|
||||||
|
assert args.compilation_config == {"mode": "VLLM_COMPILE"}
|
||||||
|
|
||||||
|
args = parser.parse_args(["-O.mode=none"])
|
||||||
|
assert args.compilation_config == {"mode": "none"}
|
||||||
|
|
||||||
|
args = parser.parse_args(["-O.mode=vllm_compile"])
|
||||||
|
assert args.compilation_config == {"mode": "vllm_compile"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_compilation_config_mode_validator():
|
||||||
|
"""Test that CompilationConfig.mode field validator converts strings to integers."""
|
||||||
|
from vllm.config.compilation import CompilationConfig, CompilationMode
|
||||||
|
|
||||||
|
config = CompilationConfig(mode=0)
|
||||||
|
assert config.mode == CompilationMode.NONE
|
||||||
|
|
||||||
|
config = CompilationConfig(mode=3)
|
||||||
|
assert config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
|
||||||
|
config = CompilationConfig(mode="NONE")
|
||||||
|
assert config.mode == CompilationMode.NONE
|
||||||
|
|
||||||
|
config = CompilationConfig(mode="STOCK_TORCH_COMPILE")
|
||||||
|
assert config.mode == CompilationMode.STOCK_TORCH_COMPILE
|
||||||
|
|
||||||
|
config = CompilationConfig(mode="DYNAMO_TRACE_ONCE")
|
||||||
|
assert config.mode == CompilationMode.DYNAMO_TRACE_ONCE
|
||||||
|
|
||||||
|
config = CompilationConfig(mode="VLLM_COMPILE")
|
||||||
|
assert config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
|
||||||
|
config = CompilationConfig(mode="none")
|
||||||
|
assert config.mode == CompilationMode.NONE
|
||||||
|
|
||||||
|
config = CompilationConfig(mode="vllm_compile")
|
||||||
|
assert config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="Invalid compilation mode"):
|
||||||
|
CompilationConfig(mode="INVALID_MODE")
|
||||||
|
|
||||||
|
|
||||||
def test_flat_product():
|
def test_flat_product():
|
||||||
# Check regular itertools.product behavior
|
# Check regular itertools.product behavior
|
||||||
result1 = list(flat_product([1, 2, 3], ["a", "b"]))
|
result1 = list(flat_product([1, 2, 3], ["a", "b"]))
|
||||||
|
|||||||
@ -31,7 +31,9 @@ class TorchCompileWrapperWithCustomDispatcher:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, compiled_callable: Callable | None = None, compilation_mode: int = 0
|
self,
|
||||||
|
compiled_callable: Callable | None = None,
|
||||||
|
compilation_mode: CompilationMode = CompilationMode.NONE,
|
||||||
):
|
):
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
|||||||
@ -28,7 +28,7 @@ else:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CompilationMode:
|
class CompilationMode(enum.IntEnum):
|
||||||
"""The compilation approach used for torch.compile-based compilation of the
|
"""The compilation approach used for torch.compile-based compilation of the
|
||||||
model."""
|
model."""
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ class CompilationConfig:
|
|||||||
Please use mode. Currently all levels are mapped to mode.
|
Please use mode. Currently all levels are mapped to mode.
|
||||||
"""
|
"""
|
||||||
# Top-level Compilation control
|
# Top-level Compilation control
|
||||||
mode: int | None = None
|
mode: CompilationMode | None = None
|
||||||
"""The compilation approach used for torch.compile-based compilation of the
|
"""The compilation approach used for torch.compile-based compilation of the
|
||||||
model.
|
model.
|
||||||
|
|
||||||
@ -579,6 +579,27 @@ class CompilationConfig:
|
|||||||
|
|
||||||
__str__ = __repr__
|
__str__ = __repr__
|
||||||
|
|
||||||
|
@field_validator("mode", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_mode_before(cls, value: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Enable parsing the `mode` field from string mode names.
|
||||||
|
Accepts both integers (0-3) and string names, like NONE, STOCK_TORCH_COMPILE,
|
||||||
|
DYNAMO_TRACE_ONCE, VLLM_COMPILE.
|
||||||
|
"""
|
||||||
|
if isinstance(value, str):
|
||||||
|
# Convert string mode name to integer value
|
||||||
|
mode_name = value.upper()
|
||||||
|
|
||||||
|
if mode_name not in CompilationMode.__members__:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid compilation mode: {value}. "
|
||||||
|
f"Valid modes are: {', '.join(CompilationMode.__members__.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return CompilationMode[mode_name]
|
||||||
|
return value
|
||||||
|
|
||||||
@field_validator("cudagraph_mode", mode="before")
|
@field_validator("cudagraph_mode", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
|
def validate_cudagraph_mode_before(cls, value: Any) -> Any:
|
||||||
@ -904,7 +925,7 @@ class CompilationConfig:
|
|||||||
return self.mode == CompilationMode.VLLM_COMPILE
|
return self.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
|
||||||
# Inductor partition case
|
# Inductor partition case
|
||||||
return self.backend == "inductor" and self.mode > CompilationMode.NONE
|
return self.backend == "inductor" and self.mode != CompilationMode.NONE
|
||||||
|
|
||||||
def custom_op_log_check(self):
|
def custom_op_log_check(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -422,16 +422,13 @@ class VllmConfig:
|
|||||||
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
self.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
||||||
else:
|
else:
|
||||||
self.compilation_config.mode = CompilationMode.NONE
|
self.compilation_config.mode = CompilationMode.NONE
|
||||||
else:
|
|
||||||
assert self.compilation_config.mode >= CompilationMode.NONE
|
|
||||||
assert self.compilation_config.mode <= CompilationMode.VLLM_COMPILE
|
|
||||||
|
|
||||||
# If user does not set custom ops via none or all set it here based on
|
# If user does not set custom ops via none or all set it here based on
|
||||||
# compilation mode and backend.
|
# compilation mode and backend.
|
||||||
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
|
if all(s not in self.compilation_config.custom_ops for s in ("all", "none")):
|
||||||
if (
|
if (
|
||||||
self.compilation_config.backend == "inductor"
|
self.compilation_config.backend == "inductor"
|
||||||
and self.compilation_config.mode > CompilationMode.NONE
|
and self.compilation_config.mode != CompilationMode.NONE
|
||||||
):
|
):
|
||||||
self.compilation_config.custom_ops.append("none")
|
self.compilation_config.custom_ops.append("none")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from vllm.config import (
|
|||||||
StructuredOutputsConfig,
|
StructuredOutputsConfig,
|
||||||
is_init_field,
|
is_init_field,
|
||||||
)
|
)
|
||||||
|
from vllm.config.compilation import CompilationMode
|
||||||
from vllm.config.model import (
|
from vllm.config.model import (
|
||||||
ConvertOption,
|
ConvertOption,
|
||||||
HfOverrides,
|
HfOverrides,
|
||||||
@ -259,7 +260,9 @@ class LLM:
|
|||||||
|
|
||||||
if compilation_config is not None:
|
if compilation_config is not None:
|
||||||
if isinstance(compilation_config, int):
|
if isinstance(compilation_config, int):
|
||||||
compilation_config_instance = CompilationConfig(mode=compilation_config)
|
compilation_config_instance = CompilationConfig(
|
||||||
|
mode=CompilationMode(compilation_config)
|
||||||
|
)
|
||||||
elif isinstance(compilation_config, dict):
|
elif isinstance(compilation_config, dict):
|
||||||
compilation_config_instance = CompilationConfig(
|
compilation_config_instance = CompilationConfig(
|
||||||
**{
|
**{
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user