mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC (#13949)
This commit is contained in:
parent
66e16a038e
commit
ef64044079
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Compare the with and without prefix caching."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
@ -15,7 +17,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
mm_positions=None,
|
||||
mm_hashes=None):
|
||||
mm_hashes=None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
if mm_positions is None:
|
||||
multi_modal_inputs = None
|
||||
else:
|
||||
@ -28,7 +31,8 @@ def make_request(request_id,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
sampling_params=SamplingParams(max_tokens=17,
|
||||
prompt_logprobs=prompt_logprobs),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
@ -144,6 +148,110 @@ def test_prefill():
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
|
||||
|
||||
def test_prefill_plp():
|
||||
'''Test prefill with APC and some prompt logprobs (plp) requests.
|
||||
|
||||
1. Schedule plp request and validate APC block allocation
|
||||
2. Schedule non-plp request and validate blocks
|
||||
3. Schedule plp request; no hit should occur; validate blocks
|
||||
'''
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
|
||||
# Request #0 is a prompt logprobs request
|
||||
# Fully cache miss
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||
req0_block_hashes = [b.block_hash for b in blocks]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
# Check partial/preallocated block metadata
|
||||
for block_id in (3, 4):
|
||||
assert manager.block_pool.blocks[block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
# Request #1 is a non-prompt-logprobs request:
|
||||
# Cache hit in the common prefix when the original block is still in use.
|
||||
# Incomplete 1 block (5 tokens)
|
||||
unique_token_ids = [3] * 5
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [5, 6]
|
||||
for block in computed_blocks:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# At this point, we should have 3 free blocks left.
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 3
|
||||
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
# All blocks should be available.
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 10
|
||||
# The order should be
|
||||
# [unallocated (7, 8, 9)]
|
||||
# [unique_req0 (4, 3)]
|
||||
# [unique_req1 (6, 5)]
|
||||
# [common (2, 1, 0)]
|
||||
assert [
|
||||
b.block_id
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
|
||||
|
||||
# Request #2 is a prompt-logprobs request:
|
||||
# NO cache hit in the common prefix; duplicates request #0 cached blocks
|
||||
unique_token_ids = [3] * 6
|
||||
req2 = make_request("2",
|
||||
common_token_ids + unique_token_ids,
|
||||
prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 55, computed_blocks)
|
||||
block_ids = [b.block_id for b in blocks]
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks] == req0_block_hashes
|
||||
assert block_ids != [0, 1, 2, 3, 4]
|
||||
|
||||
# Request #2 block hashes are valid since request #0 hashes are.
|
||||
# Check block reference counts.
|
||||
for block_id in block_ids:
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
manager.free(req2)
|
||||
|
||||
|
||||
def test_decode():
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -16,7 +18,21 @@ def create_scheduler(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
enable_prefix_caching: Optional[bool] = None,
|
||||
) -> Scheduler:
|
||||
'''Create scheduler under test.
|
||||
|
||||
Args:
|
||||
model: model under test
|
||||
max_num_seqs: max sequences to schedule
|
||||
max_num_batch_tokens: max num tokens to batch
|
||||
enable_prefix_caching: optionally force APC config
|
||||
(True/False) or use default
|
||||
(None)
|
||||
|
||||
Returns:
|
||||
:class:`Scheduler` instance
|
||||
'''
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
@ -31,11 +47,16 @@ def create_scheduler(
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||
'enable_prefix_caching': enable_prefix_caching
|
||||
})
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
**kwargs_cache,
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
@ -54,16 +75,16 @@ def create_scheduler(
|
||||
)
|
||||
|
||||
|
||||
def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
):
|
||||
def create_requests(num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids)
|
||||
stop_token_ids=stop_token_ids,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
@ -122,9 +143,18 @@ def test_get_num_unfinished_requests():
|
||||
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
|
||||
|
||||
|
||||
def test_schedule():
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
|
||||
(None, None),
|
||||
(True, 5),
|
||||
])
|
||||
def test_schedule(enable_prefix_caching: Optional[bool],
|
||||
prompt_logprobs: Optional[int]):
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
'''
|
||||
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
|
||||
requests = create_requests(num_requests=10,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
@ -427,14 +457,21 @@ def test_stop_via_update_from_output():
|
||||
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
||||
|
||||
|
||||
def test_schedule_concurrent_batches():
|
||||
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
|
||||
(None, None),
|
||||
(True, 5),
|
||||
])
|
||||
def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
prompt_logprobs: Optional[int]):
|
||||
scheduler = create_scheduler(
|
||||
max_num_batched_tokens=1024,
|
||||
max_num_seqs=2,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
requests = create_requests(
|
||||
num_requests=2,
|
||||
num_tokens=512,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
|
||||
# Schedule the first request.
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
|
||||
from vllm import SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@ -72,41 +71,6 @@ async def generate(engine: AsyncLLM,
|
||||
return count, request_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_llm_refuses_prompt_logprobs_with_apc(
|
||||
monkeypatch, output_kind: RequestOutputKind):
|
||||
"""Test passes if AsyncLLM raises an exception when it is configured
|
||||
for automatic prefix caching and it receives a request with
|
||||
prompt_logprobs enabled, which is incompatible."""
|
||||
# TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a
|
||||
# better way to test V1 so that in the future when we switch, we don't
|
||||
# have to change all the tests.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
# Create AsyncLLM engine with APC
|
||||
apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m",
|
||||
enable_prefix_caching=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
disable_log_requests=True)
|
||||
engine = AsyncLLM.from_engine_args(apc_engine_args)
|
||||
try:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
# Issue a request with prompt logprobs enabled, which should fail
|
||||
await asyncio.create_task(
|
||||
generate(engine,
|
||||
"request-0",
|
||||
TEXT_PROMPT,
|
||||
output_kind,
|
||||
10,
|
||||
prompt_logprobs=5))
|
||||
# Validate exception string is correct
|
||||
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
|
||||
finally:
|
||||
# Shut down engine
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.parametrize("engine_args_and_prompt",
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
@ -98,17 +97,3 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
|
||||
raise AssertionError(
|
||||
f"{len(completion_counts)} unique completions; expected"
|
||||
f" {n}. Repeats: {repeats}")
|
||||
|
||||
|
||||
def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc):
|
||||
"""Test passes if LLMEngine raises an exception when it is configured
|
||||
for automatic prefix caching and it receives a request with
|
||||
prompt_logprobs enabled, which is incompatible."""
|
||||
model: LLM = vllm_model_apc.model
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
model.generate(
|
||||
"Hello, my name is",
|
||||
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))
|
||||
|
||||
# Validate exception string is correct
|
||||
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
|
||||
|
||||
@ -30,9 +30,6 @@ FULL_STRINGS = [
|
||||
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
|
||||
PROMPT_LEN = 5
|
||||
|
||||
PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet "
|
||||
"supported on VLLM V1.")
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
|
||||
@ -1,24 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from tests.v1.sample.utils import (
|
||||
BatchLogprobsComposition, BatchLogprobsSpecType,
|
||||
assert_incr_detok_str_matches_non_incr_detok_str,
|
||||
compute_correct_cumulative_logprob, get_test_batch)
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ...conftest import VllmRunner
|
||||
from ...conftest import HfRunner, VllmRunner
|
||||
|
||||
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
DTYPE = "half"
|
||||
|
||||
NONE = BatchLogprobsComposition.NONE
|
||||
SAMPLE = BatchLogprobsComposition.SAMPLE
|
||||
PROMPT = BatchLogprobsComposition.PROMPT
|
||||
SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vllm_model(vllm_runner):
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
# Parameterize APC
|
||||
params=[False, True])
|
||||
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
|
||||
with vllm_runner(
|
||||
MODEL,
|
||||
dtype=DTYPE,
|
||||
@ -31,22 +41,22 @@ def vllm_model(vllm_runner):
|
||||
enforce_eager=True,
|
||||
#TODO: enable this once we support it for
|
||||
# prompt logprobs.
|
||||
enable_prefix_caching=False,
|
||||
enable_prefix_caching=request.param,
|
||||
gpu_memory_utilization=0.5,
|
||||
) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def hf_model(hf_runner):
|
||||
def hf_model(hf_runner) -> Generator[HfRunner, None, None]:
|
||||
with hf_runner(MODEL, dtype=DTYPE) as hf_model:
|
||||
yield hf_model
|
||||
|
||||
|
||||
def _repeat_logprob_config(
|
||||
test_prompts,
|
||||
logprob_prompt_logprob_list: list[tuple],
|
||||
) -> list[tuple]:
|
||||
logprob_prompt_logprob_list: BatchLogprobsSpecType,
|
||||
) -> BatchLogprobsSpecType:
|
||||
"""Ensure each test prompt has a logprob config.
|
||||
|
||||
A logprob config specifies the optional (i.e.
|
||||
@ -91,42 +101,17 @@ def _repeat_logprob_config(
|
||||
return logprob_prompt_logprob_list
|
||||
|
||||
|
||||
def _test_case_get_logprobs_and_prompt_logprobs(
|
||||
hf_model,
|
||||
vllm_model,
|
||||
batch_logprobs_composition: str,
|
||||
def _run_and_validate(
|
||||
vllm_model: VllmRunner,
|
||||
test_prompts: list[str],
|
||||
vllm_sampling_params: SamplingParams,
|
||||
hf_logprobs: list[list[torch.Tensor]],
|
||||
hf_outputs: list[tuple[list[int], str]],
|
||||
logprob_prompt_logprob_list: BatchLogprobsSpecType,
|
||||
temperature: float,
|
||||
example_prompts,
|
||||
max_tokens: int,
|
||||
do_apc: bool,
|
||||
) -> None:
|
||||
test_prompts = example_prompts
|
||||
|
||||
max_tokens = 5
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
test_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
test_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Batch has mixed sample params
|
||||
# (different logprobs/prompt logprobs combos)
|
||||
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
|
||||
|
||||
# Ensure that each test prompt has a logprob config for testing
|
||||
logprob_prompt_logprob_list = _repeat_logprob_config(
|
||||
test_prompts, logprob_prompt_logprob_list)
|
||||
# Generate SamplingParams
|
||||
vllm_sampling_params = [
|
||||
SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_lp,
|
||||
prompt_logprobs=num_plp,
|
||||
temperature=temperature,
|
||||
seed=1984)
|
||||
for num_lp, num_plp in logprob_prompt_logprob_list
|
||||
]
|
||||
|
||||
vllm_results = vllm_model.model.generate(
|
||||
test_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
@ -267,14 +252,13 @@ def _test_case_get_logprobs_and_prompt_logprobs(
|
||||
assert vllm_result.prompt_logprobs is None
|
||||
|
||||
|
||||
#@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("batch_logprobs_composition",
|
||||
["NONE", "SAMPLE", "PROMPT", "SAMPLE_PROMPT"])
|
||||
[NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
|
||||
@pytest.mark.parametrize("temperature", [0.0, 2.0])
|
||||
def test_get_logprobs_and_prompt_logprobs(
|
||||
hf_model,
|
||||
vllm_model,
|
||||
batch_logprobs_composition: str,
|
||||
batch_logprobs_composition: BatchLogprobsComposition,
|
||||
temperature: float,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
@ -292,25 +276,70 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
batch_logprobs_composition controls the logprobs configurations for
|
||||
requests in the batch under test.
|
||||
|
||||
APC tests run two test iterations so that cache hits occur.
|
||||
|
||||
To save time, only test one APC-enabled scenario
|
||||
(sample & prompt logprobs enabled, temperature>0.0).
|
||||
|
||||
Args:
|
||||
hf_model
|
||||
vllm_model
|
||||
hf_model: HuggingFace reference model fixture
|
||||
vllm_model: vLLM model fixture
|
||||
batch_logprobs_composition: logprobs configuration for test batch
|
||||
example_prompts
|
||||
monkeypatch
|
||||
temperature: "temperature" sampling parameter
|
||||
example_prompts: example prompt fixture
|
||||
"""
|
||||
_test_case_get_logprobs_and_prompt_logprobs(
|
||||
hf_model=hf_model,
|
||||
vllm_model=vllm_model,
|
||||
batch_logprobs_composition=batch_logprobs_composition,
|
||||
temperature=temperature,
|
||||
example_prompts=example_prompts)
|
||||
do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching
|
||||
if do_apc and (temperature < 2.0
|
||||
or batch_logprobs_composition != SAMPLE_PROMPT):
|
||||
# Skip some test-cases to save time.
|
||||
pytest.skip()
|
||||
test_prompts = example_prompts
|
||||
|
||||
max_tokens = 5
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
test_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
test_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Batch has mixed sample params
|
||||
# (different logprobs/prompt logprobs combos)
|
||||
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
|
||||
|
||||
# Ensure that each test prompt has a logprob config for testing
|
||||
logprob_prompt_logprob_list = _repeat_logprob_config(
|
||||
test_prompts, logprob_prompt_logprob_list)
|
||||
# Generate SamplingParams
|
||||
vllm_sampling_params = [
|
||||
SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_lp,
|
||||
prompt_logprobs=num_plp,
|
||||
temperature=temperature,
|
||||
seed=1984)
|
||||
for num_lp, num_plp in logprob_prompt_logprob_list
|
||||
]
|
||||
for _ in range(2 if do_apc else 1):
|
||||
_run_and_validate(
|
||||
vllm_model=vllm_model,
|
||||
test_prompts=test_prompts,
|
||||
vllm_sampling_params=vllm_sampling_params,
|
||||
hf_logprobs=hf_logprobs,
|
||||
hf_outputs=hf_outputs,
|
||||
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
do_apc=do_apc)
|
||||
|
||||
|
||||
def test_max_logprobs(monkeypatch):
|
||||
"""vLLM v1 engine should fail a request with `logprobs > max_logprobs`
|
||||
|
||||
Should also fail for `prompt_logprobs > max_logprobs`
|
||||
|
||||
APC should not matter as this test checks basic request validation.
|
||||
|
||||
Args:
|
||||
monkeypatch
|
||||
@ -330,14 +359,12 @@ def test_max_logprobs(monkeypatch):
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
|
||||
|
||||
def test_none_logprobs(vllm_model, example_prompts, monkeypatch):
|
||||
def test_none_logprobs(vllm_model, example_prompts):
|
||||
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
|
||||
|
||||
Args:
|
||||
vllm_model: vLLM model fixture
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
monkeypatch: supports editing env vars and rolling back changes
|
||||
after the test
|
||||
"""
|
||||
max_tokens = 5
|
||||
|
||||
@ -356,14 +383,12 @@ def test_none_logprobs(vllm_model, example_prompts, monkeypatch):
|
||||
assert results_logprobs_none[i].prompt_logprobs is None
|
||||
|
||||
|
||||
def test_zero_logprobs(vllm_model, example_prompts, monkeypatch):
|
||||
def test_zero_logprobs(vllm_model, example_prompts):
|
||||
"""Engine should return sampled token and prompt token logprobs
|
||||
|
||||
Args:
|
||||
vllm_model: vLLM model fixture
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
monkeypatch: supports editing env vars and rolling back changes
|
||||
after the test
|
||||
"""
|
||||
max_tokens = 5
|
||||
|
||||
|
||||
@ -1,27 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from vllm import CompletionOutput
|
||||
|
||||
|
||||
def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
|
||||
class BatchLogprobsComposition(Enum):
|
||||
"""Types of logprobs configs to include in test batch"""
|
||||
NONE = 0
|
||||
SAMPLE = 1
|
||||
PROMPT = 2
|
||||
SAMPLE_PROMPT = 3
|
||||
|
||||
|
||||
BatchLogprobsSpecType = list[tuple[Optional[int], Optional[int]]]
|
||||
|
||||
|
||||
def get_test_batch(
|
||||
batch_logprobs_composition: BatchLogprobsComposition
|
||||
) -> BatchLogprobsSpecType:
|
||||
"""Generate logprobs configs for a batch of requests
|
||||
|
||||
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
|
||||
num_prompt_logprobs. The batch logprobs configuration is the list of request
|
||||
logprobs configs.
|
||||
|
||||
batch_logprobs_composition == "NONE" yields a batch with no sample or prompt
|
||||
batch_logprobs_composition == NONE yields a batch with no sample or prompt
|
||||
logprobs
|
||||
|
||||
batch_logprobs_composition == "SAMPLE" yields a batch with some requests
|
||||
batch_logprobs_composition == SAMPLE yields a batch with some requests
|
||||
configured for sample logprobs only, and others configured for no logprobs
|
||||
|
||||
batch_logprobs_composition == "PROMPT" yields a batch with some requests
|
||||
batch_logprobs_composition == PROMPT yields a batch with some requests
|
||||
configured for prompt logprobs only, and others configured for no logprobs
|
||||
|
||||
batch_logprobs_composition == "SAMPLE_PROMPT" yields a batch with some
|
||||
batch_logprobs_composition == SAMPLE_PROMPT yields a batch with some
|
||||
requests configured for sample logprobs and prompt logprobs, some configured
|
||||
for only sample logprobs or only prompt logprobs, and some configured for
|
||||
no logprobs
|
||||
@ -34,10 +49,10 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
|
||||
list of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
|
||||
tuples
|
||||
"""
|
||||
if batch_logprobs_composition == "NONE":
|
||||
if batch_logprobs_composition == BatchLogprobsComposition.NONE:
|
||||
# No requests with sample or prompt logprobs
|
||||
return [(None, None)]
|
||||
elif batch_logprobs_composition == "SAMPLE":
|
||||
elif batch_logprobs_composition == BatchLogprobsComposition.SAMPLE:
|
||||
# Requests requiring sample logprobs or no logprobs
|
||||
return [
|
||||
(None, None),
|
||||
@ -45,7 +60,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
|
||||
(5, None),
|
||||
(3, None),
|
||||
]
|
||||
elif batch_logprobs_composition == "PROMPT":
|
||||
elif batch_logprobs_composition == BatchLogprobsComposition.PROMPT:
|
||||
# Requests requiring prompt logprobs or no logprobs
|
||||
return [
|
||||
(None, None),
|
||||
@ -53,7 +68,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
|
||||
(None, 6),
|
||||
(None, 5),
|
||||
]
|
||||
elif batch_logprobs_composition == "SAMPLE_PROMPT":
|
||||
elif batch_logprobs_composition == BatchLogprobsComposition.SAMPLE_PROMPT:
|
||||
# Requests requiring either no logprobs, just
|
||||
# sample logprobs, just prompt logprobs, or
|
||||
# both sample and prompt logprobs
|
||||
|
||||
@ -105,8 +105,6 @@ class KVCacheManager:
|
||||
# Prefix caching is disabled.
|
||||
return [], 0
|
||||
|
||||
computed_blocks = []
|
||||
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
@ -114,24 +112,31 @@ class KVCacheManager:
|
||||
block_hashes = hash_request_tokens(self.block_size, request)
|
||||
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||
|
||||
for block_hash in block_hashes:
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
if cached_block := self.block_pool.get_cached_block(block_hash):
|
||||
computed_blocks.append(cached_block)
|
||||
else:
|
||||
break
|
||||
|
||||
self.prefix_cache_stats.requests += 1
|
||||
self.prefix_cache_stats.queries += len(block_hashes)
|
||||
self.prefix_cache_stats.hits += len(computed_blocks)
|
||||
if request.sampling_params.prompt_logprobs is None:
|
||||
# Check for cache hits
|
||||
computed_blocks = []
|
||||
for block_hash in block_hashes:
|
||||
# block_hashes is a chain of block hashes. If a block hash
|
||||
# is not in the cached_block_hash_to_id, the following
|
||||
# block hashes are not computed yet for sure.
|
||||
if cached_block := self.block_pool.get_cached_block(
|
||||
block_hash):
|
||||
computed_blocks.append(cached_block)
|
||||
else:
|
||||
break
|
||||
|
||||
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
||||
# sharing, `num_computed_tokens` is always a multiple of
|
||||
# `block_size`.
|
||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||
return computed_blocks, num_computed_tokens
|
||||
self.prefix_cache_stats.queries += len(block_hashes)
|
||||
self.prefix_cache_stats.hits += len(computed_blocks)
|
||||
|
||||
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
||||
# sharing, `num_computed_tokens` is always a multiple of
|
||||
# `block_size`.
|
||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||
return computed_blocks, num_computed_tokens
|
||||
else:
|
||||
# Skip cache hits for prompt logprobs
|
||||
return [], 0
|
||||
|
||||
def allocate_slots(
|
||||
self,
|
||||
|
||||
@ -72,12 +72,6 @@ class Processor:
|
||||
f"Requested prompt logprobs of {params.prompt_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}")
|
||||
|
||||
# TODO(andy): enable this in follow up by recomputing.
|
||||
if (params.prompt_logprobs is not None
|
||||
and self.cache_config.enable_prefix_caching):
|
||||
raise ValueError("Prefix caching with prompt logprobs not yet "
|
||||
"supported on VLLM V1.")
|
||||
|
||||
def _validate_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user