[Model] Add num_cached_tokens for PoolingRequestOutput (#27378)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-10-23 14:03:42 +08:00 committed by GitHub
parent 6644796bf4
commit 3729ed00ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 75 additions and 5 deletions

View File

@ -19,14 +19,25 @@ def test_classify_models(
model: str,
dtype: str,
) -> None:
example_prompts = example_prompts * 2
# example_prompts is too short for testing prefix_caching
example_prompts = [s * 10 for s in example_prompts]
with vllm_runner(
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert cache_config.enable_prefix_caching
vllm_outputs = vllm_model.classify(example_prompts)
# First Run
vllm_model.classify(example_prompts)
# assert prefix_caching works
pooling_outputs = vllm_model.llm.encode(
example_prompts, pooling_task="classify"
)
for output in pooling_outputs:
assert output.num_cached_tokens > 0
vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs]
with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForSequenceClassification
@ -54,7 +65,8 @@ def test_embed_models(
model: str,
dtype: str,
):
example_prompts = [str(s).strip() for s in example_prompts] * 2
# example_prompts is too short for testing prefix_caching
example_prompts = [str(s).strip() * 10 for s in example_prompts]
with vllm_runner(
model,
@ -64,7 +76,15 @@ def test_embed_models(
) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert cache_config.enable_prefix_caching
vllm_outputs = vllm_model.embed(example_prompts)
# First Run
vllm_model.embed(example_prompts)
# assert prefix_caching works
pooling_outputs = vllm_model.llm.encode(example_prompts, pooling_task="embed")
for output in pooling_outputs:
assert output.num_cached_tokens > 0
vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs]
with hf_runner(
model,

View File

@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm import TokensPrompt
@pytest.mark.parametrize(
"model",
["Qwen/Qwen3-0.6B"],
)
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, model: str):
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
with vllm_runner(
model,
max_model_len=128,
enforce_eager=True,
runner="pooling",
enable_chunked_prefill=False,
enable_prefix_caching=False,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
pooling_task="token_embed",
)
for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert output.num_cached_tokens == 0

View File

@ -1078,6 +1078,9 @@ class LLM:
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(
processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
)

View File

@ -583,6 +583,7 @@ class EmbeddingMixin(OpenAIServing):
request_id=aggregator["request_id"],
prompt_token_ids=original_token_ids,
outputs=pooling_output_data,
num_cached_tokens=0,
finished=True,
)

View File

@ -66,6 +66,7 @@ def _cosine_similarity(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)

View File

@ -201,14 +201,21 @@ class PoolingRequestOutput(Generic[_O]):
request_id (str): A unique identifier for the pooling request.
outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (list[int]): A list of token IDs used in the prompt.
num_cached_tokens: The number of tokens with prefix cache hit.
finished (bool): A flag indicating whether the pooling is completed.
"""
def __init__(
self, request_id: str, outputs: _O, prompt_token_ids: list[int], finished: bool
self,
request_id: str,
outputs: _O,
prompt_token_ids: list[int],
num_cached_tokens: int,
finished: bool,
):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.finished = finished
self.outputs = outputs
@ -217,6 +224,7 @@ class PoolingRequestOutput(Generic[_O]):
f"{type(self).__name__}(request_id={self.request_id!r}, "
f"outputs={self.outputs!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"num_cached_tokens={self.num_cached_tokens}, "
f"finished={self.finished})"
)
@ -255,6 +263,7 @@ class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):
request_id=request_output.request_id,
outputs=EmbeddingOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
num_cached_tokens=request_output.num_cached_tokens,
finished=request_output.finished,
)
@ -294,6 +303,7 @@ class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
request_id=request_output.request_id,
outputs=ClassificationOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
num_cached_tokens=request_output.num_cached_tokens,
finished=request_output.finished,
)
@ -330,5 +340,6 @@ class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
request_id=request_output.request_id,
outputs=ScoringOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
num_cached_tokens=request_output.num_cached_tokens,
finished=request_output.finished,
)

View File

@ -230,6 +230,7 @@ class RequestState:
return PoolingRequestOutput(
request_id=request_id,
outputs=first_output,
num_cached_tokens=self.num_cached_tokens,
prompt_token_ids=self.prompt_token_ids,
finished=finished,
)