mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 02:49:37 +08:00
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
58 lines
1.9 KiB
Python
58 lines
1.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import TokensPrompt
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
["Qwen/Qwen3-0.6B"],
|
|
)
|
|
@torch.inference_mode
|
|
def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
|
|
n_prompt_tokens = [55, 56, 57]
|
|
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
|
|
|
|
with vllm_runner(
|
|
model,
|
|
max_model_len=128,
|
|
enforce_eager=True,
|
|
runner="pooling",
|
|
enable_prefix_caching=True,
|
|
) as vllm_model:
|
|
pooling_outputs = vllm_model.llm.encode(
|
|
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
|
pooling_task="token_embed",
|
|
)
|
|
|
|
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
|
assert len(output.prompt_token_ids) == n
|
|
assert len(output.outputs.data) == n
|
|
assert output.num_cached_tokens == 0
|
|
|
|
# test enable_prefix_caching plus all pooling
|
|
# we need to skip reading cache at this request by
|
|
# request.skip_reading_prefix_cache
|
|
pooling_outputs = vllm_model.llm.encode(
|
|
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
|
pooling_task="token_embed",
|
|
)
|
|
|
|
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
|
assert len(output.prompt_token_ids) == n
|
|
assert len(output.outputs.data) == n
|
|
assert output.num_cached_tokens == 0
|
|
|
|
# skip_reading_prefix_cache can still write to cache
|
|
# to accelerate following requests
|
|
pooling_outputs = vllm_model.llm.encode(
|
|
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
|
|
pooling_task="embed",
|
|
)
|
|
|
|
for n, output in zip(n_prompt_tokens, pooling_outputs):
|
|
assert len(output.prompt_token_ids) == n
|
|
assert output.num_cached_tokens > 0
|