Enable Random Prefix Caching in Serving Profiling Tool (benchmark_serving.py) (#8241)

This commit is contained in:
Wei-Sheng Chin 2024-09-06 20:18:16 -07:00 committed by GitHub
parent 2f707fcb35
commit 795b662cff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -195,8 +195,16 @@ def sample_sonnet_requests(
def sample_random_requests(
input_len: int, output_len: int, num_prompts: int, range_ratio: float,
tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]:
prefix_len: int,
input_len: int,
output_len: int,
num_prompts: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
prefix_token_ids = np.random.randint(0,
tokenizer.vocab_size,
size=prefix_len).tolist()
input_lens = np.random.randint(
int(input_len * range_ratio),
@ -211,10 +219,12 @@ def sample_random_requests(
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size
prompt = tokenizer.decode(prefix_token_ids +
[(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])])
input_requests.append(
(prompt, int(input_lens[i]), int(output_lens[i])))
(prompt, int(prefix_len + input_lens[i]), int(output_lens[i])))
return input_requests
@ -567,6 +577,7 @@ def main(args: argparse.Namespace):
elif args.dataset_name == "random":
input_requests = sample_random_requests(
prefix_len=args.random_prefix_len,
input_len=args.random_input_len,
output_len=args.random_output_len,
num_prompts=args.num_prompts,
@ -765,6 +776,14 @@ if __name__ == "__main__":
help="Range of sampled ratio of input/output length, "
"used only for random sampling.",
)
parser.add_argument(
"--random-prefix-len",
type=int,
default=0,
help="Number of fixed prefix tokens before random "
" context. The length range of context in a random "
" request is [random-prefix-len, "
" random-prefix-len + random-prefix-len * random-range-ratio).")
parser.add_argument(
"--request-rate",
type=float,