[Misc] Further clean up chunked prefill and prefix caching init (#29186)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-22 19:34:15 +08:00 committed by GitHub
parent 8e22da1d7f
commit 5a4802588e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 33 additions and 29 deletions

View File

@ -279,7 +279,7 @@ def test_prefix_cache_default():
args = parser.parse_args([]) args = parser.parse_args([])
engine_args = EngineArgs.from_cli_args(args=args) engine_args = EngineArgs.from_cli_args(args=args)
assert not engine_args.enable_prefix_caching, "prefix caching defaults to off." assert engine_args.enable_prefix_caching, "prefix caching should default to on."
# with flag to turn it on. # with flag to turn it on.
args = parser.parse_args(["--enable-prefix-caching"]) args = parser.parse_args(["--enable-prefix-caching"])

View File

@ -76,11 +76,11 @@ def test_get_num_unfinished_requests():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs", "enable_prefix_caching, prompt_logprobs",
[ [
(None, None), (False, None),
(True, 5), (True, 5),
], ],
) )
def test_schedule(enable_prefix_caching: bool | None, prompt_logprobs: int | None): def test_schedule(enable_prefix_caching: bool, prompt_logprobs: int | None):
"""Test scheduling. """Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
""" """
@ -582,12 +582,12 @@ def test_check_stop_min_tokens():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs", "enable_prefix_caching, prompt_logprobs",
[ [
(None, None), (False, None),
(True, 5), (True, 5),
], ],
) )
def test_schedule_concurrent_batches( def test_schedule_concurrent_batches(
enable_prefix_caching: bool | None, prompt_logprobs: int | None enable_prefix_caching: bool, prompt_logprobs: int | None
): ):
scheduler = create_scheduler( scheduler = create_scheduler(
max_num_batched_tokens=1024, max_num_batched_tokens=1024,
@ -1425,7 +1425,7 @@ def create_scheduler_with_priority(
model: str = "facebook/opt-125m", model: str = "facebook/opt-125m",
max_num_seqs: int = 16, max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192, max_num_batched_tokens: int = 8192,
enable_prefix_caching: bool | None = None, enable_prefix_caching: bool = False,
long_prefill_token_threshold: int = 0, long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False, disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False, use_kv_connector: bool = False,
@ -1444,7 +1444,7 @@ def create_scheduler_with_priority(
max_num_batch_tokens: max num tokens to batch max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config enable_prefix_caching: optionally force APC config
(True/False) or use default (True/False) or use default
(None) (False)
Returns: Returns:
{class}`Scheduler` instance with priority scheduling {class}`Scheduler` instance with priority scheduling
@ -1467,17 +1467,12 @@ def create_scheduler_with_priority(
seed=42, seed=42,
) )
# Cache config, optionally force APC # Cache config, optionally force APC
kwargs_cache = (
{}
if enable_prefix_caching is None
else {"enable_prefix_caching": enable_prefix_caching}
)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=block_size, block_size=block_size,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
swap_space=0, swap_space=0,
cache_dtype="auto", cache_dtype="auto",
**kwargs_cache, enable_prefix_caching=enable_prefix_caching,
) )
kv_transfer_config = ( kv_transfer_config = (
KVTransferConfig( KVTransferConfig(

View File

@ -42,7 +42,7 @@ def create_scheduler(
model: str = "facebook/opt-125m", model: str = "facebook/opt-125m",
max_num_seqs: int = 16, max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192, max_num_batched_tokens: int = 8192,
enable_prefix_caching: bool | None = None, enable_prefix_caching: bool = False,
long_prefill_token_threshold: int = 0, long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False, disable_chunked_mm_input: bool = False,
use_kv_connector: None | bool | MockKVConfig = None, use_kv_connector: None | bool | MockKVConfig = None,
@ -63,7 +63,7 @@ def create_scheduler(
max_num_batch_tokens: max num tokens to batch max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config enable_prefix_caching: optionally force APC config
(True/False) or use default (True/False) or use default
(None) (False)
Returns: Returns:
{class}`Scheduler` instance {class}`Scheduler` instance
@ -87,17 +87,12 @@ def create_scheduler(
skip_tokenizer_init=skip_tokenizer_init, skip_tokenizer_init=skip_tokenizer_init,
) )
# Cache config, optionally force APC # Cache config, optionally force APC
kwargs_cache = (
{}
if enable_prefix_caching is None
else {"enable_prefix_caching": enable_prefix_caching}
)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=block_size, block_size=block_size,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
swap_space=0, swap_space=0,
cache_dtype="auto", cache_dtype="auto",
**kwargs_cache, enable_prefix_caching=enable_prefix_caching,
) )
kv_transfer_config = None kv_transfer_config = None
if isinstance(use_kv_connector, MockKVConfig): if isinstance(use_kv_connector, MockKVConfig):

View File

@ -73,8 +73,8 @@ class CacheConfig:
sliding_window: int | None = None sliding_window: int | None = None
"""Sliding window size for the KV cache. This is primarily set in """Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here.""" `ModelConfig` and that value should be manually duplicated here."""
enable_prefix_caching: bool | None = None enable_prefix_caching: bool = True
"""Whether to enable prefix caching. Enabled by default for V1.""" """Whether to enable prefix caching."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n """Set the hash algorithm for prefix caching:\n
- "sha256" uses Pickle for object serialization before hashing.\n - "sha256" uses Pickle for object serialization before hashing.\n

View File

@ -425,7 +425,7 @@ class EngineArgs:
ParallelConfig.max_parallel_loading_workers ParallelConfig.max_parallel_loading_workers
) )
block_size: BlockSize | None = CacheConfig.block_size block_size: BlockSize | None = CacheConfig.block_size
enable_prefix_caching: bool | None = CacheConfig.enable_prefix_caching enable_prefix_caching: bool | None = None
prefix_caching_hash_algo: PrefixCachingHashAlgo = ( prefix_caching_hash_algo: PrefixCachingHashAlgo = (
CacheConfig.prefix_caching_hash_algo CacheConfig.prefix_caching_hash_algo
) )
@ -1975,10 +1975,11 @@ class EngineArgs:
if self.prefill_context_parallel_size > 1: if self.prefill_context_parallel_size > 1:
default_chunked_prefill = False default_chunked_prefill = False
default_prefix_caching = False default_prefix_caching = False
logger.warning( logger.warning_once(
"--prefill-context-parallel-size > 1 is not compatible with " "--prefill-context-parallel-size > 1 is not compatible with "
"chunked prefill and prefix caching now. Chunked prefill " "chunked prefill and prefix caching now. Chunked prefill "
"and prefix caching have been disabled by default." "and prefix caching have been disabled by default.",
scope="local",
) )
if self.enable_chunked_prefill is None: if self.enable_chunked_prefill is None:
@ -1988,15 +1989,27 @@ class EngineArgs:
"%s chunked prefill by default", "%s chunked prefill by default",
"Enabling" if default_chunked_prefill else "Disabling", "Enabling" if default_chunked_prefill else "Disabling",
) )
elif (
model_config.runner_type == "generate"
and not self.enable_chunked_prefill
and default_chunked_prefill
):
logger.warning_once(
"This model does not officially support disabling chunked prefill. "
"Disabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
scope="local",
)
elif ( elif (
model_config.runner_type == "pooling" model_config.runner_type == "pooling"
and self.enable_chunked_prefill and self.enable_chunked_prefill
and not default_chunked_prefill and not default_chunked_prefill
): ):
logger.warning( logger.warning_once(
"This model does not officially support chunked prefill. " "This model does not officially support chunked prefill. "
"Enabling this manually may cause the engine to crash " "Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.", "or produce incorrect outputs.",
scope="local",
) )
if self.enable_prefix_caching is None: if self.enable_prefix_caching is None:
@ -2011,10 +2024,11 @@ class EngineArgs:
and self.enable_prefix_caching and self.enable_prefix_caching
and not default_prefix_caching and not default_prefix_caching
): ):
logger.warning( logger.warning_once(
"This model does not officially support prefix caching. " "This model does not officially support prefix caching. "
"Enabling this manually may cause the engine to crash " "Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.", "or produce incorrect outputs.",
scope="local",
) )
world_size = self.pipeline_parallel_size * self.tensor_parallel_size world_size = self.pipeline_parallel_size * self.tensor_parallel_size

View File

@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
enable_caching=bool(self.cache_config.enable_prefix_caching), enable_caching=self.cache_config.enable_prefix_caching,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
log_stats=self.log_stats, log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,