mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:45:52 +08:00
[FIX] Fix styles in automatic prefix caching & add a automatic prefix caching benchmark (#3158)
This commit is contained in:
parent
d65fac2738
commit
996d095c54
59
benchmarks/benchmark_prefix_caching.py
Normal file
59
benchmarks/benchmark_prefix_caching.py
Normal file
@ -0,0 +1,59 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from vllm import LLM
|
||||
from vllm import SamplingParams
|
||||
|
||||
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n"
|
||||
|
||||
|
||||
def test_prefix(llm=None, sampling_params=None, prompts=None, prefix_len=None):
|
||||
start_time = time.time()
|
||||
# whether use Prefix
|
||||
if prefix_len != None:
|
||||
# start inference
|
||||
llm.generate(prompts,
|
||||
sampling_params=sampling_params,
|
||||
prefix_pos=prefix_len)
|
||||
else:
|
||||
llm.generate(prompts, sampling_params=sampling_params)
|
||||
|
||||
end_time = time.time()
|
||||
print(f"cost time {end_time - start_time}")
|
||||
|
||||
|
||||
def main(args):
|
||||
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat",
|
||||
tokenizer_mode='auto',
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=args.enable_prefix_caching)
|
||||
|
||||
num_prompts = 100
|
||||
prompts = [PROMPT] * num_prompts
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=100)
|
||||
|
||||
print("------warm up------")
|
||||
test_prefix(
|
||||
llm=llm,
|
||||
prompts=prompts[:1],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
print("------start generating------")
|
||||
test_prefix(
|
||||
llm=llm,
|
||||
prompts=prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Benchmark the performance with or without automatic '
|
||||
'prefix caching.')
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
action='store_true',
|
||||
help='enable prefix caching')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -303,7 +303,10 @@ if __name__ == "__main__":
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
parser.add_argument("--enable_prefix_caching", action='store_true')
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="enable automatic prefix caching for vLLM backend.")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
@ -236,13 +236,6 @@ class BlockSpaceManager:
|
||||
token_ids_len = len(seq.data.get_token_ids())
|
||||
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
|
||||
|
||||
def _is_last_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
index: int,
|
||||
) -> bool:
|
||||
return index == len(seq.logical_token_blocks) - 1
|
||||
|
||||
def _maybe_promote_last_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
@ -436,7 +429,7 @@ class BlockSpaceManager:
|
||||
def compute_last_full_block_in_seq(self, seq: Sequence):
|
||||
if seq.seq_id not in self.block_tables:
|
||||
return
|
||||
max_full_block = seq.get_len() // seq.block_size - 1
|
||||
max_full_block = seq.get_len() // self.block_size - 1
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if max_full_block == -1:
|
||||
return
|
||||
@ -451,9 +444,9 @@ class BlockSpaceManager:
|
||||
return [b.block_number for b in block_table[:block_idx + 1]]
|
||||
return []
|
||||
|
||||
# Can return non-empty result only with prefix caching enabled.
|
||||
def get_common_computed_block_ids(self,
|
||||
seq_group: SequenceGroup) -> List[int]:
|
||||
# Can return non-empty result only with prefix caching enabled.
|
||||
if not self.enable_caching:
|
||||
return []
|
||||
|
||||
@ -463,9 +456,9 @@ class BlockSpaceManager:
|
||||
]
|
||||
return commonprefix([ids for ids in ids_list if ids != []])
|
||||
|
||||
# We only mark the last full block because with prefix caching,
|
||||
# all blocks until the marked one are guaranteed to be computed.
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
# NOTE: We only mark the last full block because with prefix caching,
|
||||
# all blocks until the marked one are guaranteed to be computed.
|
||||
if self.enable_caching:
|
||||
for seq in seq_group.seqs_dict.values():
|
||||
self.compute_last_full_block_in_seq(seq)
|
||||
|
||||
@ -160,10 +160,10 @@ class Sequence:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
# TODO The current hashing function is O(L^2). We should optimize this in
|
||||
# the future.
|
||||
def hash_of_block(self, logical_idx: int) -> int:
|
||||
# Compute the number of tokens in the sequence
|
||||
# TODO: The current hashing function is O(L^2). We should optimize
|
||||
# this in the future.
|
||||
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
||||
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
|
||||
|
||||
@ -308,10 +308,6 @@ class SequenceGroup:
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return next(iter(self.seqs_dict.values())).block_size
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user