mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 18:14:03 +08:00
[Misc] Enable chunked prefill by default for long context models (#6666)
This commit is contained in:
parent
c5e8330997
commit
729171ae58
@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TokenizerPoolConfig)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -17,6 +18,8 @@ if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def nullable_str(val: str):
|
||||
if not val or val == "None":
|
||||
@ -95,7 +98,7 @@ class EngineArgs:
|
||||
preemption_mode: Optional[str] = None
|
||||
|
||||
scheduler_delay_factor: float = 0.0
|
||||
enable_chunked_prefill: bool = False
|
||||
enable_chunked_prefill: Optional[bool] = None
|
||||
|
||||
guided_decoding_backend: str = 'outlines'
|
||||
# Speculative decoding configuration.
|
||||
@ -508,7 +511,10 @@ class EngineArgs:
|
||||
'prompt latency) before scheduling next prompt.')
|
||||
parser.add_argument(
|
||||
'--enable-chunked-prefill',
|
||||
action='store_true',
|
||||
action=StoreBoolean,
|
||||
default=EngineArgs.enable_chunked_prefill,
|
||||
nargs="?",
|
||||
const="True",
|
||||
help='If set, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens.')
|
||||
|
||||
@ -728,6 +734,38 @@ class EngineArgs:
|
||||
ray_workers_use_nsight=self.ray_workers_use_nsight,
|
||||
distributed_executor_backend=self.distributed_executor_backend)
|
||||
|
||||
max_model_len = model_config.max_model_len
|
||||
use_long_context = max_model_len > 32768
|
||||
if self.enable_chunked_prefill is None:
|
||||
# If not explicitly set, enable chunked prefill by default for
|
||||
# long context (> 32K) models. This is to avoid OOM errors in the
|
||||
# initial memory profiling phase.
|
||||
if use_long_context:
|
||||
is_gpu = device_config.device_type == "cuda"
|
||||
use_sliding_window = (model_config.get_sliding_window()
|
||||
is not None)
|
||||
use_spec_decode = self.speculative_model is not None
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora
|
||||
and not self.enable_prompt_adapter
|
||||
and not self.enable_prefix_caching):
|
||||
self.enable_chunked_prefill = True
|
||||
logger.warning(
|
||||
"Chunked prefill is enabled by default for models with "
|
||||
"max_model_len > 32K. Currently, chunked prefill might "
|
||||
"not work with some features or models. If you "
|
||||
"encounter any issues, please disable chunked prefill "
|
||||
"by setting --enable-chunked-prefill=False.")
|
||||
if self.enable_chunked_prefill is None:
|
||||
self.enable_chunked_prefill = False
|
||||
|
||||
if not self.enable_chunked_prefill and use_long_context:
|
||||
logger.warning(
|
||||
"The model has a long context length (%s). This may cause OOM "
|
||||
"errors during the initial memory profiling phase, or result "
|
||||
"in low performance due to small KV cache space. Consider "
|
||||
"setting --max-model-len to a smaller value.", max_model_len)
|
||||
|
||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
@ -843,6 +881,18 @@ class AsyncEngineArgs(EngineArgs):
|
||||
return parser
|
||||
|
||||
|
||||
class StoreBoolean(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if values.lower() == "true":
|
||||
setattr(namespace, self.dest, True)
|
||||
elif values.lower() == "false":
|
||||
setattr(namespace, self.dest, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid boolean value: {values}. "
|
||||
"Expected 'true' or 'false'.")
|
||||
|
||||
|
||||
# These functions are used by sphinx to build the documentation
|
||||
def _engine_args_parser():
|
||||
return EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user