mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 01:37:13 +08:00
[Model] Add num_cached_tokens for PoolingRequestOutput (#27378)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
6644796bf4
commit
3729ed00ba
@ -19,14 +19,25 @@ def test_classify_models(
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
example_prompts = example_prompts * 2
|
||||
# example_prompts is too short for testing prefix_caching
|
||||
example_prompts = [s * 10 for s in example_prompts]
|
||||
|
||||
with vllm_runner(
|
||||
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
|
||||
) as vllm_model:
|
||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||
assert cache_config.enable_prefix_caching
|
||||
vllm_outputs = vllm_model.classify(example_prompts)
|
||||
|
||||
# First Run
|
||||
vllm_model.classify(example_prompts)
|
||||
|
||||
# assert prefix_caching works
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
example_prompts, pooling_task="classify"
|
||||
)
|
||||
for output in pooling_outputs:
|
||||
assert output.num_cached_tokens > 0
|
||||
vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs]
|
||||
|
||||
with hf_runner(
|
||||
model, dtype=dtype, auto_cls=AutoModelForSequenceClassification
|
||||
@ -54,7 +65,8 @@ def test_embed_models(
|
||||
model: str,
|
||||
dtype: str,
|
||||
):
|
||||
example_prompts = [str(s).strip() for s in example_prompts] * 2
|
||||
# example_prompts is too short for testing prefix_caching
|
||||
example_prompts = [str(s).strip() * 10 for s in example_prompts]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
@ -64,7 +76,15 @@ def test_embed_models(
|
||||
) as vllm_model:
|
||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||
assert cache_config.enable_prefix_caching
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
|
||||
# First Run
|
||||
vllm_model.embed(example_prompts)
|
||||
|
||||
# assert prefix_caching works
|
||||
pooling_outputs = vllm_model.llm.encode(example_prompts, pooling_task="embed")
|
||||
for output in pooling_outputs:
|
||||
assert output.num_cached_tokens > 0
|
||||
vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs]
|
||||
|
||||
with hf_runner(
|
||||
model,
|
||||
|
||||
33
tests/models/language/pooling/test_extract_hidden_states.py
Normal file
33
tests/models/language/pooling/test_extract_hidden_states.py
Normal file
@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import TokensPrompt
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["Qwen/Qwen3-0.6B"],
|
||||
)
|
||||
@torch.inference_mode
|
||||
def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||
n_prompt_tokens = [55, 56, 57]
|
||||
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
runner="pooling",
|
||||
enable_chunked_prefill=False,
|
||||
enable_prefix_caching=False,
|
||||
) as vllm_model:
|
||||
pooling_outputs = vllm_model.llm.encode(
|
||||
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||
pooling_task="token_embed",
|
||||
)
|
||||
|
||||
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
||||
assert len(output.prompt_token_ids) == n
|
||||
assert output.num_cached_tokens == 0
|
||||
@ -1078,6 +1078,9 @@ class LLM:
|
||||
PoolingRequestOutput[Any](
|
||||
request_id="",
|
||||
outputs=processed_outputs,
|
||||
num_cached_tokens=getattr(
|
||||
processed_outputs, "num_cached_tokens", 0
|
||||
),
|
||||
prompt_token_ids=[],
|
||||
finished=True,
|
||||
)
|
||||
|
||||
@ -583,6 +583,7 @@ class EmbeddingMixin(OpenAIServing):
|
||||
request_id=aggregator["request_id"],
|
||||
prompt_token_ids=original_token_ids,
|
||||
outputs=pooling_output_data,
|
||||
num_cached_tokens=0,
|
||||
finished=True,
|
||||
)
|
||||
|
||||
|
||||
@ -66,6 +66,7 @@ def _cosine_similarity(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=pair_score,
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
@ -201,14 +201,21 @@ class PoolingRequestOutput(Generic[_O]):
|
||||
request_id (str): A unique identifier for the pooling request.
|
||||
outputs (PoolingOutput): The pooling results for the given input.
|
||||
prompt_token_ids (list[int]): A list of token IDs used in the prompt.
|
||||
num_cached_tokens: The number of tokens with prefix cache hit.
|
||||
finished (bool): A flag indicating whether the pooling is completed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, request_id: str, outputs: _O, prompt_token_ids: list[int], finished: bool
|
||||
self,
|
||||
request_id: str,
|
||||
outputs: _O,
|
||||
prompt_token_ids: list[int],
|
||||
num_cached_tokens: int,
|
||||
finished: bool,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.num_cached_tokens = num_cached_tokens
|
||||
self.finished = finished
|
||||
self.outputs = outputs
|
||||
|
||||
@ -217,6 +224,7 @@ class PoolingRequestOutput(Generic[_O]):
|
||||
f"{type(self).__name__}(request_id={self.request_id!r}, "
|
||||
f"outputs={self.outputs!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"num_cached_tokens={self.num_cached_tokens}, "
|
||||
f"finished={self.finished})"
|
||||
)
|
||||
|
||||
@ -255,6 +263,7 @@ class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):
|
||||
request_id=request_output.request_id,
|
||||
outputs=EmbeddingOutput.from_base(request_output.outputs),
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
num_cached_tokens=request_output.num_cached_tokens,
|
||||
finished=request_output.finished,
|
||||
)
|
||||
|
||||
@ -294,6 +303,7 @@ class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
|
||||
request_id=request_output.request_id,
|
||||
outputs=ClassificationOutput.from_base(request_output.outputs),
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
num_cached_tokens=request_output.num_cached_tokens,
|
||||
finished=request_output.finished,
|
||||
)
|
||||
|
||||
@ -330,5 +340,6 @@ class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
|
||||
request_id=request_output.request_id,
|
||||
outputs=ScoringOutput.from_base(request_output.outputs),
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
num_cached_tokens=request_output.num_cached_tokens,
|
||||
finished=request_output.finished,
|
||||
)
|
||||
|
||||
@ -230,6 +230,7 @@ class RequestState:
|
||||
return PoolingRequestOutput(
|
||||
request_id=request_id,
|
||||
outputs=first_output,
|
||||
num_cached_tokens=self.num_cached_tokens,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user