From 729171ae58cce74753e628056bda2b6df6b65f4a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Jul 2024 20:03:13 -0700 Subject: [PATCH] [Misc] Enable chunked prefill by default for long context models (#6666) --- vllm/engine/arg_utils.py | 54 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4db071e4caef4..c34b88b53f656 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig) from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -17,6 +18,8 @@ if TYPE_CHECKING: from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) +logger = init_logger(__name__) + def nullable_str(val: str): if not val or val == "None": @@ -95,7 +98,7 @@ class EngineArgs: preemption_mode: Optional[str] = None scheduler_delay_factor: float = 0.0 - enable_chunked_prefill: bool = False + enable_chunked_prefill: Optional[bool] = None guided_decoding_backend: str = 'outlines' # Speculative decoding configuration. @@ -508,7 +511,10 @@ class EngineArgs: 'prompt latency) before scheduling next prompt.') parser.add_argument( '--enable-chunked-prefill', - action='store_true', + action=StoreBoolean, + default=EngineArgs.enable_chunked_prefill, + nargs="?", + const="True", help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens.') @@ -728,6 +734,38 @@ class EngineArgs: ray_workers_use_nsight=self.ray_workers_use_nsight, distributed_executor_backend=self.distributed_executor_backend) + max_model_len = model_config.max_model_len + use_long_context = max_model_len > 32768 + if self.enable_chunked_prefill is None: + # If not explicitly set, enable chunked prefill by default for + # long context (> 32K) models. This is to avoid OOM errors in the + # initial memory profiling phase. + if use_long_context: + is_gpu = device_config.device_type == "cuda" + use_sliding_window = (model_config.get_sliding_window() + is not None) + use_spec_decode = self.speculative_model is not None + if (is_gpu and not use_sliding_window and not use_spec_decode + and not self.enable_lora + and not self.enable_prompt_adapter + and not self.enable_prefix_caching): + self.enable_chunked_prefill = True + logger.warning( + "Chunked prefill is enabled by default for models with " + "max_model_len > 32K. Currently, chunked prefill might " + "not work with some features or models. If you " + "encounter any issues, please disable chunked prefill " + "by setting --enable-chunked-prefill=False.") + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = False + + if not self.enable_chunked_prefill and use_long_context: + logger.warning( + "The model has a long context length (%s). This may cause OOM " + "errors during the initial memory profiling phase, or result " + "in low performance due to small KV cache space. Consider " + "setting --max-model-len to a smaller value.", max_model_len) + speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -843,6 +881,18 @@ class AsyncEngineArgs(EngineArgs): return parser +class StoreBoolean(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + if values.lower() == "true": + setattr(namespace, self.dest, True) + elif values.lower() == "false": + setattr(namespace, self.dest, False) + else: + raise ValueError(f"Invalid boolean value: {values}. " + "Expected 'true' or 'false'.") + + # These functions are used by sphinx to build the documentation def _engine_args_parser(): return EngineArgs.add_cli_args(FlexibleArgumentParser())