[BUGFIX]: return fast when request requires prompt logprobs (#17251)

This commit is contained in:
Ning Xie 2025-05-09 12:25:41 +08:00 committed by GitHub
parent 5e6f939484
commit d310e6de98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 7 deletions

View File

@ -194,7 +194,7 @@ def test_prefill_plp():
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
@ -256,7 +256,7 @@ def test_prefill_plp():
common_token_ids + unique_token_ids,
prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55, computed_blocks)

View File

@ -126,8 +126,11 @@ class KVCacheManager:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
if not self.enable_caching:
# Prefix caching is disabled.
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
or request.sampling_params.prompt_logprobs is not None):
return KVCacheBlocks.create_empty(), 0
# The block hashes for the request may already be computed
@ -141,9 +144,6 @@ class KVCacheManager:
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
# When the request requires prompt logprobs, we skip prefix caching.
if request.sampling_params.prompt_logprobs is not None:
return KVCacheBlocks.create_empty(), 0
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all