mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 03:02:15 +08:00
[Model] Add num_cached_tokens for PoolingRequestOutput (#27378)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
6644796bf4
commit
3729ed00ba
@ -19,14 +19,25 @@ def test_classify_models(
|
|||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
example_prompts = example_prompts * 2
|
# example_prompts is too short for testing prefix_caching
|
||||||
|
example_prompts = [s * 10 for s in example_prompts]
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
|
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||||
assert cache_config.enable_prefix_caching
|
assert cache_config.enable_prefix_caching
|
||||||
vllm_outputs = vllm_model.classify(example_prompts)
|
|
||||||
|
# First Run
|
||||||
|
vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
|
# assert prefix_caching works
|
||||||
|
pooling_outputs = vllm_model.llm.encode(
|
||||||
|
example_prompts, pooling_task="classify"
|
||||||
|
)
|
||||||
|
for output in pooling_outputs:
|
||||||
|
assert output.num_cached_tokens > 0
|
||||||
|
vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs]
|
||||||
|
|
||||||
with hf_runner(
|
with hf_runner(
|
||||||
model, dtype=dtype, auto_cls=AutoModelForSequenceClassification
|
model, dtype=dtype, auto_cls=AutoModelForSequenceClassification
|
||||||
@ -54,7 +65,8 @@ def test_embed_models(
|
|||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
):
|
):
|
||||||
example_prompts = [str(s).strip() for s in example_prompts] * 2
|
# example_prompts is too short for testing prefix_caching
|
||||||
|
example_prompts = [str(s).strip() * 10 for s in example_prompts]
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
@ -64,7 +76,15 @@ def test_embed_models(
|
|||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
cache_config = vllm_model.llm.llm_engine.cache_config
|
cache_config = vllm_model.llm.llm_engine.cache_config
|
||||||
assert cache_config.enable_prefix_caching
|
assert cache_config.enable_prefix_caching
|
||||||
vllm_outputs = vllm_model.embed(example_prompts)
|
|
||||||
|
# First Run
|
||||||
|
vllm_model.embed(example_prompts)
|
||||||
|
|
||||||
|
# assert prefix_caching works
|
||||||
|
pooling_outputs = vllm_model.llm.encode(example_prompts, pooling_task="embed")
|
||||||
|
for output in pooling_outputs:
|
||||||
|
assert output.num_cached_tokens > 0
|
||||||
|
vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs]
|
||||||
|
|
||||||
with hf_runner(
|
with hf_runner(
|
||||||
model,
|
model,
|
||||||
|
|||||||
33
tests/models/language/pooling/test_extract_hidden_states.py
Normal file
33
tests/models/language/pooling/test_extract_hidden_states.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import TokensPrompt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["Qwen/Qwen3-0.6B"],
|
||||||
|
)
|
||||||
|
@torch.inference_mode
|
||||||
|
def test_embed_models(hf_runner, vllm_runner, model: str):
|
||||||
|
n_prompt_tokens = [55, 56, 57]
|
||||||
|
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
max_model_len=128,
|
||||||
|
enforce_eager=True,
|
||||||
|
runner="pooling",
|
||||||
|
enable_chunked_prefill=False,
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
) as vllm_model:
|
||||||
|
pooling_outputs = vllm_model.llm.encode(
|
||||||
|
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
||||||
|
pooling_task="token_embed",
|
||||||
|
)
|
||||||
|
|
||||||
|
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
||||||
|
assert len(output.prompt_token_ids) == n
|
||||||
|
assert output.num_cached_tokens == 0
|
||||||
@ -1078,6 +1078,9 @@ class LLM:
|
|||||||
PoolingRequestOutput[Any](
|
PoolingRequestOutput[Any](
|
||||||
request_id="",
|
request_id="",
|
||||||
outputs=processed_outputs,
|
outputs=processed_outputs,
|
||||||
|
num_cached_tokens=getattr(
|
||||||
|
processed_outputs, "num_cached_tokens", 0
|
||||||
|
),
|
||||||
prompt_token_ids=[],
|
prompt_token_ids=[],
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -583,6 +583,7 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
request_id=aggregator["request_id"],
|
request_id=aggregator["request_id"],
|
||||||
prompt_token_ids=original_token_ids,
|
prompt_token_ids=original_token_ids,
|
||||||
outputs=pooling_output_data,
|
outputs=pooling_output_data,
|
||||||
|
num_cached_tokens=0,
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -66,6 +66,7 @@ def _cosine_similarity(
|
|||||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||||
outputs=pair_score,
|
outputs=pair_score,
|
||||||
prompt_token_ids=tokens,
|
prompt_token_ids=tokens,
|
||||||
|
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||||
finished=True,
|
finished=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -201,14 +201,21 @@ class PoolingRequestOutput(Generic[_O]):
|
|||||||
request_id (str): A unique identifier for the pooling request.
|
request_id (str): A unique identifier for the pooling request.
|
||||||
outputs (PoolingOutput): The pooling results for the given input.
|
outputs (PoolingOutput): The pooling results for the given input.
|
||||||
prompt_token_ids (list[int]): A list of token IDs used in the prompt.
|
prompt_token_ids (list[int]): A list of token IDs used in the prompt.
|
||||||
|
num_cached_tokens: The number of tokens with prefix cache hit.
|
||||||
finished (bool): A flag indicating whether the pooling is completed.
|
finished (bool): A flag indicating whether the pooling is completed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, request_id: str, outputs: _O, prompt_token_ids: list[int], finished: bool
|
self,
|
||||||
|
request_id: str,
|
||||||
|
outputs: _O,
|
||||||
|
prompt_token_ids: list[int],
|
||||||
|
num_cached_tokens: int,
|
||||||
|
finished: bool,
|
||||||
):
|
):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.prompt_token_ids = prompt_token_ids
|
self.prompt_token_ids = prompt_token_ids
|
||||||
|
self.num_cached_tokens = num_cached_tokens
|
||||||
self.finished = finished
|
self.finished = finished
|
||||||
self.outputs = outputs
|
self.outputs = outputs
|
||||||
|
|
||||||
@ -217,6 +224,7 @@ class PoolingRequestOutput(Generic[_O]):
|
|||||||
f"{type(self).__name__}(request_id={self.request_id!r}, "
|
f"{type(self).__name__}(request_id={self.request_id!r}, "
|
||||||
f"outputs={self.outputs!r}, "
|
f"outputs={self.outputs!r}, "
|
||||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||||
|
f"num_cached_tokens={self.num_cached_tokens}, "
|
||||||
f"finished={self.finished})"
|
f"finished={self.finished})"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -255,6 +263,7 @@ class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):
|
|||||||
request_id=request_output.request_id,
|
request_id=request_output.request_id,
|
||||||
outputs=EmbeddingOutput.from_base(request_output.outputs),
|
outputs=EmbeddingOutput.from_base(request_output.outputs),
|
||||||
prompt_token_ids=request_output.prompt_token_ids,
|
prompt_token_ids=request_output.prompt_token_ids,
|
||||||
|
num_cached_tokens=request_output.num_cached_tokens,
|
||||||
finished=request_output.finished,
|
finished=request_output.finished,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -294,6 +303,7 @@ class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
|
|||||||
request_id=request_output.request_id,
|
request_id=request_output.request_id,
|
||||||
outputs=ClassificationOutput.from_base(request_output.outputs),
|
outputs=ClassificationOutput.from_base(request_output.outputs),
|
||||||
prompt_token_ids=request_output.prompt_token_ids,
|
prompt_token_ids=request_output.prompt_token_ids,
|
||||||
|
num_cached_tokens=request_output.num_cached_tokens,
|
||||||
finished=request_output.finished,
|
finished=request_output.finished,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -330,5 +340,6 @@ class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
|
|||||||
request_id=request_output.request_id,
|
request_id=request_output.request_id,
|
||||||
outputs=ScoringOutput.from_base(request_output.outputs),
|
outputs=ScoringOutput.from_base(request_output.outputs),
|
||||||
prompt_token_ids=request_output.prompt_token_ids,
|
prompt_token_ids=request_output.prompt_token_ids,
|
||||||
|
num_cached_tokens=request_output.num_cached_tokens,
|
||||||
finished=request_output.finished,
|
finished=request_output.finished,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -230,6 +230,7 @@ class RequestState:
|
|||||||
return PoolingRequestOutput(
|
return PoolingRequestOutput(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
outputs=first_output,
|
outputs=first_output,
|
||||||
|
num_cached_tokens=self.num_cached_tokens,
|
||||||
prompt_token_ids=self.prompt_token_ids,
|
prompt_token_ids=self.prompt_token_ids,
|
||||||
finished=finished,
|
finished=finished,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user