[Model] Allow users to control skip reading cache per request. (#28194)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-11-16 16:04:50 +08:00 committed by GitHub
parent d231876ce3
commit a55b64635c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 67 additions and 8 deletions

View File

@ -11,7 +11,7 @@ from vllm import TokensPrompt
["Qwen/Qwen3-0.6B"],
)
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, model: str):
def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
@ -21,7 +21,7 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
enforce_eager=True,
runner="pooling",
enable_chunked_prefill=False,
enable_prefix_caching=False,
enable_prefix_caching=True,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
@ -30,4 +30,29 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert len(output.outputs.data) == n
assert output.num_cached_tokens == 0
# test enable_prefix_caching plus all pooling
# we need to skip reading cache at this request by
# request.skip_reading_prefix_cache
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
pooling_task="token_embed",
)
for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert len(output.outputs.data) == n
assert output.num_cached_tokens == 0
# skip_reading_prefix_cache can still write to cache
# to accelerate following requests
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
pooling_task="embed",
)
for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert output.num_cached_tokens > 0

View File

@ -57,6 +57,7 @@ class PoolingParams(
## Internal use only
task: PoolingTask | None = None
requires_token_ids: bool = False
skip_reading_prefix_cache: bool = None
extra_kwargs: dict[str, Any] | None = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@ -93,6 +94,8 @@ class PoolingParams(
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
if self.skip_reading_prefix_cache is None:
self.skip_reading_prefix_cache = True
return
# NOTE: Task validation needs to done against the model instance,
@ -122,6 +125,15 @@ class PoolingParams(
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))
if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of all pooling may less than n_prompt_tokens,
# we need to skip reading cache at this request.
if self.task in ["token_embed", "token_classify"]:
self.skip_reading_prefix_cache = True
else:
self.skip_reading_prefix_cache = False
self._verify_step_pooling(pooler_config, valid_parameters)
def _verify_step_pooling(

View File

@ -254,6 +254,8 @@ class SamplingParams(
generated token can complete the sequence."""
_bad_words_token_ids: list[list[int]] | None = None
skip_reading_prefix_cache: bool = None
@staticmethod
def from_optional(
n: int | None = 1,
@ -414,6 +416,12 @@ class SamplingParams(
self.structured_outputs = self.guided_decoding
self.guided_decoding = None
if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of prompt logprobs may less than n_prompt_tokens,
# we need to skip reading cache at this request.
self.skip_reading_prefix_cache = self.prompt_logprobs is not None
def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of type {type(self.n)}")

View File

@ -185,12 +185,11 @@ class KVCacheManager:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if not self.enable_caching or (
request.sampling_params is not None
and request.sampling_params.prompt_logprobs is not None
):
# We skip finding the prefix cache hit when prefix caching is
# disabled or the request is marked as skipping kv cache read
# (which happens when the request requires prompt logprobs
# or calls a pooling model with all pooling).
if not self.enable_caching or request.skip_reading_prefix_cache:
return self.empty_kv_cache_blocks, 0
# NOTE: When all tokens hit the cache, we must recompute the last token

View File

@ -127,6 +127,8 @@ class Request:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
@classmethod
def from_engine_core_request(
cls,
@ -180,6 +182,19 @@ class Request:
def num_output_tokens(self) -> int:
return len(self._output_token_ids)
def get_skip_reading_prefix_cache(self) -> bool:
if (
self.sampling_params is not None
and self.sampling_params.skip_reading_prefix_cache is not None
):
return self.sampling_params.skip_reading_prefix_cache
elif (
self.pooling_params is not None
and self.pooling_params.skip_reading_prefix_cache is not None
):
return self.pooling_params.skip_reading_prefix_cache
return False
def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)