Improve enable chunked_prefill & prefix_caching logic. (#26623)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
wang.yuqi 2025-11-28 14:05:48 +08:00 committed by GitHub
parent 37b15e97e8
commit f4b76056ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 456 additions and 133 deletions

View File

@ -105,8 +105,6 @@ def test_embed_models(
def test_non_causal_models(
hf_runner, vllm_runner, example_prompts, model: str, dtype: str
) -> None:
with vllm_runner(
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
) as vllm_model:
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert not cache_config.enable_prefix_caching

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
from dataclasses import MISSING, Field, asdict, dataclass, field
from unittest.mock import patch
@ -602,6 +602,244 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
@pytest.mark.parametrize(
("model_id", "expected_attn_type", "expected_result", "reason"),
[
# pooling models
(
"jason9693/Qwen2.5-1.5B-apeach",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
),
(
"Qwen/Qwen3-Embedding-0.6B",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
),
(
"Qwen/Qwen2.5-Math-PRM-7B",
"decoder",
False,
"Pooling models with step pooling does not support chunked prefill.",
),
(
"internlm/internlm2-1_8b-reward",
"decoder",
False,
"Pooling models with all pooling does not support chunked prefill.",
),
(
"BAAI/bge-base-en",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
(
"boltuix/NeuroBERT-NER",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
(
"papluca/xlm-roberta-base-language-detection",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
(
"intfloat/e5-small",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
# multimodal models
(
"openai/clip-vit-base-patch32",
"decoder",
True,
"Pooling models with causal attn and last pooling support chunked prefill.",
),
(
"google/siglip-base-patch16-224",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support chunked prefill.",
),
# generate models
(
"Qwen/Qwen3-0.6B",
"decoder",
True,
"Generative models support chunked prefill.",
),
(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid",
True,
"Generative models support chunked prefill.",
),
(
"ibm-granite/granite-4.0-h-small",
"hybrid",
True,
"Generative models support chunked prefill.",
),
(
"state-spaces/mamba-130m-hf",
"attention_free",
True,
"Generative models support chunked prefill.",
),
# encoder_decoder models
(
"openai/whisper-small",
"encoder_decoder",
False,
"Encoder decoder models does not support chunked prefill.",
),
],
)
def test_is_chunked_prefill_supported(
model_id: str,
expected_attn_type: str,
expected_result: bool,
reason: str,
caplog_vllm,
):
model_config = ModelConfig(model_id, trust_remote_code=True)
assert model_config.attn_type == expected_attn_type
with caplog_vllm.at_level(level=logging.DEBUG):
assert model_config.is_chunked_prefill_supported == expected_result
assert reason in caplog_vllm.text
@pytest.mark.parametrize(
("model_id", "expected_attn_type", "expected_result", "reason"),
[
# pooling models
(
"jason9693/Qwen2.5-1.5B-apeach",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
),
(
"Qwen/Qwen3-Embedding-0.6B",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
),
(
"Qwen/Qwen2.5-Math-PRM-7B",
"decoder",
False,
"Pooling models with step pooling does not support prefix caching.",
),
(
"internlm/internlm2-1_8b-reward",
"decoder",
False,
"Pooling models with all pooling does not support prefix caching.",
),
(
"BAAI/bge-base-en",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
(
"boltuix/NeuroBERT-NER",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
(
"papluca/xlm-roberta-base-language-detection",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
(
"intfloat/e5-small",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
# multimodal models
(
"openai/clip-vit-base-patch32",
"decoder",
True,
"Pooling models with causal attn and last pooling support prefix caching.",
),
(
"google/siglip-base-patch16-224",
"encoder_only",
False,
"Pooling models with bidirectional attn does not support prefix caching.",
),
# generate models
(
"Qwen/Qwen3-0.6B",
"decoder",
True,
"Generative models support prefix caching.",
),
(
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"hybrid",
False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
),
(
"ibm-granite/granite-4.0-h-small",
"hybrid",
False,
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
),
(
"state-spaces/mamba-130m-hf",
"attention_free",
False,
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501
),
# encoder_decoder models
(
"openai/whisper-small",
"encoder_decoder",
False,
"Encoder decoder models does not support prefix caching.",
),
],
)
def test_is_prefix_caching_supported(
model_id: str,
expected_attn_type: str,
expected_result: bool,
reason: str,
caplog_vllm,
):
model_config = ModelConfig(model_id, trust_remote_code=True)
assert model_config.attn_type == expected_attn_type
with caplog_vllm.at_level(level=logging.DEBUG):
assert model_config.is_prefix_caching_supported == expected_result
assert reason in caplog_vllm.text
@pytest.mark.parametrize(
("backend", "custom_ops", "expected"),
[

View File

@ -107,6 +107,10 @@ _RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
"draft": [],
}
AttnTypeStr = Literal[
"decoder", "encoder", "encoder_only", "encoder_decoder", "attention_free", "hybrid"
]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@ -1752,6 +1756,111 @@ class ModelConfig:
logger.info("Using max model len %s", max_model_len)
return max_model_len
@property
def attn_type(self) -> AttnTypeStr:
if self.pooler_config is not None:
pooling_type = self._model_info.default_pooling_type.lower()
if pooling_type == "cls":
return "encoder_only"
else:
is_causal = getattr(self.hf_config, "is_causal", True)
return "encoder_only" if not is_causal else self._model_info.attn_type
elif self.is_hybrid:
return "hybrid"
elif self.is_attention_free:
return "attention_free"
elif self.is_encoder_decoder:
return "encoder_decoder"
else:
return "decoder"
@property
def is_chunked_prefill_supported(self) -> bool:
attn_type = self.attn_type
if self.pooler_config is not None:
# for pooling models
if attn_type == "encoder_only":
logger.debug(
"Pooling models with bidirectional attn does not support "
"chunked prefill."
)
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support chunked prefill.",
pooling_type,
)
return False
else:
# pooling_type == "last"
logger.debug(
"Pooling models with causal attn and last pooling support "
"chunked prefill."
)
return True
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder"
else:
if attn_type == "encoder_decoder":
logger.debug("Encoder decoder models does not support chunked prefill.")
return False
logger.debug("Generative models support chunked prefill.")
return True
@property
def is_prefix_caching_supported(self) -> bool:
attn_type = self.attn_type
if self.pooler_config is not None:
# for pooling models
if attn_type == "encoder_only":
logger.debug(
"Pooling models with bidirectional attn does not "
"support prefix caching."
)
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support prefix caching.",
pooling_type,
)
return False
else:
# pooling_type == "last"
logger.debug(
"Pooling models with causal attn and last pooling support "
"prefix caching."
)
return True
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return False
else:
if attn_type == "hybrid":
logger.debug(
"Hybrid models does not support prefix caching since the feature "
"is still experimental."
)
return False
elif attn_type == "attention_free":
logger.debug(
"Attention free models does not support prefix caching since the "
"feature is still experimental."
)
return False
elif attn_type == "encoder_decoder":
logger.debug("Encoder decoder models does not support prefix caching.")
return False
else: # attn_type == "decoder"
logger.debug("Generative models support prefix caching.")
return True
def is_model_moe(
self,
) -> bool:

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from typing import Any, Literal
from pydantic.dataclasses import dataclass
@ -11,13 +11,15 @@ from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
@config
@dataclass
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
pooling_type: str | None = None
pooling_type: PoolingTypeStr | None = None
"""
The pooling method of the pooling model. This should be a key in
[`vllm.model_executor.layers.pooler.PoolingType`][].

View File

@ -721,65 +721,27 @@ class VllmConfig:
"correctness and to realize prefill savings. "
)
disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY
if self.model_config:
if self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
'Only "last" pooling supports chunked '
"prefill and prefix caching; disabling both."
)
if not getattr(self.model_config.hf_config, "is_causal", True):
disable_chunked_prefill_reasons.append(
"Only models using causal attention support chunked "
"prefill and prefix caching; disabling both."
)
elif self.model_config.is_encoder_decoder:
from vllm.multimodal import MULTIMODAL_REGISTRY
self.scheduler_config.max_num_encoder_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens,
)
if (
self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
# Final off-switch for CP/APC:
# Disable for (a) collected blockers, (b) encoderdecoder, or
# (c) explicit CP=False when APC wasn't requested.
# Do NOT disable merely because the resolved CP flag is False.
apc_requested = (
self.cache_config is not None and self.cache_config.enable_prefix_caching
)
if (
disable_chunked_prefill_reasons
or (self.model_config is not None and self.model_config.is_encoder_decoder)
or (
self.scheduler_config.enable_chunked_prefill is False
and not apc_requested
self.scheduler_config.max_num_encoder_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
)
):
for reason in disable_chunked_prefill_reasons:
logger.info(reason)
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.long_prefill_token_threshold = 0
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens,
)
if (
self.model_config.architecture == "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'."
)
if (
self.kv_events_config is not None

View File

@ -1349,30 +1349,10 @@ class EngineArgs:
self.tokenizer = model_config.tokenizer
self._check_feature_supported(model_config)
# Set default arguments for V1 Engine.
self._set_default_args(usage_context, model_config)
# Disable chunked prefill and prefix caching for:
# POWER (ppc64le)/s390x/RISCV CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC,
CpuArchEnum.S390X,
CpuArchEnum.RISCV,
):
logger.info(
"Chunked prefill is not supported for ARM and POWER, "
"S390X and RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_chunked_prefill = False
logger.info(
"Prefix caching is not supported for ARM and POWER, "
"S390X and RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_prefix_caching = False
assert self.enable_chunked_prefill is not None
self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
self._set_default_max_num_seqs_and_batched_tokens_args(
usage_context, model_config
)
sliding_window: int | None = None
if not is_interleaved(model_config.hf_text_config):
@ -1805,34 +1785,6 @@ class EngineArgs:
)
_raise_unsupported_error(feature_name=name)
@classmethod
def get_chunked_prefill_prefix_caching_defaults(
cls,
model_config: ModelConfig,
) -> tuple[bool, bool]:
if model_config.runner_type != "pooling":
default_chunked_prefill = True
# Disable prefix caching default for hybrid models and mamba-only
# models since the feature is still experimental.
default_prefix_caching = not (
model_config.is_hybrid or model_config.is_attention_free
)
else:
assert model_config.pooler_config is not None
pooling_type = model_config.pooler_config.pooling_type
incremental_prefill_supported = (
pooling_type is not None
and pooling_type.lower() == "last"
and getattr(model_config.hf_config, "is_causal", True)
)
default_chunked_prefill = incremental_prefill_supported
default_prefix_caching = incremental_prefill_supported
return default_chunked_prefill, default_prefix_caching
@classmethod
def get_batch_defaults(
cls,
@ -1916,14 +1868,11 @@ class EngineArgs:
return default_max_num_batched_tokens, default_max_num_seqs
def _set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
def _set_default_chunked_prefill_and_prefix_caching_args(
self, model_config: ModelConfig
) -> None:
"""Set Default Arguments for V1 Engine."""
(
default_chunked_prefill,
default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
default_chunked_prefill = model_config.is_chunked_prefill_supported
default_prefix_caching = model_config.is_prefix_caching_supported
if self.prefill_context_parallel_size > 1:
default_chunked_prefill = False
@ -1984,6 +1933,29 @@ class EngineArgs:
scope="local",
)
# Disable chunked prefill and prefix caching for:
# POWER (ppc64le)/s390x/RISCV CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC,
CpuArchEnum.S390X,
CpuArchEnum.RISCV,
):
logger.info(
"Chunked prefill is not supported for ARM and POWER, "
"S390X and RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_chunked_prefill = False
logger.info(
"Prefix caching is not supported for ARM and POWER, "
"S390X and RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_prefix_caching = False
def _set_default_max_num_seqs_and_batched_tokens_args(
self, usage_context: UsageContext, model_config: ModelConfig
):
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
(
default_max_num_batched_tokens,

View File

@ -32,7 +32,7 @@ from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type
from .interfaces_base import attn_type, default_pooling_type
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@ -432,7 +432,6 @@ class BertModel(nn.Module, SupportsQuant):
return loaded_params
@default_pooling_type("ALL")
class BertPoolingModel(BertModel):
is_pooling_model = True
@ -864,6 +863,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
)
@attn_type("encoder_only")
@default_pooling_type("ALL")
class BertForTokenClassification(nn.Module):
is_pooling_model = True

View File

@ -19,10 +19,14 @@ from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr
from vllm.model_executor.layers.pooler import Pooler
else:
VllmConfig = Any
Pooler = Any
PoolingTypeStr = Any
AttnTypeStr = Any
logger = init_logger(__name__)
@ -165,7 +169,7 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
MRO of your model class.
"""
default_pooling_type: ClassVar[str] = "LAST"
default_pooling_type: ClassVar[PoolingTypeStr] = "LAST"
"""
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
to use by default.
@ -175,6 +179,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
decorator to conveniently set this field.
"""
attn_type: ClassVar[AttnTypeStr] = "decoder"
"""
Indicates the
[vllm.config.model.ModelConfig.attn_type][]
to use by default.
You can use the
[vllm.model_executor.models.interfaces_base.attn_type][]
decorator to conveniently set this field.
"""
pooler: Pooler
"""The pooler is only called on TP rank 0."""
@ -199,7 +214,7 @@ def is_pooling_model(
_T = TypeVar("_T", bound=type[nn.Module])
def default_pooling_type(pooling_type: str):
def default_pooling_type(pooling_type: PoolingTypeStr):
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
def func(model: _T) -> _T:
@ -209,5 +224,19 @@ def default_pooling_type(pooling_type: str):
return func
def get_default_pooling_type(model: type[object] | object) -> str:
def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr:
return getattr(model, "default_pooling_type", "LAST")
def attn_type(attn_type: AttnTypeStr):
"""Decorator to set `VllmModelForPooling.attn_type`."""
def func(model: _T) -> _T:
model.attn_type = attn_type # type: ignore
return model
return func
def get_attn_type(model: type[object] | object) -> AttnTypeStr:
return getattr(model, "attn_type", "decoder")

View File

@ -28,7 +28,7 @@ from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding
from .interfaces_base import default_pooling_type
from .interfaces_base import attn_type, default_pooling_type
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@ -396,6 +396,7 @@ class ModernBertPredictionHead(nn.Module):
return self.norm(self.act(self.dense(hidden_states)))
@attn_type("encoder_only")
@default_pooling_type("ALL")
class ModernBertForTokenClassification(nn.Module):
is_pooling_model = True

View File

@ -17,7 +17,7 @@ from collections.abc import Callable, Set
from dataclasses import asdict, dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import TypeVar
from typing import TYPE_CHECKING, Any, TypeVar
import torch.nn as nn
import transformers
@ -33,6 +33,14 @@ from vllm.logging_utils import logtime
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.config.model import AttnTypeStr
from vllm.config.pooler import PoolingTypeStr
else:
AttnTypeStr = Any
PoolingTypeStr = Any
from .interfaces import (
has_inner_state,
has_noops,
@ -47,6 +55,7 @@ from .interfaces import (
supports_transcription,
)
from .interfaces_base import (
get_attn_type,
get_default_pooling_type,
is_pooling_model,
is_text_generation_model,
@ -509,7 +518,8 @@ class _ModelInfo:
architecture: str
is_text_generation_model: bool
is_pooling_model: bool
default_pooling_type: str
attn_type: AttnTypeStr
default_pooling_type: PoolingTypeStr
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input_only: bool
@ -530,6 +540,7 @@ class _ModelInfo:
is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model(model),
default_pooling_type=get_default_pooling_type(model),
attn_type=get_attn_type(model),
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(

View File

@ -119,11 +119,12 @@ class EngineCore:
# Setup scheduler.
Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
if len(kv_cache_config.kv_cache_groups) == 0:
if len(kv_cache_config.kv_cache_groups) == 0: # noqa: SIM102
# Encoder models without KV cache don't support
# chunked prefill. But do SSM models?
logger.info("Disabling chunked prefill for model without KVCache")
vllm_config.scheduler_config.enable_chunked_prefill = False
if vllm_config.scheduler_config.enable_chunked_prefill:
logger.warning("Disabling chunked prefill for model without KVCache")
vllm_config.scheduler_config.enable_chunked_prefill = False
scheduler_block_size = (
vllm_config.cache_config.block_size