mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 00:32:29 +08:00
Improve enable chunked_prefill & prefix_caching logic. (#26623)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
37b15e97e8
commit
f4b76056ee
@ -105,8 +105,6 @@ def test_embed_models(
|
||||
def test_non_causal_models(
|
||||
hf_runner, vllm_runner, example_prompts, model: str, dtype: str
|
||||
) -> None:
|
||||
with vllm_runner(
|
||||
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
|
||||
) as vllm_model:
|
||||
with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
|
||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||
assert not cache_config.enable_prefix_caching
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||
from unittest.mock import patch
|
||||
@ -602,6 +602,244 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
|
||||
assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_attn_type", "expected_result", "reason"),
|
||||
[
|
||||
# pooling models
|
||||
(
|
||||
"jason9693/Qwen2.5-1.5B-apeach",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Embedding-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen2.5-Math-PRM-7B",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with step pooling does not support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"internlm/internlm2-1_8b-reward",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with all pooling does not support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"BAAI/bge-base-en",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"boltuix/NeuroBERT-NER",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"papluca/xlm-roberta-base-language-detection",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"intfloat/e5-small",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
),
|
||||
# multimodal models
|
||||
(
|
||||
"openai/clip-vit-base-patch32",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"google/siglip-base-patch16-224",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support chunked prefill.",
|
||||
),
|
||||
# generate models
|
||||
(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
"hybrid",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"ibm-granite/granite-4.0-h-small",
|
||||
"hybrid",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
),
|
||||
(
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"attention_free",
|
||||
True,
|
||||
"Generative models support chunked prefill.",
|
||||
),
|
||||
# encoder_decoder models
|
||||
(
|
||||
"openai/whisper-small",
|
||||
"encoder_decoder",
|
||||
False,
|
||||
"Encoder decoder models does not support chunked prefill.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_chunked_prefill_supported(
|
||||
model_id: str,
|
||||
expected_attn_type: str,
|
||||
expected_result: bool,
|
||||
reason: str,
|
||||
caplog_vllm,
|
||||
):
|
||||
model_config = ModelConfig(model_id, trust_remote_code=True)
|
||||
assert model_config.attn_type == expected_attn_type
|
||||
with caplog_vllm.at_level(level=logging.DEBUG):
|
||||
assert model_config.is_chunked_prefill_supported == expected_result
|
||||
assert reason in caplog_vllm.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_attn_type", "expected_result", "reason"),
|
||||
[
|
||||
# pooling models
|
||||
(
|
||||
"jason9693/Qwen2.5-1.5B-apeach",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Embedding-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen2.5-Math-PRM-7B",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with step pooling does not support prefix caching.",
|
||||
),
|
||||
(
|
||||
"internlm/internlm2-1_8b-reward",
|
||||
"decoder",
|
||||
False,
|
||||
"Pooling models with all pooling does not support prefix caching.",
|
||||
),
|
||||
(
|
||||
"BAAI/bge-base-en",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
),
|
||||
(
|
||||
"boltuix/NeuroBERT-NER",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
),
|
||||
(
|
||||
"papluca/xlm-roberta-base-language-detection",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
),
|
||||
(
|
||||
"Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
),
|
||||
(
|
||||
"intfloat/e5-small",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
),
|
||||
# multimodal models
|
||||
(
|
||||
"openai/clip-vit-base-patch32",
|
||||
"decoder",
|
||||
True,
|
||||
"Pooling models with causal attn and last pooling support prefix caching.",
|
||||
),
|
||||
(
|
||||
"google/siglip-base-patch16-224",
|
||||
"encoder_only",
|
||||
False,
|
||||
"Pooling models with bidirectional attn does not support prefix caching.",
|
||||
),
|
||||
# generate models
|
||||
(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"decoder",
|
||||
True,
|
||||
"Generative models support prefix caching.",
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
"hybrid",
|
||||
False,
|
||||
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"ibm-granite/granite-4.0-h-small",
|
||||
"hybrid",
|
||||
False,
|
||||
"Hybrid models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
),
|
||||
(
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"attention_free",
|
||||
False,
|
||||
"Attention free models does not support prefix caching since the feature is still experimental.", # noqa: E501
|
||||
),
|
||||
# encoder_decoder models
|
||||
(
|
||||
"openai/whisper-small",
|
||||
"encoder_decoder",
|
||||
False,
|
||||
"Encoder decoder models does not support prefix caching.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_is_prefix_caching_supported(
|
||||
model_id: str,
|
||||
expected_attn_type: str,
|
||||
expected_result: bool,
|
||||
reason: str,
|
||||
caplog_vllm,
|
||||
):
|
||||
model_config = ModelConfig(model_id, trust_remote_code=True)
|
||||
assert model_config.attn_type == expected_attn_type
|
||||
with caplog_vllm.at_level(level=logging.DEBUG):
|
||||
assert model_config.is_prefix_caching_supported == expected_result
|
||||
assert reason in caplog_vllm.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("backend", "custom_ops", "expected"),
|
||||
[
|
||||
|
||||
@ -107,6 +107,10 @@ _RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
|
||||
"draft": [],
|
||||
}
|
||||
|
||||
AttnTypeStr = Literal[
|
||||
"decoder", "encoder", "encoder_only", "encoder_decoder", "attention_free", "hybrid"
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
@ -1752,6 +1756,111 @@ class ModelConfig:
|
||||
logger.info("Using max model len %s", max_model_len)
|
||||
return max_model_len
|
||||
|
||||
@property
|
||||
def attn_type(self) -> AttnTypeStr:
|
||||
if self.pooler_config is not None:
|
||||
pooling_type = self._model_info.default_pooling_type.lower()
|
||||
if pooling_type == "cls":
|
||||
return "encoder_only"
|
||||
else:
|
||||
is_causal = getattr(self.hf_config, "is_causal", True)
|
||||
return "encoder_only" if not is_causal else self._model_info.attn_type
|
||||
elif self.is_hybrid:
|
||||
return "hybrid"
|
||||
elif self.is_attention_free:
|
||||
return "attention_free"
|
||||
elif self.is_encoder_decoder:
|
||||
return "encoder_decoder"
|
||||
else:
|
||||
return "decoder"
|
||||
|
||||
@property
|
||||
def is_chunked_prefill_supported(self) -> bool:
|
||||
attn_type = self.attn_type
|
||||
if self.pooler_config is not None:
|
||||
# for pooling models
|
||||
if attn_type == "encoder_only":
|
||||
logger.debug(
|
||||
"Pooling models with bidirectional attn does not support "
|
||||
"chunked prefill."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "decoder":
|
||||
pooling_type = self.pooler_config.pooling_type.lower()
|
||||
if pooling_type in ["all", "mean", "step", "cls"]:
|
||||
logger.debug(
|
||||
"Pooling models with %s pooling does not "
|
||||
"support chunked prefill.",
|
||||
pooling_type,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# pooling_type == "last"
|
||||
logger.debug(
|
||||
"Pooling models with causal attn and last pooling support "
|
||||
"chunked prefill."
|
||||
)
|
||||
return True
|
||||
# vllm currently does not have pooling models using hybrid,
|
||||
# attention_free or encoder_decoder attn types.
|
||||
return attn_type != "encoder_decoder"
|
||||
else:
|
||||
if attn_type == "encoder_decoder":
|
||||
logger.debug("Encoder decoder models does not support chunked prefill.")
|
||||
return False
|
||||
logger.debug("Generative models support chunked prefill.")
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_prefix_caching_supported(self) -> bool:
|
||||
attn_type = self.attn_type
|
||||
if self.pooler_config is not None:
|
||||
# for pooling models
|
||||
if attn_type == "encoder_only":
|
||||
logger.debug(
|
||||
"Pooling models with bidirectional attn does not "
|
||||
"support prefix caching."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "decoder":
|
||||
pooling_type = self.pooler_config.pooling_type.lower()
|
||||
if pooling_type in ["all", "mean", "step", "cls"]:
|
||||
logger.debug(
|
||||
"Pooling models with %s pooling does not "
|
||||
"support prefix caching.",
|
||||
pooling_type,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# pooling_type == "last"
|
||||
logger.debug(
|
||||
"Pooling models with causal attn and last pooling support "
|
||||
"prefix caching."
|
||||
)
|
||||
return True
|
||||
# vllm currently does not have pooling models using hybrid,
|
||||
# attention_free or encoder_decoder attn types.
|
||||
return False
|
||||
else:
|
||||
if attn_type == "hybrid":
|
||||
logger.debug(
|
||||
"Hybrid models does not support prefix caching since the feature "
|
||||
"is still experimental."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "attention_free":
|
||||
logger.debug(
|
||||
"Attention free models does not support prefix caching since the "
|
||||
"feature is still experimental."
|
||||
)
|
||||
return False
|
||||
elif attn_type == "encoder_decoder":
|
||||
logger.debug("Encoder decoder models does not support prefix caching.")
|
||||
return False
|
||||
else: # attn_type == "decoder"
|
||||
logger.debug("Generative models support prefix caching.")
|
||||
return True
|
||||
|
||||
def is_model_moe(
|
||||
self,
|
||||
) -> bool:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
@ -11,13 +11,15 @@ from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
|
||||
pooling_type: str | None = None
|
||||
pooling_type: PoolingTypeStr | None = None
|
||||
"""
|
||||
The pooling method of the pooling model. This should be a key in
|
||||
[`vllm.model_executor.layers.pooler.PoolingType`][].
|
||||
|
||||
@ -721,65 +721,27 @@ class VllmConfig:
|
||||
"correctness and to realize prefill savings. "
|
||||
)
|
||||
|
||||
disable_chunked_prefill_reasons: list[str] = []
|
||||
if self.model_config and self.model_config.is_encoder_decoder:
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
if self.model_config:
|
||||
if self.model_config.pooler_config:
|
||||
pooling_type = self.model_config.pooler_config.pooling_type
|
||||
if pooling_type is None or pooling_type.lower() != "last":
|
||||
disable_chunked_prefill_reasons.append(
|
||||
'Only "last" pooling supports chunked '
|
||||
"prefill and prefix caching; disabling both."
|
||||
)
|
||||
if not getattr(self.model_config.hf_config, "is_causal", True):
|
||||
disable_chunked_prefill_reasons.append(
|
||||
"Only models using causal attention support chunked "
|
||||
"prefill and prefix caching; disabling both."
|
||||
)
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
self.scheduler_config.max_num_encoder_input_tokens = (
|
||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
||||
)
|
||||
logger.debug(
|
||||
"Encoder-decoder model detected: setting "
|
||||
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
||||
self.scheduler_config.max_num_encoder_input_tokens,
|
||||
)
|
||||
if (
|
||||
self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
|
||||
# Final off-switch for CP/APC:
|
||||
# Disable for (a) collected blockers, (b) encoder–decoder, or
|
||||
# (c) explicit CP=False when APC wasn't requested.
|
||||
# Do NOT disable merely because the resolved CP flag is False.
|
||||
apc_requested = (
|
||||
self.cache_config is not None and self.cache_config.enable_prefix_caching
|
||||
)
|
||||
if (
|
||||
disable_chunked_prefill_reasons
|
||||
or (self.model_config is not None and self.model_config.is_encoder_decoder)
|
||||
or (
|
||||
self.scheduler_config.enable_chunked_prefill is False
|
||||
and not apc_requested
|
||||
self.scheduler_config.max_num_encoder_input_tokens = (
|
||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
|
||||
)
|
||||
):
|
||||
for reason in disable_chunked_prefill_reasons:
|
||||
logger.info(reason)
|
||||
self.scheduler_config.enable_chunked_prefill = False
|
||||
self.scheduler_config.long_prefill_token_threshold = 0
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
logger.debug(
|
||||
"Encoder-decoder model detected: setting "
|
||||
"`max_num_encoder_input_tokens` to encoder length (%s)",
|
||||
self.scheduler_config.max_num_encoder_input_tokens,
|
||||
)
|
||||
if (
|
||||
self.model_config.architecture == "WhisperForConditionalGeneration"
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"
|
||||
):
|
||||
logger.warning(
|
||||
"Whisper is known to have issues with "
|
||||
"forked workers. If startup is hanging, "
|
||||
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
|
||||
"to 'spawn'."
|
||||
)
|
||||
|
||||
if (
|
||||
self.kv_events_config is not None
|
||||
|
||||
@ -1349,30 +1349,10 @@ class EngineArgs:
|
||||
self.tokenizer = model_config.tokenizer
|
||||
|
||||
self._check_feature_supported(model_config)
|
||||
|
||||
# Set default arguments for V1 Engine.
|
||||
self._set_default_args(usage_context, model_config)
|
||||
# Disable chunked prefill and prefix caching for:
|
||||
# POWER (ppc64le)/s390x/RISCV CPUs in V1
|
||||
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
|
||||
CpuArchEnum.POWERPC,
|
||||
CpuArchEnum.S390X,
|
||||
CpuArchEnum.RISCV,
|
||||
):
|
||||
logger.info(
|
||||
"Chunked prefill is not supported for ARM and POWER, "
|
||||
"S390X and RISC-V CPUs; "
|
||||
"disabling it for V1 backend."
|
||||
)
|
||||
self.enable_chunked_prefill = False
|
||||
logger.info(
|
||||
"Prefix caching is not supported for ARM and POWER, "
|
||||
"S390X and RISC-V CPUs; "
|
||||
"disabling it for V1 backend."
|
||||
)
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
assert self.enable_chunked_prefill is not None
|
||||
self._set_default_chunked_prefill_and_prefix_caching_args(model_config)
|
||||
self._set_default_max_num_seqs_and_batched_tokens_args(
|
||||
usage_context, model_config
|
||||
)
|
||||
|
||||
sliding_window: int | None = None
|
||||
if not is_interleaved(model_config.hf_text_config):
|
||||
@ -1805,34 +1785,6 @@ class EngineArgs:
|
||||
)
|
||||
_raise_unsupported_error(feature_name=name)
|
||||
|
||||
@classmethod
|
||||
def get_chunked_prefill_prefix_caching_defaults(
|
||||
cls,
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[bool, bool]:
|
||||
if model_config.runner_type != "pooling":
|
||||
default_chunked_prefill = True
|
||||
|
||||
# Disable prefix caching default for hybrid models and mamba-only
|
||||
# models since the feature is still experimental.
|
||||
default_prefix_caching = not (
|
||||
model_config.is_hybrid or model_config.is_attention_free
|
||||
)
|
||||
else:
|
||||
assert model_config.pooler_config is not None
|
||||
|
||||
pooling_type = model_config.pooler_config.pooling_type
|
||||
incremental_prefill_supported = (
|
||||
pooling_type is not None
|
||||
and pooling_type.lower() == "last"
|
||||
and getattr(model_config.hf_config, "is_causal", True)
|
||||
)
|
||||
|
||||
default_chunked_prefill = incremental_prefill_supported
|
||||
default_prefix_caching = incremental_prefill_supported
|
||||
|
||||
return default_chunked_prefill, default_prefix_caching
|
||||
|
||||
@classmethod
|
||||
def get_batch_defaults(
|
||||
cls,
|
||||
@ -1916,14 +1868,11 @@ class EngineArgs:
|
||||
|
||||
return default_max_num_batched_tokens, default_max_num_seqs
|
||||
|
||||
def _set_default_args(
|
||||
self, usage_context: UsageContext, model_config: ModelConfig
|
||||
def _set_default_chunked_prefill_and_prefix_caching_args(
|
||||
self, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
(
|
||||
default_chunked_prefill,
|
||||
default_prefix_caching,
|
||||
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
|
||||
default_chunked_prefill = model_config.is_chunked_prefill_supported
|
||||
default_prefix_caching = model_config.is_prefix_caching_supported
|
||||
|
||||
if self.prefill_context_parallel_size > 1:
|
||||
default_chunked_prefill = False
|
||||
@ -1984,6 +1933,29 @@ class EngineArgs:
|
||||
scope="local",
|
||||
)
|
||||
|
||||
# Disable chunked prefill and prefix caching for:
|
||||
# POWER (ppc64le)/s390x/RISCV CPUs in V1
|
||||
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
|
||||
CpuArchEnum.POWERPC,
|
||||
CpuArchEnum.S390X,
|
||||
CpuArchEnum.RISCV,
|
||||
):
|
||||
logger.info(
|
||||
"Chunked prefill is not supported for ARM and POWER, "
|
||||
"S390X and RISC-V CPUs; "
|
||||
"disabling it for V1 backend."
|
||||
)
|
||||
self.enable_chunked_prefill = False
|
||||
logger.info(
|
||||
"Prefix caching is not supported for ARM and POWER, "
|
||||
"S390X and RISC-V CPUs; "
|
||||
"disabling it for V1 backend."
|
||||
)
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
def _set_default_max_num_seqs_and_batched_tokens_args(
|
||||
self, usage_context: UsageContext, model_config: ModelConfig
|
||||
):
|
||||
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
(
|
||||
default_max_num_batched_tokens,
|
||||
|
||||
@ -32,7 +32,7 @@ from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .interfaces_base import attn_type, default_pooling_type
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
@ -432,7 +432,6 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
return loaded_params
|
||||
|
||||
|
||||
@default_pooling_type("ALL")
|
||||
class BertPoolingModel(BertModel):
|
||||
is_pooling_model = True
|
||||
|
||||
@ -864,6 +863,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
||||
)
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
@default_pooling_type("ALL")
|
||||
class BertForTokenClassification(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
@ -19,10 +19,14 @@ from vllm.utils.func_utils import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.model import AttnTypeStr
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
else:
|
||||
VllmConfig = Any
|
||||
Pooler = Any
|
||||
PoolingTypeStr = Any
|
||||
AttnTypeStr = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -165,7 +169,7 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
default_pooling_type: ClassVar[str] = "LAST"
|
||||
default_pooling_type: ClassVar[PoolingTypeStr] = "LAST"
|
||||
"""
|
||||
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
|
||||
to use by default.
|
||||
@ -175,6 +179,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
|
||||
decorator to conveniently set this field.
|
||||
"""
|
||||
|
||||
attn_type: ClassVar[AttnTypeStr] = "decoder"
|
||||
"""
|
||||
Indicates the
|
||||
[vllm.config.model.ModelConfig.attn_type][]
|
||||
to use by default.
|
||||
|
||||
You can use the
|
||||
[vllm.model_executor.models.interfaces_base.attn_type][]
|
||||
decorator to conveniently set this field.
|
||||
"""
|
||||
|
||||
pooler: Pooler
|
||||
"""The pooler is only called on TP rank 0."""
|
||||
|
||||
@ -199,7 +214,7 @@ def is_pooling_model(
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
|
||||
def default_pooling_type(pooling_type: str):
|
||||
def default_pooling_type(pooling_type: PoolingTypeStr):
|
||||
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
|
||||
|
||||
def func(model: _T) -> _T:
|
||||
@ -209,5 +224,19 @@ def default_pooling_type(pooling_type: str):
|
||||
return func
|
||||
|
||||
|
||||
def get_default_pooling_type(model: type[object] | object) -> str:
|
||||
def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr:
|
||||
return getattr(model, "default_pooling_type", "LAST")
|
||||
|
||||
|
||||
def attn_type(attn_type: AttnTypeStr):
|
||||
"""Decorator to set `VllmModelForPooling.attn_type`."""
|
||||
|
||||
def func(model: _T) -> _T:
|
||||
model.attn_type = attn_type # type: ignore
|
||||
return model
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def get_attn_type(model: type[object] | object) -> AttnTypeStr:
|
||||
return getattr(model, "attn_type", "decoder")
|
||||
|
||||
@ -28,7 +28,7 @@ from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .interfaces_base import attn_type, default_pooling_type
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
@ -396,6 +396,7 @@ class ModernBertPredictionHead(nn.Module):
|
||||
return self.norm(self.act(self.dense(hidden_states)))
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
@default_pooling_type("ALL")
|
||||
class ModernBertForTokenClassification(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
@ -17,7 +17,7 @@ from collections.abc import Callable, Set
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
@ -33,6 +33,14 @@ from vllm.logging_utils import logtime
|
||||
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.model import AttnTypeStr
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
else:
|
||||
AttnTypeStr = Any
|
||||
PoolingTypeStr = Any
|
||||
|
||||
|
||||
from .interfaces import (
|
||||
has_inner_state,
|
||||
has_noops,
|
||||
@ -47,6 +55,7 @@ from .interfaces import (
|
||||
supports_transcription,
|
||||
)
|
||||
from .interfaces_base import (
|
||||
get_attn_type,
|
||||
get_default_pooling_type,
|
||||
is_pooling_model,
|
||||
is_text_generation_model,
|
||||
@ -509,7 +518,8 @@ class _ModelInfo:
|
||||
architecture: str
|
||||
is_text_generation_model: bool
|
||||
is_pooling_model: bool
|
||||
default_pooling_type: str
|
||||
attn_type: AttnTypeStr
|
||||
default_pooling_type: PoolingTypeStr
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_multimodal_raw_input_only: bool
|
||||
@ -530,6 +540,7 @@ class _ModelInfo:
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_pooling_model=is_pooling_model(model),
|
||||
default_pooling_type=get_default_pooling_type(model),
|
||||
attn_type=get_attn_type(model),
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
|
||||
|
||||
@ -119,11 +119,12 @@ class EngineCore:
|
||||
# Setup scheduler.
|
||||
Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
|
||||
|
||||
if len(kv_cache_config.kv_cache_groups) == 0:
|
||||
if len(kv_cache_config.kv_cache_groups) == 0: # noqa: SIM102
|
||||
# Encoder models without KV cache don't support
|
||||
# chunked prefill. But do SSM models?
|
||||
logger.info("Disabling chunked prefill for model without KVCache")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
if vllm_config.scheduler_config.enable_chunked_prefill:
|
||||
logger.warning("Disabling chunked prefill for model without KVCache")
|
||||
vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
|
||||
scheduler_block_size = (
|
||||
vllm_config.cache_config.block_size
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user