[Chore] Rename SchedulerConfig.chunked_prefill_enabled (#28735)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-15 02:39:57 +08:00 committed by GitHub
parent 67187554dd
commit e2741f6cbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 21 additions and 19 deletions

View File

@ -2282,7 +2282,6 @@ def _validate_chunked_prefill_settings_for_encoder_decoder(
) -> None:
"""Validate chunked prefill settings in the scheduler config for
encoder-decoder models."""
assert scheduler_config.chunked_prefill_enabled is expect_enabled
assert scheduler_config.enable_chunked_prefill is expect_enabled
if is_encoder_decoder:
# Encoder-decoder models should automatically disable chunked multimodal

View File

@ -272,7 +272,7 @@ def test_speculators_model_integration(
@pytest.mark.parametrize(
["model_setup", "mm_enabled", "chunked_prefill_enabled"],
["model_setup", "mm_enabled", "enable_chunked_prefill"],
[
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False, False),
pytest.param(
@ -358,7 +358,7 @@ def test_eagle_correctness(
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
chunked_prefill_enabled: bool,
enable_chunked_prefill: bool,
attn_backend: str,
):
if attn_backend == "TREE_ATTN":
@ -396,9 +396,7 @@ def test_eagle_correctness(
method, model_name, spec_model_name, tp_size = model_setup
max_model_len = 2048
max_num_batched_tokens = max_model_len
if chunked_prefill_enabled:
max_num_batched_tokens = 128
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
ref_llm = LLM(
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
@ -420,7 +418,7 @@ def test_eagle_correctness(
},
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=chunked_prefill_enabled,
enable_chunked_prefill=enable_chunked_prefill,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0

View File

@ -571,7 +571,7 @@ def test_encoder_instance_zero_kv_cache(
)
# Check 5: Verify chunked prefill is disabled
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
assert not vllm_config.scheduler_config.enable_chunked_prefill, (
"Encoder instance should disable chunked prefill (no KV cache)"
)

View File

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from typing_extensions import Self, deprecated
from vllm.config.utils import config
from vllm.logger import init_logger
@ -233,6 +233,11 @@ class SchedulerConfig:
)
@property
@deprecated(
"`SchedulerConfig.chunked_prefill_enabled` has been renamed to "
"`SchedulerConfig.enable_chunked_prefill`. "
"The old name will be removed in v0.12."
)
def chunked_prefill_enabled(self) -> bool:
return self.enable_chunked_prefill
@ -244,7 +249,7 @@ class SchedulerConfig:
def _verify_args(self) -> Self:
if (
self.max_num_batched_tokens < self.max_model_len
and not self.chunked_prefill_enabled
and not self.enable_chunked_prefill
):
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
@ -271,7 +276,7 @@ class SchedulerConfig:
)
if self.max_num_partial_prefills > 1:
if not self.chunked_prefill_enabled:
if not self.enable_chunked_prefill:
raise ValueError(
"Chunked prefill must be enabled to set "
"max_num_partial_prefills > 1."

View File

@ -411,7 +411,7 @@ class VllmConfig:
if (
self.model_config is not None
and self.scheduler_config.chunked_prefill_enabled
and self.scheduler_config.enable_chunked_prefill
and self.model_config.dtype == torch.float32
and current_platform.get_device_capability() == (7, 5)
):
@ -584,7 +584,7 @@ class VllmConfig:
):
for reason in disable_chunked_prefill_reasons:
logger.info(reason)
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.long_prefill_token_threshold = 0
if self.cache_config is not None:
@ -1026,7 +1026,7 @@ class VllmConfig:
f"seed={self.model_config.seed}, "
f"served_model_name={self.model_config.served_model_name}, "
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"enable_chunked_prefill={self.scheduler_config.enable_chunked_prefill}, " # noqa
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}"
)

View File

@ -192,7 +192,7 @@ class CpuPlatform(Platform):
scheduler_config = vllm_config.scheduler_config
if (
scheduler_config.chunked_prefill_enabled
scheduler_config.enable_chunked_prefill
or cache_config.enable_prefix_caching
) and cache_config.cache_dtype != "auto":
raise RuntimeError(

View File

@ -497,7 +497,7 @@ class Scheduler(SchedulerInterface):
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if (
not self.scheduler_config.chunked_prefill_enabled
not self.scheduler_config.enable_chunked_prefill
and num_new_tokens > token_budget
):
self.waiting.pop_request()

View File

@ -124,7 +124,7 @@ class EngineCore:
# Encoder models without KV cache don't support
# chunked prefill. But do SSM models?
logger.info("Disabling chunked prefill for model without KVCache")
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.enable_chunked_prefill = False
scheduler_block_size = (
vllm_config.cache_config.block_size

View File

@ -2031,7 +2031,7 @@ class GPUModelRunner(
supported_tasks = list(model.pooler.get_supported_tasks())
if self.scheduler_config.chunked_prefill_enabled:
if self.scheduler_config.enable_chunked_prefill:
if "token_embed" in supported_tasks:
supported_tasks.remove("token_embed")
if "token_classify" in supported_tasks:
@ -3825,7 +3825,7 @@ class GPUModelRunner(
supported_pooling_tasks = self.get_supported_pooling_tasks()
if not supported_pooling_tasks:
if self.scheduler_config.chunked_prefill_enabled:
if self.scheduler_config.enable_chunked_prefill:
raise RuntimeError(
f"Model {self.model_config.model} does not support "
"any pooling tasks with chunked prefill enabled. "