mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Model] Allow users to control skip reading cache per request. (#28194)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
d231876ce3
commit
a55b64635c
@ -11,7 +11,7 @@ from vllm import TokensPrompt
|
||||
["Qwen/Qwen3-0.6B"],
|
||||
)
|
||||
@torch.inference_mode
|
||||
def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||
def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
|
||||
n_prompt_tokens = [55, 56, 57]
|
||||
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
|
||||
|
||||
@ -21,7 +21,7 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||
enforce_eager=True,
|
||||
runner="pooling",
|
||||
enable_chunked_prefill=False,
|
||||
enable_prefix_caching=False,
|
||||
enable_prefix_caching=True,
|
||||
) as vllm_model:
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||
@ -30,4 +30,29 @@ def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||
|
||||
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
||||
assert len(output.prompt_token_ids) == n
|
||||
assert len(output.outputs.data) == n
|
||||
assert output.num_cached_tokens == 0
|
||||
|
||||
# test enable_prefix_caching plus all pooling
|
||||
# we need to skip reading cache at this request by
|
||||
# request.skip_reading_prefix_cache
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||
pooling_task="token_embed",
|
||||
)
|
||||
|
||||
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
||||
assert len(output.prompt_token_ids) == n
|
||||
assert len(output.outputs.data) == n
|
||||
assert output.num_cached_tokens == 0
|
||||
|
||||
# skip_reading_prefix_cache can still write to cache
|
||||
# to accelerate following requests
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||
pooling_task="embed",
|
||||
)
|
||||
|
||||
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
||||
assert len(output.prompt_token_ids) == n
|
||||
assert output.num_cached_tokens > 0
|
||||
|
||||
@ -57,6 +57,7 @@ class PoolingParams(
|
||||
## Internal use only
|
||||
task: PoolingTask | None = None
|
||||
requires_token_ids: bool = False
|
||||
skip_reading_prefix_cache: bool = None
|
||||
extra_kwargs: dict[str, Any] | None = None
|
||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
@ -93,6 +94,8 @@ class PoolingParams(
|
||||
# plugin task uses io_processor.parse_request to verify inputs,
|
||||
# skipping PoolingParams verify
|
||||
if self.task == "plugin":
|
||||
if self.skip_reading_prefix_cache is None:
|
||||
self.skip_reading_prefix_cache = True
|
||||
return
|
||||
|
||||
# NOTE: Task validation needs to done against the model instance,
|
||||
@ -122,6 +125,15 @@ class PoolingParams(
|
||||
if getattr(self, k, None) is None:
|
||||
setattr(self, k, getattr(pooler_config, k))
|
||||
|
||||
if self.skip_reading_prefix_cache is None:
|
||||
# If prefix caching is enabled,
|
||||
# the output of all pooling may less than n_prompt_tokens,
|
||||
# we need to skip reading cache at this request.
|
||||
if self.task in ["token_embed", "token_classify"]:
|
||||
self.skip_reading_prefix_cache = True
|
||||
else:
|
||||
self.skip_reading_prefix_cache = False
|
||||
|
||||
self._verify_step_pooling(pooler_config, valid_parameters)
|
||||
|
||||
def _verify_step_pooling(
|
||||
|
||||
@ -254,6 +254,8 @@ class SamplingParams(
|
||||
generated token can complete the sequence."""
|
||||
_bad_words_token_ids: list[list[int]] | None = None
|
||||
|
||||
skip_reading_prefix_cache: bool = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: int | None = 1,
|
||||
@ -414,6 +416,12 @@ class SamplingParams(
|
||||
self.structured_outputs = self.guided_decoding
|
||||
self.guided_decoding = None
|
||||
|
||||
if self.skip_reading_prefix_cache is None:
|
||||
# If prefix caching is enabled,
|
||||
# the output of prompt logprobs may less than n_prompt_tokens,
|
||||
# we need to skip reading cache at this request.
|
||||
self.skip_reading_prefix_cache = self.prompt_logprobs is not None
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if not isinstance(self.n, int):
|
||||
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
|
||||
|
||||
@ -185,12 +185,11 @@ class KVCacheManager:
|
||||
- A list of blocks that are computed for the request.
|
||||
- The number of computed tokens.
|
||||
"""
|
||||
# Prefix caching is disabled or
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if not self.enable_caching or (
|
||||
request.sampling_params is not None
|
||||
and request.sampling_params.prompt_logprobs is not None
|
||||
):
|
||||
# We skip finding the prefix cache hit when prefix caching is
|
||||
# disabled or the request is marked as skipping kv cache read
|
||||
# (which happens when the request requires prompt logprobs
|
||||
# or calls a pooling model with all pooling).
|
||||
if not self.enable_caching or request.skip_reading_prefix_cache:
|
||||
return self.empty_kv_cache_blocks, 0
|
||||
|
||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||
|
||||
@ -127,6 +127,8 @@ class Request:
|
||||
self.get_hash_new_full_blocks = partial(block_hasher, self)
|
||||
self.block_hashes = self.get_hash_new_full_blocks()
|
||||
|
||||
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(
|
||||
cls,
|
||||
@ -180,6 +182,19 @@ class Request:
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self._output_token_ids)
|
||||
|
||||
def get_skip_reading_prefix_cache(self) -> bool:
|
||||
if (
|
||||
self.sampling_params is not None
|
||||
and self.sampling_params.skip_reading_prefix_cache is not None
|
||||
):
|
||||
return self.sampling_params.skip_reading_prefix_cache
|
||||
elif (
|
||||
self.pooling_params is not None
|
||||
and self.pooling_params.skip_reading_prefix_cache is not None
|
||||
):
|
||||
return self.pooling_params.skip_reading_prefix_cache
|
||||
return False
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return RequestStatus.is_finished(self.status)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user