[V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC (#13949)

This commit is contained in:
afeldman-nm 2025-03-07 20:48:12 -05:00 committed by GitHub
parent 66e16a038e
commit ef64044079
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 291 additions and 161 deletions

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching.""" """Compare the with and without prefix caching."""
from typing import Optional
import pytest import pytest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
@ -15,7 +17,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
def make_request(request_id, def make_request(request_id,
prompt_token_ids, prompt_token_ids,
mm_positions=None, mm_positions=None,
mm_hashes=None): mm_hashes=None,
prompt_logprobs: Optional[int] = None):
if mm_positions is None: if mm_positions is None:
multi_modal_inputs = None multi_modal_inputs = None
else: else:
@ -28,7 +31,8 @@ def make_request(request_id,
multi_modal_inputs=multi_modal_inputs, multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes, multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions, multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17,
prompt_logprobs=prompt_logprobs),
eos_token_id=100, eos_token_id=100,
arrival_time=0, arrival_time=0,
lora_request=None, lora_request=None,
@ -144,6 +148,110 @@ def test_prefill():
assert manager.block_pool.free_block_queue.free_list_tail is None assert manager.block_pool.free_block_queue.free_list_tail is None
def test_prefill_plp():
'''Test prefill with APC and some prompt logprobs (plp) requests.
1. Schedule plp request and validate APC block allocation
2. Schedule non-plp request and validate blocks
3. Schedule plp request; no hit should occur; validate blocks
'''
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
# Request #0 is a prompt logprobs request
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks]
# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (3, 4):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
# Request #1 is a non-prompt-logprobs request:
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
for block in computed_blocks:
assert block.ref_cnt == 2
# At this point, we should have 3 free blocks left.
assert manager.block_pool.free_block_queue.num_free_blocks == 3
manager.free(req0)
manager.free(req1)
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
unique_token_ids = [3] * 6
req2 = make_request("2",
common_token_ids + unique_token_ids,
prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55, computed_blocks)
block_ids = [b.block_id for b in blocks]
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes
assert block_ids != [0, 1, 2, 3, 4]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
for block_id in block_ids:
assert manager.block_pool.blocks[block_id].ref_cnt == 1
manager.free(req2)
def test_decode(): def test_decode():
manager = KVCacheManager( manager = KVCacheManager(
block_size=16, block_size=16,

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional from typing import Optional
import pytest
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -16,7 +18,21 @@ def create_scheduler(
model: str = "facebook/opt-125m", model: str = "facebook/opt-125m",
max_num_seqs: int = 16, max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192, max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
) -> Scheduler: ) -> Scheduler:
'''Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
:class:`Scheduler` instance
'''
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
@ -31,11 +47,16 @@ def create_scheduler(
dtype="float16", dtype="float16",
seed=42, seed=42,
) )
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=16, block_size=16,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
swap_space=0, swap_space=0,
cache_dtype="auto", cache_dtype="auto",
**kwargs_cache,
) )
vllm_config = VllmConfig( vllm_config = VllmConfig(
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
@ -54,16 +75,16 @@ def create_scheduler(
) )
def create_requests( def create_requests(num_requests: int,
num_requests: int,
num_tokens: int = 10, num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None, mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16, max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None, stop_token_ids: Optional[list[int]] = None,
): prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(ignore_eos=False, sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens, max_tokens=max_tokens,
stop_token_ids=stop_token_ids) stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = [] requests = []
for i in range(num_requests): for i in range(num_requests):
if mm_positions is not None: if mm_positions is not None:
@ -122,9 +143,18 @@ def test_get_num_unfinished_requests():
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
def test_schedule(): @pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
scheduler = create_scheduler() (None, None),
requests = create_requests(num_requests=10) (True, 5),
])
def test_schedule(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
requests = create_requests(num_requests=10,
prompt_logprobs=prompt_logprobs)
for request in requests: for request in requests:
scheduler.add_request(request) scheduler.add_request(request)
@ -427,14 +457,21 @@ def test_stop_via_update_from_output():
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
def test_schedule_concurrent_batches(): @pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
(None, None),
(True, 5),
])
def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
scheduler = create_scheduler( scheduler = create_scheduler(
max_num_batched_tokens=1024, max_num_batched_tokens=1024,
max_num_seqs=2, max_num_seqs=2,
enable_prefix_caching=enable_prefix_caching,
) )
requests = create_requests( requests = create_requests(
num_requests=2, num_requests=2,
num_tokens=512, num_tokens=512,
prompt_logprobs=prompt_logprobs,
) )
# Schedule the first request. # Schedule the first request.

View File

@ -6,7 +6,6 @@ from typing import Optional
import pytest import pytest
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import SamplingParams from vllm import SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
@ -72,41 +71,6 @@ async def generate(engine: AsyncLLM,
return count, request_id return count, request_id
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_async_llm_refuses_prompt_logprobs_with_apc(
monkeypatch, output_kind: RequestOutputKind):
"""Test passes if AsyncLLM raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
# TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a
# better way to test V1 so that in the future when we switch, we don't
# have to change all the tests.
monkeypatch.setenv("VLLM_USE_V1", "1")
# Create AsyncLLM engine with APC
apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m",
enable_prefix_caching=True,
gpu_memory_utilization=0.8,
disable_log_requests=True)
engine = AsyncLLM.from_engine_args(apc_engine_args)
try:
with pytest.raises(ValueError) as excinfo:
# Issue a request with prompt logprobs enabled, which should fail
await asyncio.create_task(
generate(engine,
"request-0",
TEXT_PROMPT,
output_kind,
10,
prompt_logprobs=5))
# Validate exception string is correct
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
finally:
# Shut down engine
engine.shutdown()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("engine_args_and_prompt", @pytest.mark.parametrize("engine_args_and_prompt",

View File

@ -5,7 +5,6 @@ from typing import Optional
import pytest import pytest
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
MODEL = "facebook/opt-125m" MODEL = "facebook/opt-125m"
@ -98,17 +97,3 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
raise AssertionError( raise AssertionError(
f"{len(completion_counts)} unique completions; expected" f"{len(completion_counts)} unique completions; expected"
f" {n}. Repeats: {repeats}") f" {n}. Repeats: {repeats}")
def test_llm_engine_refuses_prompt_logprobs_with_apc(vllm_model_apc):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
model: LLM = vllm_model_apc.model
with pytest.raises(ValueError) as excinfo:
model.generate(
"Hello, my name is",
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))
# Validate exception string is correct
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG

View File

@ -30,9 +30,6 @@ FULL_STRINGS = [
STOP_STRINGS = ["I love working on", "company by far", "brother in"] STOP_STRINGS = ["I love working on", "company by far", "brother in"]
PROMPT_LEN = 5 PROMPT_LEN = 5
PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")
random.seed(42) random.seed(42)

View File

@ -1,24 +1,34 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import itertools import itertools
from collections.abc import Generator
import pytest import pytest
import torch import torch
from tests.kernels.utils import override_backend_env_variable from tests.kernels.utils import override_backend_env_variable
from tests.v1.sample.utils import ( from tests.v1.sample.utils import (
BatchLogprobsComposition, BatchLogprobsSpecType,
assert_incr_detok_str_matches_non_incr_detok_str, assert_incr_detok_str_matches_non_incr_detok_str,
compute_correct_cumulative_logprob, get_test_batch) compute_correct_cumulative_logprob, get_test_batch)
from vllm import SamplingParams from vllm import SamplingParams
from ...conftest import VllmRunner from ...conftest import HfRunner, VllmRunner
MODEL = "meta-llama/Llama-3.2-1B-Instruct" MODEL = "meta-llama/Llama-3.2-1B-Instruct"
DTYPE = "half" DTYPE = "half"
NONE = BatchLogprobsComposition.NONE
SAMPLE = BatchLogprobsComposition.SAMPLE
PROMPT = BatchLogprobsComposition.PROMPT
SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
@pytest.fixture(scope="module")
def vllm_model(vllm_runner): @pytest.fixture(
scope="module",
# Parameterize APC
params=[False, True])
def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
with vllm_runner( with vllm_runner(
MODEL, MODEL,
dtype=DTYPE, dtype=DTYPE,
@ -31,22 +41,22 @@ def vllm_model(vllm_runner):
enforce_eager=True, enforce_eager=True,
#TODO: enable this once we support it for #TODO: enable this once we support it for
# prompt logprobs. # prompt logprobs.
enable_prefix_caching=False, enable_prefix_caching=request.param,
gpu_memory_utilization=0.5, gpu_memory_utilization=0.5,
) as vllm_model: ) as vllm_model:
yield vllm_model yield vllm_model
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def hf_model(hf_runner): def hf_model(hf_runner) -> Generator[HfRunner, None, None]:
with hf_runner(MODEL, dtype=DTYPE) as hf_model: with hf_runner(MODEL, dtype=DTYPE) as hf_model:
yield hf_model yield hf_model
def _repeat_logprob_config( def _repeat_logprob_config(
test_prompts, test_prompts,
logprob_prompt_logprob_list: list[tuple], logprob_prompt_logprob_list: BatchLogprobsSpecType,
) -> list[tuple]: ) -> BatchLogprobsSpecType:
"""Ensure each test prompt has a logprob config. """Ensure each test prompt has a logprob config.
A logprob config specifies the optional (i.e. A logprob config specifies the optional (i.e.
@ -91,42 +101,17 @@ def _repeat_logprob_config(
return logprob_prompt_logprob_list return logprob_prompt_logprob_list
def _test_case_get_logprobs_and_prompt_logprobs( def _run_and_validate(
hf_model, vllm_model: VllmRunner,
vllm_model, test_prompts: list[str],
batch_logprobs_composition: str, vllm_sampling_params: SamplingParams,
hf_logprobs: list[list[torch.Tensor]],
hf_outputs: list[tuple[list[int], str]],
logprob_prompt_logprob_list: BatchLogprobsSpecType,
temperature: float, temperature: float,
example_prompts, max_tokens: int,
do_apc: bool,
) -> None: ) -> None:
test_prompts = example_prompts
max_tokens = 5
hf_outputs = hf_model.generate_greedy(
test_prompts,
max_tokens=max_tokens,
)
hf_logprobs = hf_model.generate_greedy_logprobs(
test_prompts,
max_tokens=max_tokens,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
test_prompts, logprob_prompt_logprob_list)
# Generate SamplingParams
vllm_sampling_params = [
SamplingParams(max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984)
for num_lp, num_plp in logprob_prompt_logprob_list
]
vllm_results = vllm_model.model.generate( vllm_results = vllm_model.model.generate(
test_prompts, sampling_params=vllm_sampling_params) test_prompts, sampling_params=vllm_sampling_params)
@ -267,14 +252,13 @@ def _test_case_get_logprobs_and_prompt_logprobs(
assert vllm_result.prompt_logprobs is None assert vllm_result.prompt_logprobs is None
#@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("batch_logprobs_composition", @pytest.mark.parametrize("batch_logprobs_composition",
["NONE", "SAMPLE", "PROMPT", "SAMPLE_PROMPT"]) [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT])
@pytest.mark.parametrize("temperature", [0.0, 2.0]) @pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs( def test_get_logprobs_and_prompt_logprobs(
hf_model, hf_model,
vllm_model, vllm_model,
batch_logprobs_composition: str, batch_logprobs_composition: BatchLogprobsComposition,
temperature: float, temperature: float,
example_prompts, example_prompts,
) -> None: ) -> None:
@ -292,19 +276,62 @@ def test_get_logprobs_and_prompt_logprobs(
batch_logprobs_composition controls the logprobs configurations for batch_logprobs_composition controls the logprobs configurations for
requests in the batch under test. requests in the batch under test.
APC tests run two test iterations so that cache hits occur.
To save time, only test one APC-enabled scenario
(sample & prompt logprobs enabled, temperature>0.0).
Args: Args:
hf_model hf_model: HuggingFace reference model fixture
vllm_model vllm_model: vLLM model fixture
batch_logprobs_composition: logprobs configuration for test batch batch_logprobs_composition: logprobs configuration for test batch
example_prompts temperature: "temperature" sampling parameter
monkeypatch example_prompts: example prompt fixture
""" """
_test_case_get_logprobs_and_prompt_logprobs( do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching
hf_model=hf_model, if do_apc and (temperature < 2.0
vllm_model=vllm_model, or batch_logprobs_composition != SAMPLE_PROMPT):
batch_logprobs_composition=batch_logprobs_composition, # Skip some test-cases to save time.
pytest.skip()
test_prompts = example_prompts
max_tokens = 5
hf_outputs = hf_model.generate_greedy(
test_prompts,
max_tokens=max_tokens,
)
hf_logprobs = hf_model.generate_greedy_logprobs(
test_prompts,
max_tokens=max_tokens,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
test_prompts, logprob_prompt_logprob_list)
# Generate SamplingParams
vllm_sampling_params = [
SamplingParams(max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature, temperature=temperature,
example_prompts=example_prompts) seed=1984)
for num_lp, num_plp in logprob_prompt_logprob_list
]
for _ in range(2 if do_apc else 1):
_run_and_validate(
vllm_model=vllm_model,
test_prompts=test_prompts,
vllm_sampling_params=vllm_sampling_params,
hf_logprobs=hf_logprobs,
hf_outputs=hf_outputs,
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
temperature=temperature,
max_tokens=max_tokens,
do_apc=do_apc)
def test_max_logprobs(monkeypatch): def test_max_logprobs(monkeypatch):
@ -312,6 +339,8 @@ def test_max_logprobs(monkeypatch):
Should also fail for `prompt_logprobs > max_logprobs` Should also fail for `prompt_logprobs > max_logprobs`
APC should not matter as this test checks basic request validation.
Args: Args:
monkeypatch monkeypatch
""" """
@ -330,14 +359,12 @@ def test_max_logprobs(monkeypatch):
runner.generate(["Hello world"], sampling_params=bad_sampling_params) runner.generate(["Hello world"], sampling_params=bad_sampling_params)
def test_none_logprobs(vllm_model, example_prompts, monkeypatch): def test_none_logprobs(vllm_model, example_prompts):
"""Engine should return `logprobs` and `prompt_logprobs` as `None` """Engine should return `logprobs` and `prompt_logprobs` as `None`
Args: Args:
vllm_model: vLLM model fixture vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture) example_prompts: list of example prompts (test fixture)
monkeypatch: supports editing env vars and rolling back changes
after the test
""" """
max_tokens = 5 max_tokens = 5
@ -356,14 +383,12 @@ def test_none_logprobs(vllm_model, example_prompts, monkeypatch):
assert results_logprobs_none[i].prompt_logprobs is None assert results_logprobs_none[i].prompt_logprobs is None
def test_zero_logprobs(vllm_model, example_prompts, monkeypatch): def test_zero_logprobs(vllm_model, example_prompts):
"""Engine should return sampled token and prompt token logprobs """Engine should return sampled token and prompt token logprobs
Args: Args:
vllm_model: vLLM model fixture vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture) example_prompts: list of example prompts (test fixture)
monkeypatch: supports editing env vars and rolling back changes
after the test
""" """
max_tokens = 5 max_tokens = 5

View File

@ -1,27 +1,42 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re import re
from enum import Enum
from typing import Optional
from vllm import CompletionOutput from vllm import CompletionOutput
def get_test_batch(batch_logprobs_composition: str) -> list[tuple]: class BatchLogprobsComposition(Enum):
"""Types of logprobs configs to include in test batch"""
NONE = 0
SAMPLE = 1
PROMPT = 2
SAMPLE_PROMPT = 3
BatchLogprobsSpecType = list[tuple[Optional[int], Optional[int]]]
def get_test_batch(
batch_logprobs_composition: BatchLogprobsComposition
) -> BatchLogprobsSpecType:
"""Generate logprobs configs for a batch of requests """Generate logprobs configs for a batch of requests
A given request's logprobs configuration is (1) num_sample_logprobs and (2) A given request's logprobs configuration is (1) num_sample_logprobs and (2)
num_prompt_logprobs. The batch logprobs configuration is the list of request num_prompt_logprobs. The batch logprobs configuration is the list of request
logprobs configs. logprobs configs.
batch_logprobs_composition == "NONE" yields a batch with no sample or prompt batch_logprobs_composition == NONE yields a batch with no sample or prompt
logprobs logprobs
batch_logprobs_composition == "SAMPLE" yields a batch with some requests batch_logprobs_composition == SAMPLE yields a batch with some requests
configured for sample logprobs only, and others configured for no logprobs configured for sample logprobs only, and others configured for no logprobs
batch_logprobs_composition == "PROMPT" yields a batch with some requests batch_logprobs_composition == PROMPT yields a batch with some requests
configured for prompt logprobs only, and others configured for no logprobs configured for prompt logprobs only, and others configured for no logprobs
batch_logprobs_composition == "SAMPLE_PROMPT" yields a batch with some batch_logprobs_composition == SAMPLE_PROMPT yields a batch with some
requests configured for sample logprobs and prompt logprobs, some configured requests configured for sample logprobs and prompt logprobs, some configured
for only sample logprobs or only prompt logprobs, and some configured for for only sample logprobs or only prompt logprobs, and some configured for
no logprobs no logprobs
@ -34,10 +49,10 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
list of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs]) list of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
tuples tuples
""" """
if batch_logprobs_composition == "NONE": if batch_logprobs_composition == BatchLogprobsComposition.NONE:
# No requests with sample or prompt logprobs # No requests with sample or prompt logprobs
return [(None, None)] return [(None, None)]
elif batch_logprobs_composition == "SAMPLE": elif batch_logprobs_composition == BatchLogprobsComposition.SAMPLE:
# Requests requiring sample logprobs or no logprobs # Requests requiring sample logprobs or no logprobs
return [ return [
(None, None), (None, None),
@ -45,7 +60,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
(5, None), (5, None),
(3, None), (3, None),
] ]
elif batch_logprobs_composition == "PROMPT": elif batch_logprobs_composition == BatchLogprobsComposition.PROMPT:
# Requests requiring prompt logprobs or no logprobs # Requests requiring prompt logprobs or no logprobs
return [ return [
(None, None), (None, None),
@ -53,7 +68,7 @@ def get_test_batch(batch_logprobs_composition: str) -> list[tuple]:
(None, 6), (None, 6),
(None, 5), (None, 5),
] ]
elif batch_logprobs_composition == "SAMPLE_PROMPT": elif batch_logprobs_composition == BatchLogprobsComposition.SAMPLE_PROMPT:
# Requests requiring either no logprobs, just # Requests requiring either no logprobs, just
# sample logprobs, just prompt logprobs, or # sample logprobs, just prompt logprobs, or
# both sample and prompt logprobs # both sample and prompt logprobs

View File

@ -105,8 +105,6 @@ class KVCacheManager:
# Prefix caching is disabled. # Prefix caching is disabled.
return [], 0 return [], 0
computed_blocks = []
# The block hashes for the request may already be computed # The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before. # if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id] block_hashes = self.req_to_block_hashes[request.request_id]
@ -114,16 +112,20 @@ class KVCacheManager:
block_hashes = hash_request_tokens(self.block_size, request) block_hashes = hash_request_tokens(self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes self.req_to_block_hashes[request.request_id] = block_hashes
self.prefix_cache_stats.requests += 1
if request.sampling_params.prompt_logprobs is None:
# Check for cache hits
computed_blocks = []
for block_hash in block_hashes: for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not # block_hashes is a chain of block hashes. If a block hash
# in the cached_block_hash_to_id, the following block hashes are # is not in the cached_block_hash_to_id, the following
# not computed yet for sure. # block hashes are not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(block_hash): if cached_block := self.block_pool.get_cached_block(
block_hash):
computed_blocks.append(cached_block) computed_blocks.append(cached_block)
else: else:
break break
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += len(block_hashes) self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks) self.prefix_cache_stats.hits += len(computed_blocks)
@ -132,6 +134,9 @@ class KVCacheManager:
# `block_size`. # `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens return computed_blocks, num_computed_tokens
else:
# Skip cache hits for prompt logprobs
return [], 0
def allocate_slots( def allocate_slots(
self, self,

View File

@ -72,12 +72,6 @@ class Processor:
f"Requested prompt logprobs of {params.prompt_logprobs}, " f"Requested prompt logprobs of {params.prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}") f"which is greater than max allowed: {max_logprobs}")
# TODO(andy): enable this in follow up by recomputing.
if (params.prompt_logprobs is not None
and self.cache_config.enable_prefix_caching):
raise ValueError("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")
def _validate_sampling_params( def _validate_sampling_params(
self, self,
params: SamplingParams, params: SamplingParams,