Fix auto prefix bug (#3239)

This commit is contained in:
ElizaWszola 2024-03-08 01:37:28 +01:00 committed by GitHub
parent 8cbba4622c
commit b35cc93420
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 12 deletions

View File

@ -0,0 +1,34 @@
import pytest
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("block_size", [16])
def test_computed_prefix_blocks(model: str, block_size: int):
# This test checks if we are able to run the engine to completion
# without triggering asserts.
# We are in a scenario where all blocks from the second request's prompt
# are full and already computed when the second request arrives.
prompt = (
"You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available "
"online for free?")
prompt2 = (
" Please recommend to me some resources where I can learn not only to "
"handle technical difficulties of building a car, but also "
"decoration.")
engine_args = EngineArgs(model=model,
block_size=block_size,
enable_prefix_caching=True)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams()
engine.add_request("0", prompt + prompt2, sampling_params)
engine.step()
engine.add_request("1", prompt, sampling_params)
engine.step()

View File

@ -1,6 +1,6 @@
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
import enum import enum
from itertools import count from itertools import count, takewhile
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
@ -426,23 +426,29 @@ class BlockSpaceManager:
for block in block_table: for block in block_table:
block.last_accessed = access_time block.last_accessed = access_time
def compute_last_full_block_in_seq(self, seq: Sequence): def compute_full_blocks_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables: if seq.seq_id not in self.block_tables:
return return
max_full_block = seq.get_len() // self.block_size - 1 max_full_block = seq.get_len() // self.block_size - 1
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
if max_full_block == -1: if max_full_block == -1:
return return
block_table[max_full_block].computed = True for i in reversed(range(max_full_block)):
if block_table[i].computed:
break
block_table[i].computed = True
def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]: def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
if seq.seq_id not in self.block_tables: if seq.seq_id not in self.block_tables:
return [] return []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
for block_idx in reversed(range(len(block_table))): # NOTE We exclude the last block to avoid the case where the entire
if block_table[block_idx].computed: # prompt is cached. This would cause erroneous behavior in model
return [b.block_number for b in block_table[:block_idx + 1]] # runner.
return [] return [
b.block_number
for b in takewhile(lambda b: b.computed, block_table[:-1])
]
def get_common_computed_block_ids(self, def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]: seq_group: SequenceGroup) -> List[int]:
@ -451,14 +457,12 @@ class BlockSpaceManager:
return [] return []
ids_list = [ ids_list = [
self.get_all_block_ids_till_computed(seq) self.get_all_computed_blocks(seq)
for seq in iter(seq_group.seqs_dict.values()) for seq in iter(seq_group.seqs_dict.values())
] ]
return commonprefix([ids for ids in ids_list if ids != []]) return commonprefix([ids for ids in ids_list if ids != []])
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# NOTE: We only mark the last full block because with prefix caching,
# all blocks until the marked one are guaranteed to be computed.
if self.enable_caching: if self.enable_caching:
for seq in seq_group.seqs_dict.values(): for seq in seq_group.seqs_dict.values():
self.compute_last_full_block_in_seq(seq) self.compute_full_blocks_in_seq(seq)

View File

@ -215,6 +215,7 @@ class ModelRunner:
slot_mapping[-1].append(slot) slot_mapping[-1].append(slot)
max_prompt_len = max(subquery_lens) max_prompt_len = max(subquery_lens)
assert max_prompt_len > 0
input_tokens = _make_tensor_with_pad(input_tokens, input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len, max_prompt_len,
pad=0, pad=0,