diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 65bee85e7a1e..b7124ebc1b0f 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -4,7 +4,7 @@ import tempfile import depyf -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationLevel temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): @@ -34,8 +34,7 @@ with depyf.prepare_debug(temp_dir): # all the control llm = LLM(model="google/gemma-2b", enforce_eager=True, - compilation_config=CompilationConfig( - level=CompilationLevel.DYNAMO_AS_IS)) + compilation_config={"level": CompilationLevel.DYNAMO_AS_IS}) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): prompt = output.prompt diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 86b0b6893f1d..2446a64a02eb 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,4 +1,5 @@ import itertools +import json import warnings from contextlib import contextmanager from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, @@ -9,6 +10,7 @@ from tqdm import tqdm from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) +from vllm.config import CompilationConfig from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -107,13 +109,16 @@ class LLM: hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the HuggingFace config. + compilation_config: Either an integer or a dictionary. If it is an integer, + it is used as the level of compilation optimization. If it is a dictionary, + it can specify the full compilation configuration. **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) Note: This class is intended to be used for offline inference. For online serving, use the :class:`~vllm.AsyncLLMEngine` class instead. - """ + """ # noqa DEPRECATE_LEGACY: ClassVar[bool] = False """A flag to toggle whether to deprecate the legacy generate/encode API.""" @@ -166,6 +171,7 @@ class LLM: # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, **kwargs, ) -> None: ''' @@ -178,6 +184,12 @@ class LLM: if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True + if compilation_config is not None: + compilation_config_instance = CompilationConfig.from_cli( + json.dumps(compilation_config)) + else: + compilation_config_instance = None + engine_args = EngineArgs( model=model, task=task, @@ -202,6 +214,7 @@ class LLM: hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, **kwargs, ) # Logic to switch between engines is done at runtime instead of import