[Misc] Enable chunked prefill by default for long context models (#6666)

This commit is contained in:
Woosuk Kwon 2024-07-22 20:03:13 -07:00 committed by GitHub
parent c5e8330997
commit 729171ae58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
@ -17,6 +18,8 @@ if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
logger = init_logger(__name__)
def nullable_str(val: str):
if not val or val == "None":
@ -95,7 +98,7 @@ class EngineArgs:
preemption_mode: Optional[str] = None
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
enable_chunked_prefill: Optional[bool] = None
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
@ -508,7 +511,10 @@ class EngineArgs:
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
action='store_true',
action=StoreBoolean,
default=EngineArgs.enable_chunked_prefill,
nargs="?",
const="True",
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
@ -728,6 +734,38 @@ class EngineArgs:
ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend)
max_model_len = model_config.max_model_len
use_long_context = max_model_len > 32768
if self.enable_chunked_prefill is None:
# If not explicitly set, enable chunked prefill by default for
# long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase.
if use_long_context:
is_gpu = device_config.device_type == "cuda"
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
and not self.enable_prefix_caching):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models with "
"max_model_len > 32K. Currently, chunked prefill might "
"not work with some features or models. If you "
"encounter any issues, please disable chunked prefill "
"by setting --enable-chunked-prefill=False.")
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = False
if not self.enable_chunked_prefill and use_long_context:
logger.warning(
"The model has a long context length (%s). This may cause OOM "
"errors during the initial memory profiling phase, or result "
"in low performance due to small KV cache space. Consider "
"setting --max-model-len to a smaller value.", max_model_len)
speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
@ -843,6 +881,18 @@ class AsyncEngineArgs(EngineArgs):
return parser
class StoreBoolean(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
setattr(namespace, self.dest, True)
elif values.lower() == "false":
setattr(namespace, self.dest, False)
else:
raise ValueError(f"Invalid boolean value: {values}. "
"Expected 'true' or 'false'.")
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(FlexibleArgumentParser())