From f4b76056ee5c3a3f917527da5be3786e1b8530c6 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Fri, 28 Nov 2025 14:05:48 +0800 Subject: [PATCH] Improve enable chunked_prefill & prefix_caching logic. (#26623) Signed-off-by: wang.yuqi Signed-off-by: wang.yuqi Co-authored-by: Cyrus Leung --- .../pooling/test_auto_prefix_cache_support.py | 4 +- tests/test_config.py | 240 +++++++++++++++++- vllm/config/model.py | 109 ++++++++ vllm/config/pooler.py | 6 +- vllm/config/vllm.py | 76 ++---- vllm/engine/arg_utils.py | 90 +++---- vllm/model_executor/models/bert.py | 4 +- vllm/model_executor/models/interfaces_base.py | 35 ++- vllm/model_executor/models/modernbert.py | 3 +- vllm/model_executor/models/registry.py | 15 +- vllm/v1/engine/core.py | 7 +- 11 files changed, 456 insertions(+), 133 deletions(-) diff --git a/tests/models/language/pooling/test_auto_prefix_cache_support.py b/tests/models/language/pooling/test_auto_prefix_cache_support.py index 0904c7e877ef4..3795f2a5d8664 100644 --- a/tests/models/language/pooling/test_auto_prefix_cache_support.py +++ b/tests/models/language/pooling/test_auto_prefix_cache_support.py @@ -105,8 +105,6 @@ def test_embed_models( def test_non_causal_models( hf_runner, vllm_runner, example_prompts, model: str, dtype: str ) -> None: - with vllm_runner( - model, max_model_len=512, dtype=dtype, enable_prefix_caching=True - ) as vllm_model: + with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert not cache_config.enable_prefix_caching diff --git a/tests/test_config.py b/tests/test_config.py index 080e4d2afacc6..112b02edd0389 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import logging import os from dataclasses import MISSING, Field, asdict, dataclass, field from unittest.mock import patch @@ -602,6 +602,244 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files): assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer) +@pytest.mark.parametrize( + ("model_id", "expected_attn_type", "expected_result", "reason"), + [ + # pooling models + ( + "jason9693/Qwen2.5-1.5B-apeach", + "decoder", + True, + "Pooling models with causal attn and last pooling support chunked prefill.", + ), + ( + "Qwen/Qwen3-Embedding-0.6B", + "decoder", + True, + "Pooling models with causal attn and last pooling support chunked prefill.", + ), + ( + "Qwen/Qwen2.5-Math-PRM-7B", + "decoder", + False, + "Pooling models with step pooling does not support chunked prefill.", + ), + ( + "internlm/internlm2-1_8b-reward", + "decoder", + False, + "Pooling models with all pooling does not support chunked prefill.", + ), + ( + "BAAI/bge-base-en", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support chunked prefill.", + ), + ( + "boltuix/NeuroBERT-NER", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support chunked prefill.", + ), + ( + "papluca/xlm-roberta-base-language-detection", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support chunked prefill.", + ), + ( + "Alibaba-NLP/gte-Qwen2-1.5B-instruct", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support chunked prefill.", + ), + ( + "intfloat/e5-small", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support chunked prefill.", + ), + # multimodal models + ( + "openai/clip-vit-base-patch32", + "decoder", + True, + "Pooling models with causal attn and last pooling support chunked prefill.", + ), + ( + "google/siglip-base-patch16-224", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support chunked prefill.", + ), + # generate models + ( + "Qwen/Qwen3-0.6B", + "decoder", + True, + "Generative models support chunked prefill.", + ), + ( + "Qwen/Qwen3-Next-80B-A3B-Instruct", + "hybrid", + True, + "Generative models support chunked prefill.", + ), + ( + "ibm-granite/granite-4.0-h-small", + "hybrid", + True, + "Generative models support chunked prefill.", + ), + ( + "state-spaces/mamba-130m-hf", + "attention_free", + True, + "Generative models support chunked prefill.", + ), + # encoder_decoder models + ( + "openai/whisper-small", + "encoder_decoder", + False, + "Encoder decoder models does not support chunked prefill.", + ), + ], +) +def test_is_chunked_prefill_supported( + model_id: str, + expected_attn_type: str, + expected_result: bool, + reason: str, + caplog_vllm, +): + model_config = ModelConfig(model_id, trust_remote_code=True) + assert model_config.attn_type == expected_attn_type + with caplog_vllm.at_level(level=logging.DEBUG): + assert model_config.is_chunked_prefill_supported == expected_result + assert reason in caplog_vllm.text + + +@pytest.mark.parametrize( + ("model_id", "expected_attn_type", "expected_result", "reason"), + [ + # pooling models + ( + "jason9693/Qwen2.5-1.5B-apeach", + "decoder", + True, + "Pooling models with causal attn and last pooling support prefix caching.", + ), + ( + "Qwen/Qwen3-Embedding-0.6B", + "decoder", + True, + "Pooling models with causal attn and last pooling support prefix caching.", + ), + ( + "Qwen/Qwen2.5-Math-PRM-7B", + "decoder", + False, + "Pooling models with step pooling does not support prefix caching.", + ), + ( + "internlm/internlm2-1_8b-reward", + "decoder", + False, + "Pooling models with all pooling does not support prefix caching.", + ), + ( + "BAAI/bge-base-en", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support prefix caching.", + ), + ( + "boltuix/NeuroBERT-NER", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support prefix caching.", + ), + ( + "papluca/xlm-roberta-base-language-detection", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support prefix caching.", + ), + ( + "Alibaba-NLP/gte-Qwen2-1.5B-instruct", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support prefix caching.", + ), + ( + "intfloat/e5-small", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support prefix caching.", + ), + # multimodal models + ( + "openai/clip-vit-base-patch32", + "decoder", + True, + "Pooling models with causal attn and last pooling support prefix caching.", + ), + ( + "google/siglip-base-patch16-224", + "encoder_only", + False, + "Pooling models with bidirectional attn does not support prefix caching.", + ), + # generate models + ( + "Qwen/Qwen3-0.6B", + "decoder", + True, + "Generative models support prefix caching.", + ), + ( + "Qwen/Qwen3-Next-80B-A3B-Instruct", + "hybrid", + False, + "Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501 + ), + ( + "ibm-granite/granite-4.0-h-small", + "hybrid", + False, + "Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501 + ), + ( + "state-spaces/mamba-130m-hf", + "attention_free", + False, + "Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501 + ), + # encoder_decoder models + ( + "openai/whisper-small", + "encoder_decoder", + False, + "Encoder decoder models does not support prefix caching.", + ), + ], +) +def test_is_prefix_caching_supported( + model_id: str, + expected_attn_type: str, + expected_result: bool, + reason: str, + caplog_vllm, +): + model_config = ModelConfig(model_id, trust_remote_code=True) + assert model_config.attn_type == expected_attn_type + with caplog_vllm.at_level(level=logging.DEBUG): + assert model_config.is_prefix_caching_supported == expected_result + assert reason in caplog_vllm.text + + @pytest.mark.parametrize( ("backend", "custom_ops", "expected"), [ diff --git a/vllm/config/model.py b/vllm/config/model.py index 21d602b30ac1a..b9ae4fec14efa 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -107,6 +107,10 @@ _RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = { "draft": [], } +AttnTypeStr = Literal[ + "decoder", "encoder", "encoder_only", "encoder_decoder", "attention_free", "hybrid" +] + @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) @@ -1752,6 +1756,111 @@ class ModelConfig: logger.info("Using max model len %s", max_model_len) return max_model_len + @property + def attn_type(self) -> AttnTypeStr: + if self.pooler_config is not None: + pooling_type = self._model_info.default_pooling_type.lower() + if pooling_type == "cls": + return "encoder_only" + else: + is_causal = getattr(self.hf_config, "is_causal", True) + return "encoder_only" if not is_causal else self._model_info.attn_type + elif self.is_hybrid: + return "hybrid" + elif self.is_attention_free: + return "attention_free" + elif self.is_encoder_decoder: + return "encoder_decoder" + else: + return "decoder" + + @property + def is_chunked_prefill_supported(self) -> bool: + attn_type = self.attn_type + if self.pooler_config is not None: + # for pooling models + if attn_type == "encoder_only": + logger.debug( + "Pooling models with bidirectional attn does not support " + "chunked prefill." + ) + return False + elif attn_type == "decoder": + pooling_type = self.pooler_config.pooling_type.lower() + if pooling_type in ["all", "mean", "step", "cls"]: + logger.debug( + "Pooling models with %s pooling does not " + "support chunked prefill.", + pooling_type, + ) + return False + else: + # pooling_type == "last" + logger.debug( + "Pooling models with causal attn and last pooling support " + "chunked prefill." + ) + return True + # vllm currently does not have pooling models using hybrid, + # attention_free or encoder_decoder attn types. + return attn_type != "encoder_decoder" + else: + if attn_type == "encoder_decoder": + logger.debug("Encoder decoder models does not support chunked prefill.") + return False + logger.debug("Generative models support chunked prefill.") + return True + + @property + def is_prefix_caching_supported(self) -> bool: + attn_type = self.attn_type + if self.pooler_config is not None: + # for pooling models + if attn_type == "encoder_only": + logger.debug( + "Pooling models with bidirectional attn does not " + "support prefix caching." + ) + return False + elif attn_type == "decoder": + pooling_type = self.pooler_config.pooling_type.lower() + if pooling_type in ["all", "mean", "step", "cls"]: + logger.debug( + "Pooling models with %s pooling does not " + "support prefix caching.", + pooling_type, + ) + return False + else: + # pooling_type == "last" + logger.debug( + "Pooling models with causal attn and last pooling support " + "prefix caching." + ) + return True + # vllm currently does not have pooling models using hybrid, + # attention_free or encoder_decoder attn types. + return False + else: + if attn_type == "hybrid": + logger.debug( + "Hybrid models does not support prefix caching since the feature " + "is still experimental." + ) + return False + elif attn_type == "attention_free": + logger.debug( + "Attention free models does not support prefix caching since the " + "feature is still experimental." + ) + return False + elif attn_type == "encoder_decoder": + logger.debug("Encoder decoder models does not support prefix caching.") + return False + else: # attn_type == "decoder" + logger.debug("Generative models support prefix caching.") + return True + def is_model_moe( self, ) -> bool: diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index 85950bbcd666f..aa4e7006d0247 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any +from typing import Any, Literal from pydantic.dataclasses import dataclass @@ -11,13 +11,15 @@ from vllm.utils.hashing import safe_hash logger = init_logger(__name__) +PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"] + @config @dataclass class PoolerConfig: """Controls the behavior of output pooling in pooling models.""" - pooling_type: str | None = None + pooling_type: PoolingTypeStr | None = None """ The pooling method of the pooling model. This should be a key in [`vllm.model_executor.layers.pooler.PoolingType`][]. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index c576275e80fe3..7ac8cc764322e 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -721,65 +721,27 @@ class VllmConfig: "correctness and to realize prefill savings. " ) - disable_chunked_prefill_reasons: list[str] = [] + if self.model_config and self.model_config.is_encoder_decoder: + from vllm.multimodal import MULTIMODAL_REGISTRY - if self.model_config: - if self.model_config.pooler_config: - pooling_type = self.model_config.pooler_config.pooling_type - if pooling_type is None or pooling_type.lower() != "last": - disable_chunked_prefill_reasons.append( - 'Only "last" pooling supports chunked ' - "prefill and prefix caching; disabling both." - ) - if not getattr(self.model_config.hf_config, "is_causal", True): - disable_chunked_prefill_reasons.append( - "Only models using causal attention support chunked " - "prefill and prefix caching; disabling both." - ) - elif self.model_config.is_encoder_decoder: - from vllm.multimodal import MULTIMODAL_REGISTRY - - self.scheduler_config.max_num_encoder_input_tokens = ( - MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) - ) - logger.debug( - "Encoder-decoder model detected: setting " - "`max_num_encoder_input_tokens` to encoder length (%s)", - self.scheduler_config.max_num_encoder_input_tokens, - ) - if ( - self.model_config.architecture == "WhisperForConditionalGeneration" - and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn" - ): - logger.warning( - "Whisper is known to have issues with " - "forked workers. If startup is hanging, " - "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " - "to 'spawn'." - ) - - # Final off-switch for CP/APC: - # Disable for (a) collected blockers, (b) encoder–decoder, or - # (c) explicit CP=False when APC wasn't requested. - # Do NOT disable merely because the resolved CP flag is False. - apc_requested = ( - self.cache_config is not None and self.cache_config.enable_prefix_caching - ) - if ( - disable_chunked_prefill_reasons - or (self.model_config is not None and self.model_config.is_encoder_decoder) - or ( - self.scheduler_config.enable_chunked_prefill is False - and not apc_requested + self.scheduler_config.max_num_encoder_input_tokens = ( + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) ) - ): - for reason in disable_chunked_prefill_reasons: - logger.info(reason) - self.scheduler_config.enable_chunked_prefill = False - self.scheduler_config.long_prefill_token_threshold = 0 - - if self.cache_config is not None: - self.cache_config.enable_prefix_caching = False + logger.debug( + "Encoder-decoder model detected: setting " + "`max_num_encoder_input_tokens` to encoder length (%s)", + self.scheduler_config.max_num_encoder_input_tokens, + ) + if ( + self.model_config.architecture == "WhisperForConditionalGeneration" + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn" + ): + logger.warning( + "Whisper is known to have issues with " + "forked workers. If startup is hanging, " + "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " + "to 'spawn'." + ) if ( self.kv_events_config is not None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e4c9a82d25223..ad5a34c56161c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1349,30 +1349,10 @@ class EngineArgs: self.tokenizer = model_config.tokenizer self._check_feature_supported(model_config) - - # Set default arguments for V1 Engine. - self._set_default_args(usage_context, model_config) - # Disable chunked prefill and prefix caching for: - # POWER (ppc64le)/s390x/RISCV CPUs in V1 - if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( - CpuArchEnum.POWERPC, - CpuArchEnum.S390X, - CpuArchEnum.RISCV, - ): - logger.info( - "Chunked prefill is not supported for ARM and POWER, " - "S390X and RISC-V CPUs; " - "disabling it for V1 backend." - ) - self.enable_chunked_prefill = False - logger.info( - "Prefix caching is not supported for ARM and POWER, " - "S390X and RISC-V CPUs; " - "disabling it for V1 backend." - ) - self.enable_prefix_caching = False - - assert self.enable_chunked_prefill is not None + self._set_default_chunked_prefill_and_prefix_caching_args(model_config) + self._set_default_max_num_seqs_and_batched_tokens_args( + usage_context, model_config + ) sliding_window: int | None = None if not is_interleaved(model_config.hf_text_config): @@ -1805,34 +1785,6 @@ class EngineArgs: ) _raise_unsupported_error(feature_name=name) - @classmethod - def get_chunked_prefill_prefix_caching_defaults( - cls, - model_config: ModelConfig, - ) -> tuple[bool, bool]: - if model_config.runner_type != "pooling": - default_chunked_prefill = True - - # Disable prefix caching default for hybrid models and mamba-only - # models since the feature is still experimental. - default_prefix_caching = not ( - model_config.is_hybrid or model_config.is_attention_free - ) - else: - assert model_config.pooler_config is not None - - pooling_type = model_config.pooler_config.pooling_type - incremental_prefill_supported = ( - pooling_type is not None - and pooling_type.lower() == "last" - and getattr(model_config.hf_config, "is_causal", True) - ) - - default_chunked_prefill = incremental_prefill_supported - default_prefix_caching = incremental_prefill_supported - - return default_chunked_prefill, default_prefix_caching - @classmethod def get_batch_defaults( cls, @@ -1916,14 +1868,11 @@ class EngineArgs: return default_max_num_batched_tokens, default_max_num_seqs - def _set_default_args( - self, usage_context: UsageContext, model_config: ModelConfig + def _set_default_chunked_prefill_and_prefix_caching_args( + self, model_config: ModelConfig ) -> None: - """Set Default Arguments for V1 Engine.""" - ( - default_chunked_prefill, - default_prefix_caching, - ) = self.get_chunked_prefill_prefix_caching_defaults(model_config) + default_chunked_prefill = model_config.is_chunked_prefill_supported + default_prefix_caching = model_config.is_prefix_caching_supported if self.prefill_context_parallel_size > 1: default_chunked_prefill = False @@ -1984,6 +1933,29 @@ class EngineArgs: scope="local", ) + # Disable chunked prefill and prefix caching for: + # POWER (ppc64le)/s390x/RISCV CPUs in V1 + if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( + CpuArchEnum.POWERPC, + CpuArchEnum.S390X, + CpuArchEnum.RISCV, + ): + logger.info( + "Chunked prefill is not supported for ARM and POWER, " + "S390X and RISC-V CPUs; " + "disabling it for V1 backend." + ) + self.enable_chunked_prefill = False + logger.info( + "Prefix caching is not supported for ARM and POWER, " + "S390X and RISC-V CPUs; " + "disabling it for V1 backend." + ) + self.enable_prefix_caching = False + + def _set_default_max_num_seqs_and_batched_tokens_args( + self, usage_context: UsageContext, model_config: ModelConfig + ): world_size = self.pipeline_parallel_size * self.tensor_parallel_size ( default_max_num_batched_tokens, diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 2679448bce775..e774cd647ea8c 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -32,7 +32,7 @@ from vllm.tasks import PoolingTask from vllm.v1.pool.metadata import PoolingMetadata from .interfaces import SupportsCrossEncoding, SupportsQuant -from .interfaces_base import default_pooling_type +from .interfaces_base import attn_type, default_pooling_type from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -432,7 +432,6 @@ class BertModel(nn.Module, SupportsQuant): return loaded_params -@default_pooling_type("ALL") class BertPoolingModel(BertModel): is_pooling_model = True @@ -864,6 +863,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu ) +@attn_type("encoder_only") @default_pooling_type("ALL") class BertForTokenClassification(nn.Module): is_pooling_model = True diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 85c5574bacf0a..2c99fce8d918c 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -19,10 +19,14 @@ from vllm.utils.func_utils import supports_kw if TYPE_CHECKING: from vllm.config import VllmConfig + from vllm.config.model import AttnTypeStr + from vllm.config.pooler import PoolingTypeStr from vllm.model_executor.layers.pooler import Pooler else: VllmConfig = Any Pooler = Any + PoolingTypeStr = Any + AttnTypeStr = Any logger = init_logger(__name__) @@ -165,7 +169,7 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): MRO of your model class. """ - default_pooling_type: ClassVar[str] = "LAST" + default_pooling_type: ClassVar[PoolingTypeStr] = "LAST" """ Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][] to use by default. @@ -175,6 +179,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): decorator to conveniently set this field. """ + attn_type: ClassVar[AttnTypeStr] = "decoder" + """ + Indicates the + [vllm.config.model.ModelConfig.attn_type][] + to use by default. + + You can use the + [vllm.model_executor.models.interfaces_base.attn_type][] + decorator to conveniently set this field. + """ + pooler: Pooler """The pooler is only called on TP rank 0.""" @@ -199,7 +214,7 @@ def is_pooling_model( _T = TypeVar("_T", bound=type[nn.Module]) -def default_pooling_type(pooling_type: str): +def default_pooling_type(pooling_type: PoolingTypeStr): """Decorator to set `VllmModelForPooling.default_pooling_type`.""" def func(model: _T) -> _T: @@ -209,5 +224,19 @@ def default_pooling_type(pooling_type: str): return func -def get_default_pooling_type(model: type[object] | object) -> str: +def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr: return getattr(model, "default_pooling_type", "LAST") + + +def attn_type(attn_type: AttnTypeStr): + """Decorator to set `VllmModelForPooling.attn_type`.""" + + def func(model: _T) -> _T: + model.attn_type = attn_type # type: ignore + return model + + return func + + +def get_attn_type(model: type[object] | object) -> AttnTypeStr: + return getattr(model, "attn_type", "decoder") diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 3a8a6c74d9d15..743bc23d9876f 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -28,7 +28,7 @@ from vllm.tasks import PoolingTask from vllm.v1.pool.metadata import PoolingMetadata from .interfaces import SupportsCrossEncoding -from .interfaces_base import default_pooling_type +from .interfaces_base import attn_type, default_pooling_type from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix @@ -396,6 +396,7 @@ class ModernBertPredictionHead(nn.Module): return self.norm(self.act(self.dense(hidden_states))) +@attn_type("encoder_only") @default_pooling_type("ALL") class ModernBertForTokenClassification(nn.Module): is_pooling_model = True diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 0d582043e8c02..73a61f1148b50 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -17,7 +17,7 @@ from collections.abc import Callable, Set from dataclasses import asdict, dataclass, field from functools import lru_cache from pathlib import Path -from typing import TypeVar +from typing import TYPE_CHECKING, Any, TypeVar import torch.nn as nn import transformers @@ -33,6 +33,14 @@ from vllm.logging_utils import logtime from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module from vllm.utils.hashing import safe_hash +if TYPE_CHECKING: + from vllm.config.model import AttnTypeStr + from vllm.config.pooler import PoolingTypeStr +else: + AttnTypeStr = Any + PoolingTypeStr = Any + + from .interfaces import ( has_inner_state, has_noops, @@ -47,6 +55,7 @@ from .interfaces import ( supports_transcription, ) from .interfaces_base import ( + get_attn_type, get_default_pooling_type, is_pooling_model, is_text_generation_model, @@ -509,7 +518,8 @@ class _ModelInfo: architecture: str is_text_generation_model: bool is_pooling_model: bool - default_pooling_type: str + attn_type: AttnTypeStr + default_pooling_type: PoolingTypeStr supports_cross_encoding: bool supports_multimodal: bool supports_multimodal_raw_input_only: bool @@ -530,6 +540,7 @@ class _ModelInfo: is_text_generation_model=is_text_generation_model(model), is_pooling_model=is_pooling_model(model), default_pooling_type=get_default_pooling_type(model), + attn_type=get_attn_type(model), supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_multimodal_raw_input_only=supports_multimodal_raw_input_only( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 8657a95b5e6e7..e3a5f51a8fc56 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -119,11 +119,12 @@ class EngineCore: # Setup scheduler. Scheduler = vllm_config.scheduler_config.get_scheduler_cls() - if len(kv_cache_config.kv_cache_groups) == 0: + if len(kv_cache_config.kv_cache_groups) == 0: # noqa: SIM102 # Encoder models without KV cache don't support # chunked prefill. But do SSM models? - logger.info("Disabling chunked prefill for model without KVCache") - vllm_config.scheduler_config.enable_chunked_prefill = False + if vllm_config.scheduler_config.enable_chunked_prefill: + logger.warning("Disabling chunked prefill for model without KVCache") + vllm_config.scheduler_config.enable_chunked_prefill = False scheduler_block_size = ( vllm_config.cache_config.block_size