From 3729ed00ba4c5a9f351b1241b15a2ca3ca12a5e1 Mon Sep 17 00:00:00 2001 From: "wang.yuqi" Date: Thu, 23 Oct 2025 14:03:42 +0800 Subject: [PATCH] [Model] Add num_cached_tokens for PoolingRequestOutput (#27378) Signed-off-by: wang.yuqi --- .../pooling/test_auto_prefix_cache_support.py | 28 +++++++++++++--- .../pooling/test_extract_hidden_states.py | 33 +++++++++++++++++++ vllm/entrypoints/llm.py | 3 ++ vllm/entrypoints/openai/serving_embedding.py | 1 + vllm/entrypoints/score_utils.py | 1 + vllm/outputs.py | 13 +++++++- vllm/v1/engine/output_processor.py | 1 + 7 files changed, 75 insertions(+), 5 deletions(-) create mode 100644 tests/models/language/pooling/test_extract_hidden_states.py diff --git a/tests/models/language/pooling/test_auto_prefix_cache_support.py b/tests/models/language/pooling/test_auto_prefix_cache_support.py index e95119df95c71..0904c7e877ef4 100644 --- a/tests/models/language/pooling/test_auto_prefix_cache_support.py +++ b/tests/models/language/pooling/test_auto_prefix_cache_support.py @@ -19,14 +19,25 @@ def test_classify_models( model: str, dtype: str, ) -> None: - example_prompts = example_prompts * 2 + # example_prompts is too short for testing prefix_caching + example_prompts = [s * 10 for s in example_prompts] with vllm_runner( model, max_model_len=512, dtype=dtype, enable_prefix_caching=True ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert cache_config.enable_prefix_caching - vllm_outputs = vllm_model.classify(example_prompts) + + # First Run + vllm_model.classify(example_prompts) + + # assert prefix_caching works + pooling_outputs = vllm_model.llm.encode( + example_prompts, pooling_task="classify" + ) + for output in pooling_outputs: + assert output.num_cached_tokens > 0 + vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs] with hf_runner( model, dtype=dtype, auto_cls=AutoModelForSequenceClassification @@ -54,7 +65,8 @@ def test_embed_models( model: str, dtype: str, ): - example_prompts = [str(s).strip() for s in example_prompts] * 2 + # example_prompts is too short for testing prefix_caching + example_prompts = [str(s).strip() * 10 for s in example_prompts] with vllm_runner( model, @@ -64,7 +76,15 @@ def test_embed_models( ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert cache_config.enable_prefix_caching - vllm_outputs = vllm_model.embed(example_prompts) + + # First Run + vllm_model.embed(example_prompts) + + # assert prefix_caching works + pooling_outputs = vllm_model.llm.encode(example_prompts, pooling_task="embed") + for output in pooling_outputs: + assert output.num_cached_tokens > 0 + vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs] with hf_runner( model, diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py new file mode 100644 index 0000000000000..f8e3fa7d1560f --- /dev/null +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm import TokensPrompt + + +@pytest.mark.parametrize( + "model", + ["Qwen/Qwen3-0.6B"], +) +@torch.inference_mode +def test_embed_models(hf_runner, vllm_runner, model: str): + n_prompt_tokens = [55, 56, 57] + token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] + + with vllm_runner( + model, + max_model_len=128, + enforce_eager=True, + runner="pooling", + enable_chunked_prefill=False, + enable_prefix_caching=False, + ) as vllm_model: + pooling_outputs = vllm_model.llm.encode( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + pooling_task="token_embed", + ) + + for n, output in zip(n_prompt_tokens, pooling_outputs): + assert len(output.prompt_token_ids) == n + assert output.num_cached_tokens == 0 diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 7cf0ad671cf2e..869861afff037 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1078,6 +1078,9 @@ class LLM: PoolingRequestOutput[Any]( request_id="", outputs=processed_outputs, + num_cached_tokens=getattr( + processed_outputs, "num_cached_tokens", 0 + ), prompt_token_ids=[], finished=True, ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 3308198c8bd19..51f6106acec3d 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -583,6 +583,7 @@ class EmbeddingMixin(OpenAIServing): request_id=aggregator["request_id"], prompt_token_ids=original_token_ids, outputs=pooling_output_data, + num_cached_tokens=0, finished=True, ) diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index cd62cfe5448c4..309a4c996392d 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -66,6 +66,7 @@ def _cosine_similarity( request_id=f"{emb_1.request_id}_{emb_2.request_id}", outputs=pair_score, prompt_token_ids=tokens, + num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens, finished=True, ) ) diff --git a/vllm/outputs.py b/vllm/outputs.py index 114c1c5dc4b03..cdfe06f1c7fae 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -201,14 +201,21 @@ class PoolingRequestOutput(Generic[_O]): request_id (str): A unique identifier for the pooling request. outputs (PoolingOutput): The pooling results for the given input. prompt_token_ids (list[int]): A list of token IDs used in the prompt. + num_cached_tokens: The number of tokens with prefix cache hit. finished (bool): A flag indicating whether the pooling is completed. """ def __init__( - self, request_id: str, outputs: _O, prompt_token_ids: list[int], finished: bool + self, + request_id: str, + outputs: _O, + prompt_token_ids: list[int], + num_cached_tokens: int, + finished: bool, ): self.request_id = request_id self.prompt_token_ids = prompt_token_ids + self.num_cached_tokens = num_cached_tokens self.finished = finished self.outputs = outputs @@ -217,6 +224,7 @@ class PoolingRequestOutput(Generic[_O]): f"{type(self).__name__}(request_id={self.request_id!r}, " f"outputs={self.outputs!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"num_cached_tokens={self.num_cached_tokens}, " f"finished={self.finished})" ) @@ -255,6 +263,7 @@ class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]): request_id=request_output.request_id, outputs=EmbeddingOutput.from_base(request_output.outputs), prompt_token_ids=request_output.prompt_token_ids, + num_cached_tokens=request_output.num_cached_tokens, finished=request_output.finished, ) @@ -294,6 +303,7 @@ class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]): request_id=request_output.request_id, outputs=ClassificationOutput.from_base(request_output.outputs), prompt_token_ids=request_output.prompt_token_ids, + num_cached_tokens=request_output.num_cached_tokens, finished=request_output.finished, ) @@ -330,5 +340,6 @@ class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]): request_id=request_output.request_id, outputs=ScoringOutput.from_base(request_output.outputs), prompt_token_ids=request_output.prompt_token_ids, + num_cached_tokens=request_output.num_cached_tokens, finished=request_output.finished, ) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bc1542187c9b..44e4eadce42ac 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -230,6 +230,7 @@ class RequestState: return PoolingRequestOutput( request_id=request_id, outputs=first_output, + num_cached_tokens=self.num_cached_tokens, prompt_token_ids=self.prompt_token_ids, finished=finished, )