[FIX] Fix styles in automatic prefix caching & add a automatic prefix caching benchmark (#3158)

This commit is contained in:
Zhuohan Li 2024-03-03 14:37:18 -08:00 committed by GitHub
parent d65fac2738
commit 996d095c54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 69 additions and 18 deletions

View File

@ -0,0 +1,59 @@
import argparse
import time
from vllm import LLM
from vllm import SamplingParams
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n"
def test_prefix(llm=None, sampling_params=None, prompts=None, prefix_len=None):
start_time = time.time()
# whether use Prefix
if prefix_len != None:
# start inference
llm.generate(prompts,
sampling_params=sampling_params,
prefix_pos=prefix_len)
else:
llm.generate(prompts, sampling_params=sampling_params)
end_time = time.time()
print(f"cost time {end_time - start_time}")
def main(args):
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat",
tokenizer_mode='auto',
trust_remote_code=True,
enforce_eager=True,
enable_prefix_caching=args.enable_prefix_caching)
num_prompts = 100
prompts = [PROMPT] * num_prompts
sampling_params = SamplingParams(temperature=0, max_tokens=100)
print("------warm up------")
test_prefix(
llm=llm,
prompts=prompts[:1],
sampling_params=sampling_params,
)
print("------start generating------")
test_prefix(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the performance with or without automatic '
'prefix caching.')
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
args = parser.parse_args()
main(args)

View File

@ -303,7 +303,10 @@ if __name__ == "__main__":
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument("--enable_prefix_caching", action='store_true')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model

View File

@ -236,13 +236,6 @@ class BlockSpaceManager:
token_ids_len = len(seq.data.get_token_ids())
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
def _is_last_block(
self,
seq: Sequence,
index: int,
) -> bool:
return index == len(seq.logical_token_blocks) - 1
def _maybe_promote_last_block(
self,
seq: Sequence,
@ -436,7 +429,7 @@ class BlockSpaceManager:
def compute_last_full_block_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // seq.block_size - 1
max_full_block = seq.get_len() // self.block_size - 1
block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
return
@ -451,9 +444,9 @@ class BlockSpaceManager:
return [b.block_number for b in block_table[:block_idx + 1]]
return []
# Can return non-empty result only with prefix caching enabled.
def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
# Can return non-empty result only with prefix caching enabled.
if not self.enable_caching:
return []
@ -463,9 +456,9 @@ class BlockSpaceManager:
]
return commonprefix([ids for ids in ids_list if ids != []])
# We only mark the last full block because with prefix caching,
# all blocks until the marked one are guaranteed to be computed.
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# NOTE: We only mark the last full block because with prefix caching,
# all blocks until the marked one are guaranteed to be computed.
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
self.compute_last_full_block_in_seq(seq)

View File

@ -160,10 +160,10 @@ class Sequence:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
# TODO The current hashing function is O(L^2). We should optimize this in
# the future.
def hash_of_block(self, logical_idx: int) -> int:
# Compute the number of tokens in the sequence
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
@ -308,10 +308,6 @@ class SequenceGroup:
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
@property
def block_size(self) -> int:
return next(iter(self.seqs_dict.values())).block_size
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0