mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:45:01 +08:00
[v1] EngineArgs for better config handling for v1 (#10382)
Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
parent
a6760f6456
commit
519e8e4182
@ -172,7 +172,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/v1
|
||||
commands:
|
||||
- pytest -v -s v1
|
||||
- VLLM_USE_V1=1 pytest -v -s v1
|
||||
|
||||
- label: Examples Test # 15min
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
|
||||
@ -32,6 +32,9 @@ async def generate(engine: AsyncLLM, request_id: str,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(monkeypatch):
|
||||
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
|
||||
# so that in the future when we switch, we don't have to change all the
|
||||
# tests.
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
|
||||
42
tests/v1/engine/test_engine_args.py
Normal file
42
tests/v1/engine/test_engine_args.py
Normal file
@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
pytest.skip(
|
||||
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def test_defaults():
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
|
||||
# Assert V1 defaults
|
||||
assert (engine_args.enable_prefix_caching
|
||||
), "V1 turns on prefix caching by default"
|
||||
|
||||
|
||||
def test_defaults_with_usage_context():
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
vllm_config: VllmConfig = engine_args.create_engine_config(
|
||||
UsageContext.LLM_CLASS)
|
||||
|
||||
assert vllm_config.scheduler_config.max_num_seqs == 1024
|
||||
assert vllm_config.scheduler_config.max_num_batched_tokens == 8192
|
||||
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.OPENAI_API_SERVER)
|
||||
assert vllm_config.scheduler_config.max_num_seqs == 1024
|
||||
assert vllm_config.scheduler_config.max_num_batched_tokens == 2048
|
||||
|
||||
|
||||
def test_prefix_cache_disabled_with_multimodel():
|
||||
engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
|
||||
assert not vllm_config.cache_config.enable_prefix_caching
|
||||
@ -43,7 +43,8 @@ def test_engine_core(monkeypatch):
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
"""Setup the EngineCore."""
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
|
||||
@ -82,7 +82,8 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
vllm_config,
|
||||
@ -153,7 +154,8 @@ async def test_engine_core_client_asyncio(monkeypatch):
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = AsyncLLM._get_executor_cls(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
vllm_config,
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, StoreBoolean
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -113,7 +114,7 @@ class EngineArgs:
|
||||
# NOTE(kzawora): default block size for Gaudi should be 128
|
||||
# smaller sizes still work, but very inefficiently
|
||||
block_size: int = 16 if not current_platform.is_hpu() else 128
|
||||
enable_prefix_caching: bool = False
|
||||
enable_prefix_caching: Optional[bool] = None
|
||||
disable_sliding_window: bool = False
|
||||
use_v2_block_manager: bool = True
|
||||
swap_space: float = 4 # GiB
|
||||
@ -197,6 +198,11 @@ class EngineArgs:
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
|
||||
# Override the default value of enable_prefix_caching if it's not set
|
||||
# by user.
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
|
||||
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
# without having to manually construct a
|
||||
# CompilationConfig object
|
||||
@ -953,7 +959,12 @@ class EngineArgs:
|
||||
ignore_patterns=self.ignore_patterns,
|
||||
)
|
||||
|
||||
def create_engine_config(self) -> VllmConfig:
|
||||
def create_engine_config(self,
|
||||
usage_context: Optional[UsageContext] = None
|
||||
) -> VllmConfig:
|
||||
if envs.VLLM_USE_V1:
|
||||
self._override_v1_engine_args(usage_context)
|
||||
|
||||
# gguf file needs a specific model loader and doesn't use hf_repo
|
||||
if check_gguf_file(self.model):
|
||||
self.quantization = self.load_format = "gguf"
|
||||
@ -1170,7 +1181,7 @@ class EngineArgs:
|
||||
or "all" in detailed_trace_modules,
|
||||
)
|
||||
|
||||
return VllmConfig(
|
||||
config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
@ -1185,6 +1196,42 @@ class EngineArgs:
|
||||
compilation_config=self.compilation_config,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
self._override_v1_engine_config(config)
|
||||
return config
|
||||
|
||||
def _override_v1_engine_args(self, usage_context: UsageContext) -> None:
|
||||
"""
|
||||
Override the EngineArgs's args based on the usage context for V1.
|
||||
"""
|
||||
assert envs.VLLM_USE_V1, "V1 is not enabled"
|
||||
|
||||
if self.max_num_batched_tokens is None:
|
||||
# When no user override, set the default values based on the
|
||||
# usage context.
|
||||
if usage_context == UsageContext.LLM_CLASS:
|
||||
logger.warning("Setting max_num_batched_tokens to 8192 "
|
||||
"for LLM_CLASS usage context.")
|
||||
self.max_num_seqs = 1024
|
||||
self.max_num_batched_tokens = 8192
|
||||
elif usage_context == UsageContext.OPENAI_API_SERVER:
|
||||
logger.warning("Setting max_num_batched_tokens to 2048 "
|
||||
"for OPENAI_API_SERVER usage context.")
|
||||
self.max_num_seqs = 1024
|
||||
self.max_num_batched_tokens = 2048
|
||||
|
||||
def _override_v1_engine_config(self, engine_config: VllmConfig) -> None:
|
||||
"""
|
||||
Override the EngineConfig's configs based on the usage context for V1.
|
||||
"""
|
||||
assert envs.VLLM_USE_V1, "V1 is not enabled"
|
||||
# TODO (ywang96): Enable APC by default when VLM supports it.
|
||||
if engine_config.model_config.is_multimodal_model:
|
||||
logger.warning(
|
||||
"Prefix caching is currently not supported for multimodal "
|
||||
"models and has been disabled.")
|
||||
engine_config.cache_config.enable_prefix_caching = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncEngineArgs(EngineArgs):
|
||||
|
||||
@ -680,7 +680,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
if engine_config is None:
|
||||
engine_config = engine_args.create_engine_config()
|
||||
engine_config = engine_args.create_engine_config(usage_context)
|
||||
|
||||
executor_class = cls._get_executor_cls(engine_config)
|
||||
|
||||
|
||||
@ -568,7 +568,7 @@ class LLMEngine:
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_config = engine_args.create_engine_config()
|
||||
engine_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = cls._get_executor_cls(engine_config)
|
||||
# Create the LLM engine.
|
||||
engine = cls(
|
||||
|
||||
@ -111,7 +111,7 @@ class MQLLMEngine:
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
engine_config = engine_args.create_engine_config()
|
||||
engine_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = LLMEngine._get_executor_cls(engine_config)
|
||||
|
||||
use_async_sockets = engine_config.model_config.use_async_output_proc
|
||||
|
||||
@ -135,8 +135,8 @@ async def build_async_engine_client_from_engine_args(
|
||||
# TODO: fill out feature matrix.
|
||||
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
|
||||
|
||||
engine_config = engine_args.create_engine_config()
|
||||
engine_config = engine_args.create_engine_config(
|
||||
UsageContext.OPENAI_API_SERVER)
|
||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||
"uses_ray", False)
|
||||
|
||||
|
||||
@ -94,7 +94,7 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
# Create the engine configs.
|
||||
if engine_config is None:
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
else:
|
||||
vllm_config = engine_config
|
||||
|
||||
|
||||
@ -41,19 +41,6 @@ class EngineCore:
|
||||
executor_class: Type[GPUExecutor],
|
||||
usage_context: UsageContext,
|
||||
):
|
||||
# Override the configs for V1.
|
||||
# FIXME
|
||||
if usage_context == UsageContext.LLM_CLASS:
|
||||
vllm_config.scheduler_config.max_num_seqs = 1024
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 8192
|
||||
elif usage_context == UsageContext.OPENAI_API_SERVER:
|
||||
vllm_config.scheduler_config.max_num_seqs = 1024
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 2048
|
||||
|
||||
# TODO (ywang96): Enable APC by default when VLM supports it.
|
||||
if not vllm_config.model_config.is_multimodal_model:
|
||||
vllm_config.cache_config.enable_prefix_caching = True
|
||||
|
||||
assert vllm_config.model_config.task != "embedding"
|
||||
|
||||
logger.info("Initializing an LLM engine (v%s) with config: %s",
|
||||
|
||||
@ -82,7 +82,7 @@ class LLMEngine:
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
|
||||
# Create the engine configs.
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
vllm_config = engine_args.create_engine_config(usage_context)
|
||||
executor_class = cls._get_executor_cls(vllm_config)
|
||||
|
||||
if VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user