diff --git a/tests/engine/test_computed_prefix_blocks.py b/tests/engine/test_computed_prefix_blocks.py new file mode 100644 index 000000000000..ed35212cc3f1 --- /dev/null +++ b/tests/engine/test_computed_prefix_blocks.py @@ -0,0 +1,34 @@ +import pytest + +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.sampling_params import SamplingParams + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("block_size", [16]) +def test_computed_prefix_blocks(model: str, block_size: int): + # This test checks if we are able to run the engine to completion + # without triggering asserts. + # We are in a scenario where all blocks from the second request's prompt + # are full and already computed when the second request arrives. + prompt = ( + "You are a helpful assistant. How do I build a car from cardboard and " + "paper clips? Is there an easy to follow video tutorial available " + "online for free?") + prompt2 = ( + " Please recommend to me some resources where I can learn not only to " + "handle technical difficulties of building a car, but also " + "decoration.") + + engine_args = EngineArgs(model=model, + block_size=block_size, + enable_prefix_caching=True) + + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams() + + engine.add_request("0", prompt + prompt2, sampling_params) + engine.step() + engine.add_request("1", prompt, sampling_params) + engine.step() diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index daf83827a7e5..52b120f227ed 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -1,6 +1,6 @@ """A block manager that manages token blocks.""" import enum -from itertools import count +from itertools import count, takewhile from os.path import commonprefix from typing import Dict, List, Optional, Set, Tuple @@ -426,23 +426,29 @@ class BlockSpaceManager: for block in block_table: block.last_accessed = access_time - def compute_last_full_block_in_seq(self, seq: Sequence): + def compute_full_blocks_in_seq(self, seq: Sequence): if seq.seq_id not in self.block_tables: return max_full_block = seq.get_len() // self.block_size - 1 block_table = self.block_tables[seq.seq_id] if max_full_block == -1: return - block_table[max_full_block].computed = True + for i in reversed(range(max_full_block)): + if block_table[i].computed: + break + block_table[i].computed = True - def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]: + def get_all_computed_blocks(self, seq: Sequence) -> List[int]: if seq.seq_id not in self.block_tables: return [] block_table = self.block_tables[seq.seq_id] - for block_idx in reversed(range(len(block_table))): - if block_table[block_idx].computed: - return [b.block_number for b in block_table[:block_idx + 1]] - return [] + # NOTE We exclude the last block to avoid the case where the entire + # prompt is cached. This would cause erroneous behavior in model + # runner. + return [ + b.block_number + for b in takewhile(lambda b: b.computed, block_table[:-1]) + ] def get_common_computed_block_ids(self, seq_group: SequenceGroup) -> List[int]: @@ -451,14 +457,12 @@ class BlockSpaceManager: return [] ids_list = [ - self.get_all_block_ids_till_computed(seq) + self.get_all_computed_blocks(seq) for seq in iter(seq_group.seqs_dict.values()) ] return commonprefix([ids for ids in ids_list if ids != []]) def mark_blocks_as_computed(self, seq_group: SequenceGroup): - # NOTE: We only mark the last full block because with prefix caching, - # all blocks until the marked one are guaranteed to be computed. if self.enable_caching: for seq in seq_group.seqs_dict.values(): - self.compute_last_full_block_in_seq(seq) + self.compute_full_blocks_in_seq(seq) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b01f865f1bb0..9023b0c59b3f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -215,6 +215,7 @@ class ModelRunner: slot_mapping[-1].append(slot) max_prompt_len = max(subquery_lens) + assert max_prompt_len > 0 input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, pad=0,