mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 05:54:27 +08:00
[Benchmark] Add benchmark tool for multi turn conversations (#20267)
This commit is contained in:
parent
e789cad6b8
commit
f0964e29cb
71
benchmarks/multi_turn/README.md
Normal file
71
benchmarks/multi_turn/README.md
Normal file
@ -0,0 +1,71 @@
|
||||
# Benchmark KV Cache Offloading with Multi-Turn Conversations
|
||||
|
||||
The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `requirements.txt`
|
||||
|
||||
First start serving your model
|
||||
|
||||
```bash
|
||||
export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
|
||||
vllm serve $MODEL_NAME --disable-log-requests
|
||||
```
|
||||
|
||||
## Synthetic Multi-Turn Conversations
|
||||
|
||||
Download the following text file (used for generation of synthetic conversations)
|
||||
|
||||
```bash
|
||||
wget https://www.gutenberg.org/ebooks/1184.txt.utf-8
|
||||
mv 1184.txt.utf-8 pg1184.txt
|
||||
```
|
||||
|
||||
The filename `pg1184.txt` is used in `generate_multi_turn.json` (see `"text_files"`).
|
||||
|
||||
But you may use other text files if you prefer (using this specific file is not required).
|
||||
|
||||
Then run the benchmarking script
|
||||
|
||||
```bash
|
||||
export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
|
||||
python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \
|
||||
--num-clients 2 --max-active-conversations 6
|
||||
```
|
||||
|
||||
You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.).
|
||||
|
||||
If successful, you will see the following output
|
||||
|
||||
```bash
|
||||
----------------------------------------------------------------------------------------------------
|
||||
Statistics summary:
|
||||
runtime_sec = 215.810
|
||||
requests_per_sec = 0.769
|
||||
----------------------------------------------------------------------------------------------------
|
||||
count mean std min 25% 50% 75% 90% 99% max
|
||||
ttft_ms 166.0 78.22 67.63 45.91 59.94 62.26 64.43 69.66 353.18 567.54
|
||||
tpot_ms 166.0 25.37 0.57 24.40 25.07 25.31 25.50 25.84 27.50 28.05
|
||||
latency_ms 166.0 2591.07 326.90 1998.53 2341.62 2573.01 2860.10 3003.50 3268.46 3862.94
|
||||
input_num_turns 166.0 7.43 4.57 1.00 3.00 7.00 11.00 13.00 17.00 17.00
|
||||
input_num_tokens 166.0 2006.20 893.56 522.00 1247.75 2019.00 2718.00 3233.00 3736.45 3899.00
|
||||
output_num_tokens 166.0 100.01 11.80 80.00 91.00 99.00 109.75 116.00 120.00 120.00
|
||||
output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 115.00 119.00 119.00
|
||||
----------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
## ShareGPT Conversations
|
||||
|
||||
To run with the ShareGPT data, download the following ShareGPT dataset:
|
||||
`https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json`
|
||||
|
||||
Use the `convert_sharegpt_to_openai.py` script to convert the dataset to a format supported by `benchmark_serving_multi_turn.py`
|
||||
|
||||
```bash
|
||||
python convert_sharegpt_to_openai.py sharegpt_20230401_clean_lang_split.json sharegpt_conv_128.json --seed=99 --max-items=128
|
||||
```
|
||||
|
||||
The script will convert the ShareGPT dataset to a dataset with the standard user/assistant roles.
|
||||
|
||||
The flag `--max-items=128` is used to sample 128 conversations from the original dataset (change as needed).
|
||||
|
||||
Use the output JSON file `sharegpt_conv_128.json` as the `--input-file` for `benchmark_serving_multi_turn.py`.
|
||||
493
benchmarks/multi_turn/bench_dataset.py
Normal file
493
benchmarks/multi_turn/bench_dataset.py
Normal file
@ -0,0 +1,493 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from statistics import mean
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np # type: ignore
|
||||
import pandas as pd # type: ignore
|
||||
from bench_utils import (
|
||||
TEXT_SEPARATOR,
|
||||
Color,
|
||||
logger,
|
||||
)
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
# Conversation ID is a string (e.g: "UzTK34D")
|
||||
ConvId = str
|
||||
|
||||
# A list of dicts (dicts with keys "id" and "messages")
|
||||
ShareGptConversations = list[dict[str, Any]]
|
||||
|
||||
# A list of dicts (dicts with keys "role" and "content")
|
||||
MessagesList = list[dict[str, str]]
|
||||
|
||||
# Map conversation ID to conversation messages
|
||||
ConversationsMap = list[ConvId, MessagesList]
|
||||
|
||||
|
||||
class Distribution(ABC):
|
||||
@abstractmethod
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
class UniformDistribution(Distribution):
|
||||
def __init__(
|
||||
self,
|
||||
min_val: Union[int, float],
|
||||
max_val: Union[int, float],
|
||||
is_integer: bool = True,
|
||||
) -> None:
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
self.is_integer = is_integer
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
if self.is_integer:
|
||||
return np.random.randint(
|
||||
int(self.min_val), int(self.max_val + 1), size=size
|
||||
)
|
||||
else:
|
||||
return np.random.uniform(self.min_val, self.max_val, size=size)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"UniformDistribution[{self.min_val}, {self.max_val}]"
|
||||
|
||||
|
||||
class ConstantDistribution(Distribution):
|
||||
def __init__(self, value: Union[int, float]) -> None:
|
||||
self.value = value
|
||||
self.max_val = value
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
return np.full(shape=size, fill_value=self.value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Constant[{self.value}]"
|
||||
|
||||
|
||||
class ZipfDistribution(Distribution):
|
||||
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
|
||||
self.alpha = alpha
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.zipf(self.alpha, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
return samples
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ZipfDistribution[{self.alpha}]"
|
||||
|
||||
|
||||
class PoissonDistribution(Distribution):
|
||||
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
|
||||
self.alpha = alpha
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.poisson(self.alpha, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
return samples
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PoissonDistribution[{self.alpha}]"
|
||||
|
||||
|
||||
class LognormalDistribution(Distribution):
|
||||
def __init__(
|
||||
self, mean: float, sigma: float, max_val: Optional[int] = None
|
||||
) -> None:
|
||||
self.mean = mean
|
||||
self.sigma = sigma
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
|
||||
return np.round(samples).astype(int)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LognormalDistribution[{self.mean}, {self.sigma}]"
|
||||
|
||||
|
||||
class GenConvArgs(NamedTuple):
|
||||
num_conversations: int
|
||||
text_files: list[str]
|
||||
input_num_turns: Distribution
|
||||
input_common_prefix_num_tokens: Distribution
|
||||
input_prefix_num_tokens: Distribution
|
||||
input_num_tokens: Distribution
|
||||
output_num_tokens: Distribution
|
||||
print_stats: bool
|
||||
|
||||
|
||||
def verify_field_exists(
|
||||
conf: dict, field_name: str, section: str, subsection: str
|
||||
) -> None:
|
||||
if field_name not in conf:
|
||||
raise ValueError(
|
||||
f"Missing field '{field_name}' in {section=} and {subsection=}"
|
||||
)
|
||||
|
||||
|
||||
def get_random_distribution(
|
||||
conf: dict, section: str, subsection: str, optional: bool = False
|
||||
) -> Distribution:
|
||||
# section can be "prompt_input" or "prompt_output" (both required)
|
||||
conf = conf[section]
|
||||
|
||||
if optional and subsection not in conf:
|
||||
# Optional subsection, if not found assume the value is always 0
|
||||
return ConstantDistribution(0)
|
||||
|
||||
# subsection can be "num_turns", "num_tokens" or "prefix_num_tokens"
|
||||
if subsection not in conf:
|
||||
raise ValueError(f"Missing subsection {subsection} in section {section}")
|
||||
|
||||
conf = conf[subsection]
|
||||
|
||||
distribution = conf.get("distribution")
|
||||
if distribution is None:
|
||||
raise ValueError(
|
||||
f"Missing field 'distribution' in {section=} and {subsection=}"
|
||||
)
|
||||
|
||||
if distribution == "constant":
|
||||
verify_field_exists(conf, "value", section, subsection)
|
||||
return ConstantDistribution(conf["value"])
|
||||
|
||||
elif distribution == "zipf":
|
||||
verify_field_exists(conf, "alpha", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return ZipfDistribution(conf["alpha"], max_val=max_val)
|
||||
|
||||
elif distribution == "poisson":
|
||||
verify_field_exists(conf, "alpha", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return PoissonDistribution(conf["alpha"], max_val=max_val)
|
||||
|
||||
elif distribution == "lognormal":
|
||||
verify_field_exists(conf, "mean", section, subsection)
|
||||
verify_field_exists(conf, "sigma", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return LognormalDistribution(conf["mean"], conf["sigma"], max_val=max_val)
|
||||
|
||||
elif distribution == "uniform":
|
||||
verify_field_exists(conf, "min", section, subsection)
|
||||
verify_field_exists(conf, "max", section, subsection)
|
||||
|
||||
min_value = conf["min"]
|
||||
max_value = conf["max"]
|
||||
|
||||
assert min_value > 0
|
||||
assert min_value <= max_value
|
||||
|
||||
is_integer = isinstance(min_value, int) and isinstance(max_value, int)
|
||||
return UniformDistribution(min_value, max_value, is_integer)
|
||||
else:
|
||||
raise ValueError(f"Unknown distribution: {distribution}")
|
||||
|
||||
|
||||
def parse_input_json_file(conf: dict) -> GenConvArgs:
|
||||
# Validate the input file
|
||||
assert isinstance(conf, dict)
|
||||
required_fields = [
|
||||
"filetype",
|
||||
"num_conversations",
|
||||
"text_files",
|
||||
"prompt_input",
|
||||
"prompt_output",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in conf, f"Missing field {field} in input {conf}"
|
||||
|
||||
assert conf["filetype"] == "generate_conversations"
|
||||
|
||||
assert conf["num_conversations"] > 0, "num_conversations should be larger than zero"
|
||||
|
||||
text_files = conf["text_files"]
|
||||
|
||||
assert isinstance(text_files, list), "Field 'text_files' should be a list"
|
||||
assert len(text_files) > 0, (
|
||||
"Field 'text_files' should be a list with at least one file"
|
||||
)
|
||||
|
||||
# Parse the parameters for the prompt input/output workload
|
||||
input_num_turns = get_random_distribution(conf, "prompt_input", "num_turns")
|
||||
input_num_tokens = get_random_distribution(conf, "prompt_input", "num_tokens")
|
||||
input_common_prefix_num_tokens = get_random_distribution(
|
||||
conf, "prompt_input", "common_prefix_num_tokens", optional=True
|
||||
)
|
||||
input_prefix_num_tokens = get_random_distribution(
|
||||
conf, "prompt_input", "prefix_num_tokens"
|
||||
)
|
||||
output_num_tokens = get_random_distribution(conf, "prompt_output", "num_tokens")
|
||||
|
||||
print_stats: bool = conf.get("print_stats", False)
|
||||
assert isinstance(print_stats, bool), (
|
||||
"Field 'print_stats' should be either 'true' or 'false'"
|
||||
)
|
||||
|
||||
args = GenConvArgs(
|
||||
num_conversations=conf["num_conversations"],
|
||||
text_files=text_files,
|
||||
input_num_turns=input_num_turns,
|
||||
input_common_prefix_num_tokens=input_common_prefix_num_tokens,
|
||||
input_prefix_num_tokens=input_prefix_num_tokens,
|
||||
input_num_tokens=input_num_tokens,
|
||||
output_num_tokens=output_num_tokens,
|
||||
print_stats=print_stats,
|
||||
)
|
||||
return args
|
||||
|
||||
|
||||
def print_conv_stats(conversations: ConversationsMap, tokenizer: AutoTokenizer) -> None:
|
||||
# Collect statistics
|
||||
conv_stats: list[dict[Any, Any]] = []
|
||||
req_stats: list[int] = []
|
||||
|
||||
print("\nCollecting statistics...")
|
||||
for messages in conversations.values():
|
||||
# messages is a list of dicts
|
||||
user_tokens: list[int] = []
|
||||
assistant_tokens: list[int] = []
|
||||
request_tokens: list[int] = []
|
||||
|
||||
req_tokens = 0
|
||||
for m in messages:
|
||||
content = m["content"]
|
||||
num_tokens = len(tokenizer(content).input_ids)
|
||||
|
||||
if m["role"] == "user":
|
||||
user_tokens.append(num_tokens)
|
||||
# New user prompt including all chat history
|
||||
req_tokens += num_tokens
|
||||
request_tokens.append(req_tokens)
|
||||
|
||||
elif m["role"] == "assistant":
|
||||
assistant_tokens.append(num_tokens)
|
||||
# Update assistant answer
|
||||
# (will be part of chat history for the next user prompt)
|
||||
req_tokens += num_tokens
|
||||
|
||||
item_stats = {
|
||||
"conversation_turns": len(messages),
|
||||
"user_tokens": mean(user_tokens),
|
||||
"assistant_tokens": mean(assistant_tokens),
|
||||
}
|
||||
|
||||
conv_stats.append(item_stats)
|
||||
req_stats.extend(request_tokens)
|
||||
|
||||
# Print statistics
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99]
|
||||
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
df = pd.DataFrame(conv_stats)
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Request statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
df = pd.DataFrame(req_stats, columns=["request_tokens"])
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
|
||||
|
||||
def generate_conversations(
|
||||
args: GenConvArgs, tokenizer: AutoTokenizer
|
||||
) -> ConversationsMap:
|
||||
# Text for all user prompts
|
||||
# (text from the input text files will be appended to this line)
|
||||
base_prompt_text = "Please rewrite the following text and add more content: "
|
||||
base_prompt_token_count = len(
|
||||
tokenizer.encode(base_prompt_text, add_special_tokens=False)
|
||||
)
|
||||
|
||||
logger.info(f"{Color.PURPLE}Generating conversations...{Color.RESET}")
|
||||
logger.info(args)
|
||||
|
||||
list_of_tokens = []
|
||||
|
||||
for filename in args.text_files:
|
||||
# Load text file that will be used to generate prompts
|
||||
with open(filename) as file:
|
||||
data = file.read()
|
||||
tokens_in_file = tokenizer.encode(data, add_special_tokens=False)
|
||||
list_of_tokens.extend(tokens_in_file)
|
||||
|
||||
conversations: ConversationsMap = {}
|
||||
conv_id = 0
|
||||
|
||||
# Generate number of turns for every conversation
|
||||
turn_count: np.ndarray = args.input_num_turns.sample(args.num_conversations)
|
||||
|
||||
# Turn count should be at least 2 (one user prompt and one assistant answer)
|
||||
turn_count = np.maximum(turn_count, 2)
|
||||
|
||||
# Round up to an even number (every user prompt should have an answer)
|
||||
turn_count = turn_count + (turn_count % 2)
|
||||
|
||||
# Generate number of prefix tokens for every conversation
|
||||
conv_prefix_tokens: np.ndarray = args.input_prefix_num_tokens.sample(
|
||||
args.num_conversations
|
||||
)
|
||||
|
||||
# Used to reduce shared text between conversations
|
||||
# (jump/skip over text sections between conversations)
|
||||
base_offset = 0
|
||||
|
||||
# Common prefix size for all conversations (only 1 sample required)
|
||||
common_prefix_text = ""
|
||||
common_prefix_tokens: int = args.input_common_prefix_num_tokens.sample(1)[0]
|
||||
if common_prefix_tokens > 0:
|
||||
# Using "." at the end to separate sentences
|
||||
common_prefix_text = (
|
||||
tokenizer.decode(list_of_tokens[: common_prefix_tokens - 2]) + "."
|
||||
)
|
||||
base_offset += common_prefix_tokens
|
||||
|
||||
for conv_id in range(args.num_conversations):
|
||||
# Generate a single conversation
|
||||
messages: MessagesList = []
|
||||
|
||||
nturns = turn_count[conv_id]
|
||||
|
||||
# User prompt token count per turn (with lower limit)
|
||||
input_token_count: np.ndarray = args.input_num_tokens.sample(nturns)
|
||||
input_token_count = np.maximum(input_token_count, base_prompt_token_count)
|
||||
|
||||
# Assistant answer token count per turn (with lower limit)
|
||||
output_token_count: np.ndarray = args.output_num_tokens.sample(nturns)
|
||||
output_token_count = np.maximum(output_token_count, 1)
|
||||
|
||||
user_turn = True
|
||||
for turn_id in range(nturns):
|
||||
if user_turn:
|
||||
role = "user"
|
||||
num_tokens = input_token_count[turn_id]
|
||||
|
||||
# Generate the user prompt,
|
||||
# use a unique prefix (the conv_id) for each conversation
|
||||
# (to avoid shared prefix between conversations)
|
||||
content = f"{conv_id} is a nice number... "
|
||||
|
||||
if len(common_prefix_text) > 0 and turn_id == 0:
|
||||
content = common_prefix_text + content
|
||||
|
||||
# Update the number of tokens left for the content
|
||||
num_tokens -= len(tokenizer.encode(content, add_special_tokens=False))
|
||||
|
||||
if turn_id == 0:
|
||||
prefix_num_tokens = conv_prefix_tokens[conv_id]
|
||||
if prefix_num_tokens > 0:
|
||||
# Add prefix text (context) to the first turn
|
||||
start_offset = base_offset
|
||||
end_offset = start_offset + prefix_num_tokens
|
||||
assert len(list_of_tokens) > end_offset, (
|
||||
"Not enough input text to generate "
|
||||
f"{prefix_num_tokens} tokens for the "
|
||||
f"prefix text ({start_offset=}, {end_offset=})"
|
||||
)
|
||||
|
||||
content += f"{conv_id}, " + tokenizer.decode(
|
||||
list_of_tokens[start_offset:end_offset]
|
||||
)
|
||||
base_offset += prefix_num_tokens
|
||||
|
||||
# Add the actual user prompt/question after the prefix text
|
||||
content += base_prompt_text
|
||||
num_tokens -= base_prompt_token_count
|
||||
|
||||
if num_tokens > 0:
|
||||
# Add text from the input file (to reach the desired token count)
|
||||
start_offset = base_offset + turn_id * input_token_count.max()
|
||||
end_offset = start_offset + num_tokens
|
||||
assert len(list_of_tokens) > end_offset, (
|
||||
f"Not enough input text to generate {num_tokens} tokens "
|
||||
f"for the prompt ({start_offset=}, {end_offset=})"
|
||||
)
|
||||
|
||||
# Convert tokens back to text
|
||||
content += tokenizer.decode(list_of_tokens[start_offset:end_offset])
|
||||
else:
|
||||
role = "assistant"
|
||||
# This content will not be used as input to the LLM server
|
||||
# (actual answers will be used instead).
|
||||
# Content is only required to determine the min_tokens/max_tokens
|
||||
# (inputs to the LLM server).
|
||||
num_tokens = output_token_count[turn_id]
|
||||
assert len(list_of_tokens) > num_tokens, (
|
||||
f"Not enough input text to generate {num_tokens} "
|
||||
"tokens for assistant content"
|
||||
)
|
||||
content = tokenizer.decode(list_of_tokens[:num_tokens])
|
||||
|
||||
# Append the user/assistant message to the list of messages
|
||||
messages.append({"role": role, "content": content})
|
||||
user_turn = not user_turn
|
||||
|
||||
# Add the new conversation
|
||||
conversations[f"CONV_ID_{conv_id}"] = messages
|
||||
|
||||
# Increase base offset for the next conversation
|
||||
base_offset += nturns
|
||||
|
||||
if args.print_stats:
|
||||
print_conv_stats(conversations, tokenizer)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def conversations_list_to_dict(input_list: ShareGptConversations) -> ConversationsMap:
|
||||
conversations: ConversationsMap = {}
|
||||
|
||||
for item in input_list:
|
||||
conv_id: str = item["id"]
|
||||
assert isinstance(conv_id, str)
|
||||
|
||||
assert conv_id not in conversations, (
|
||||
f"Conversation ID {conv_id} found more than once in the input"
|
||||
)
|
||||
|
||||
messages: MessagesList = item["messages"]
|
||||
assert isinstance(messages, list), (
|
||||
f"Conversation messages should be a list (ID: {conv_id})"
|
||||
)
|
||||
assert len(messages) > 0, f"Conversation with no messages (ID: {conv_id})"
|
||||
|
||||
conversations[conv_id] = messages
|
||||
|
||||
logger.info(f"Using {len(conversations)} unique conversations (IDs)")
|
||||
assert len(conversations) == len(input_list)
|
||||
|
||||
# Print statistics about the selected conversations
|
||||
stats: list[dict[str, Any]] = []
|
||||
for conv_data in conversations.values():
|
||||
stats.append({"num_turns": len(conv_data)})
|
||||
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
|
||||
conv_stats = pd.DataFrame(stats).describe(percentiles=percentiles)
|
||||
print(conv_stats.transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def conversations_dict_to_list(input_dict: ConversationsMap) -> ShareGptConversations:
|
||||
output: ShareGptConversations = []
|
||||
for conv_id, conv_data in input_dict.items():
|
||||
new_item = {"id": conv_id, "messages": conv_data}
|
||||
output.append(new_item)
|
||||
|
||||
return output
|
||||
25
benchmarks/multi_turn/bench_utils.py
Normal file
25
benchmarks/multi_turn/bench_utils.py
Normal file
@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Color(str, Enum):
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
BLUE = "\033[94m"
|
||||
PURPLE = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
YELLOW = "\033[93m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
TEXT_SEPARATOR = "-" * 100
|
||||
|
||||
# Configure the logger
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] - %(message)s",
|
||||
datefmt="%d-%m-%Y %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
1557
benchmarks/multi_turn/benchmark_serving_multi_turn.py
Normal file
1557
benchmarks/multi_turn/benchmark_serving_multi_turn.py
Normal file
File diff suppressed because it is too large
Load Diff
354
benchmarks/multi_turn/convert_sharegpt_to_openai.py
Normal file
354
benchmarks/multi_turn/convert_sharegpt_to_openai.py
Normal file
@ -0,0 +1,354 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Download dataset from:
|
||||
https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json
|
||||
|
||||
Convert to OpenAI API:
|
||||
export INPUT_FILE=sharegpt_20230401_clean_lang_split.json
|
||||
python convert_sharegpt_to_openai.py $INPUT_FILE sharegpt_conv_128.json --max-items=128
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
from statistics import mean
|
||||
from typing import Any, Optional
|
||||
|
||||
import pandas as pd # type: ignore
|
||||
import tqdm # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
def has_non_english_chars(text: str) -> bool:
|
||||
return not text.isascii()
|
||||
|
||||
|
||||
def content_is_valid(
|
||||
content: str, min_content_len: Optional[int], max_content_len: Optional[int]
|
||||
) -> bool:
|
||||
if min_content_len and len(content) < min_content_len:
|
||||
return False
|
||||
|
||||
if max_content_len and len(content) > max_content_len:
|
||||
return False
|
||||
|
||||
return has_non_english_chars(content)
|
||||
|
||||
|
||||
def print_stats(
|
||||
conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None
|
||||
) -> None:
|
||||
# Collect statistics
|
||||
stats = []
|
||||
|
||||
print("\nCollecting statistics...")
|
||||
for item in tqdm.tqdm(conversations):
|
||||
# item has "id" and "messages"
|
||||
messages = item["messages"]
|
||||
|
||||
user_turns = 0
|
||||
assistant_turns = 0
|
||||
user_words = 0
|
||||
assistant_words = 0
|
||||
conv_chars = 0
|
||||
|
||||
user_tokens: list[int] = []
|
||||
assistant_tokens: list[int] = []
|
||||
|
||||
for m in messages:
|
||||
content = m["content"]
|
||||
conv_chars += len(content)
|
||||
content_num_words = content.count(" ") + 1
|
||||
|
||||
num_tokens = 0
|
||||
if tokenizer:
|
||||
num_tokens = len(tokenizer(m["content"]).input_ids)
|
||||
|
||||
if m["role"] == "user":
|
||||
user_turns += 1
|
||||
user_words += content_num_words
|
||||
if tokenizer:
|
||||
user_tokens.append(num_tokens)
|
||||
|
||||
elif m["role"] == "assistant":
|
||||
assistant_turns += 1
|
||||
assistant_words += content_num_words
|
||||
if tokenizer:
|
||||
assistant_tokens.append(num_tokens)
|
||||
|
||||
# assert user_turns == assistant_turns, \
|
||||
# f"Invalid conversation ID {item['id']}"
|
||||
|
||||
conv_words = user_words + assistant_words
|
||||
item_stats = {
|
||||
"user_turns": user_turns,
|
||||
"assistant_turns": assistant_turns,
|
||||
"user_words": user_words,
|
||||
"assistant_words": assistant_words,
|
||||
"conv_turns": len(messages),
|
||||
"conv_words": conv_words,
|
||||
"conv_characters": conv_chars,
|
||||
}
|
||||
|
||||
if len(user_tokens) > 0:
|
||||
item_stats["user_tokens"] = int(mean(user_tokens))
|
||||
|
||||
if len(assistant_tokens) > 0:
|
||||
item_stats["assistant_tokens"] = int(mean(assistant_tokens))
|
||||
|
||||
stats.append(item_stats)
|
||||
|
||||
print("\nStatistics:")
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
|
||||
df = pd.DataFrame(stats)
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
|
||||
|
||||
def convert_sharegpt_to_openai(
|
||||
seed: int,
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
max_items: Optional[int],
|
||||
min_content_len: Optional[int] = None,
|
||||
max_content_len: Optional[int] = None,
|
||||
min_turns: Optional[int] = None,
|
||||
max_turns: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> None:
|
||||
if min_turns and max_turns:
|
||||
assert min_turns <= max_turns
|
||||
|
||||
if min_content_len and max_content_len:
|
||||
# Verify that min is not larger than max if both were given
|
||||
assert min_content_len <= max_content_len
|
||||
|
||||
print(
|
||||
f"Input parameters:\n{seed=}, {max_items=}, {min_content_len=},"
|
||||
f" {max_content_len=}, {min_turns=}, {max_turns=}\n"
|
||||
)
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
tokenizer = None
|
||||
if model is not None:
|
||||
print(f"Loading tokenizer from: {model}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
# Read the ShareGPT JSON file
|
||||
print(f"Reading file: {input_file}")
|
||||
with open(input_file, encoding="utf-8") as f:
|
||||
# Should be a list of dicts
|
||||
# Each dict should have "id" (string) and "conversations" (list of dicts)
|
||||
sharegpt_data = json.load(f)
|
||||
|
||||
assert isinstance(sharegpt_data, list), "Input file should contain a list of dicts"
|
||||
|
||||
print(f"Total items in input file: {len(sharegpt_data):,}")
|
||||
|
||||
print(f"Shuffling dataset with seed {seed}")
|
||||
random.shuffle(sharegpt_data)
|
||||
|
||||
# Map conversation ID to the all the messages
|
||||
conversation_parts: dict[str, list[Any]] = {}
|
||||
|
||||
for item in tqdm.tqdm(sharegpt_data):
|
||||
assert "id" in item, "Missing key 'id'"
|
||||
assert "conversations" in item, "Missing key 'conversations'"
|
||||
|
||||
# Conversation ID (e.g: "hiWPlMD") and part/session (0, 1, 2, etc.)
|
||||
conv_id, _ = item["id"].split("_")
|
||||
new_turns = item["conversations"]
|
||||
|
||||
if conv_id not in conversation_parts:
|
||||
# Start new conversation
|
||||
conversation_parts[conv_id] = []
|
||||
elif len(conversation_parts[conv_id]) > 0 and len(new_turns) > 0:
|
||||
prev_turns = conversation_parts[conv_id][-1]
|
||||
if prev_turns[-1]["from"] == new_turns[0]["from"]:
|
||||
new_turns = new_turns[1:]
|
||||
|
||||
if len(new_turns) > 0:
|
||||
# We assume that parts are in order in the ShareGPT dataset
|
||||
conversation_parts[conv_id].append(new_turns)
|
||||
|
||||
dataset: list[dict[str, Any]] = []
|
||||
for conv_id, conv_parts in conversation_parts.items():
|
||||
new_item = {"id": conv_id}
|
||||
|
||||
conversations: list[dict[str, str]] = []
|
||||
|
||||
# Merge all parts
|
||||
for conv_part in conv_parts:
|
||||
conversations.extend(conv_part)
|
||||
|
||||
if len(conversations) > 0:
|
||||
new_item["conversations"] = conversations
|
||||
dataset.append(new_item)
|
||||
|
||||
print(f"Total unique conversations (IDs) in input file: {len(dataset):,}")
|
||||
|
||||
# Final output data
|
||||
final_openai_dataset: list[dict] = []
|
||||
|
||||
# Filter conversations from the ShareGPT dataset and convert to OpenAI format
|
||||
for item in tqdm.tqdm(dataset):
|
||||
messages: list[dict] = []
|
||||
|
||||
assert "id" in item, "Missing key 'id'"
|
||||
assert "conversations" in item, "Missing key 'conversations'"
|
||||
|
||||
conv_id = item["id"]
|
||||
conversations = item["conversations"]
|
||||
|
||||
if min_turns is not None and len(conversations) < min_turns:
|
||||
# Skip short conversations
|
||||
continue
|
||||
|
||||
# Convert each message in the conversation, up to max_turns if specified
|
||||
for i, turn in enumerate(conversations):
|
||||
assert "from" in turn and "value" in turn, (
|
||||
f"Invalid conversation ID {conv_id} - missing 'from' or 'value'"
|
||||
)
|
||||
|
||||
role = None
|
||||
turn_from = turn["from"]
|
||||
|
||||
if turn_from in {"human", "user"}:
|
||||
role = "user"
|
||||
elif turn_from in {"gpt", "bing", "chatgpt", "bard"}:
|
||||
role = "assistant"
|
||||
elif turn_from == "system":
|
||||
role = "system"
|
||||
|
||||
assert role is not None, (
|
||||
f"Invalid conversation ID {conv_id} - 'from'='{turn_from}' is invalid"
|
||||
)
|
||||
|
||||
if i == 0 and role != "user":
|
||||
# If the first message is from assistant (gpt), skip it.
|
||||
# this happens when the conversation is a follow-up
|
||||
# to a previous conversation (from the same user).
|
||||
continue
|
||||
|
||||
if max_turns is not None and i >= max_turns:
|
||||
break
|
||||
|
||||
# Convert message to OpenAI format (with "role" and "content")
|
||||
content = turn["value"]
|
||||
messages.append({"role": role, "content": content})
|
||||
|
||||
# Add the converted conversation to the OpenAI format
|
||||
if len(messages) > 0:
|
||||
valid_messages = True
|
||||
|
||||
# First turn should always be from the user
|
||||
user_turn = True
|
||||
|
||||
for m in messages:
|
||||
# Make sure that turns alternate between user and assistant
|
||||
if (user_turn and m["role"] != "user") or (
|
||||
not user_turn and m["role"] != "assistant"
|
||||
):
|
||||
valid_messages = False
|
||||
break
|
||||
|
||||
user_turn = not user_turn
|
||||
|
||||
content = m["content"]
|
||||
valid_messages = content_is_valid(
|
||||
content, min_content_len, max_content_len
|
||||
)
|
||||
if not valid_messages:
|
||||
break
|
||||
|
||||
if valid_messages is True:
|
||||
final_openai_dataset.append({"id": conv_id, "messages": messages})
|
||||
|
||||
assert len(final_openai_dataset) > 0, "Final number of conversations is zero"
|
||||
|
||||
print_stats(final_openai_dataset)
|
||||
|
||||
print_stats_again = False
|
||||
if max_items is not None and len(final_openai_dataset) > max_items:
|
||||
print(f"\n\nSampling {max_items} items from the dataset...")
|
||||
print_stats_again = True
|
||||
final_openai_dataset = random.sample(final_openai_dataset, max_items)
|
||||
|
||||
if print_stats_again:
|
||||
# Print stats after the dataset changed
|
||||
print_stats(final_openai_dataset, tokenizer)
|
||||
|
||||
# Write the converted data to a new JSON file
|
||||
final_size = len(final_openai_dataset)
|
||||
print(f"\nTotal conversations converted (after filtering): {final_size:,}")
|
||||
print(f"\nWriting file: {output_file}")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(final_openai_dataset, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert ShareGPT dataset to OpenAI API format"
|
||||
)
|
||||
parser.add_argument("input_file", help="Path to the input ShareGPT JSON file")
|
||||
parser.add_argument(
|
||||
"output_file", help="Path to the output OpenAI format JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=0, help="Seed for random number generators"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-items",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of items in the output file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-turns",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Minimum number of turns per conversation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-turns",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of turns per conversation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-content-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Min number of characters in the messages' content",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-content-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max number of characters in the messages' content",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="LLM model, only the tokenizer will be used",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_sharegpt_to_openai(
|
||||
args.seed,
|
||||
args.input_file,
|
||||
args.output_file,
|
||||
args.max_items,
|
||||
args.min_content_len,
|
||||
args.max_content_len,
|
||||
args.min_turns,
|
||||
args.max_turns,
|
||||
args.model,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
35
benchmarks/multi_turn/generate_multi_turn.json
Normal file
35
benchmarks/multi_turn/generate_multi_turn.json
Normal file
@ -0,0 +1,35 @@
|
||||
{
|
||||
"filetype": "generate_conversations",
|
||||
"num_conversations": 24,
|
||||
"text_files": ["pg1184.txt"],
|
||||
"print_stats": false,
|
||||
"prompt_input": {
|
||||
"num_turns": {
|
||||
"distribution": "uniform",
|
||||
"min": 12,
|
||||
"max": 18
|
||||
},
|
||||
"common_prefix_num_tokens": {
|
||||
"distribution": "constant",
|
||||
"value": 500
|
||||
},
|
||||
"prefix_num_tokens": {
|
||||
"distribution": "lognormal",
|
||||
"mean": 6,
|
||||
"sigma": 4,
|
||||
"max": 1500
|
||||
},
|
||||
"num_tokens": {
|
||||
"distribution": "uniform",
|
||||
"min": 120,
|
||||
"max": 160
|
||||
}
|
||||
},
|
||||
"prompt_output": {
|
||||
"num_tokens": {
|
||||
"distribution": "uniform",
|
||||
"min": 80,
|
||||
"max": 120
|
||||
}
|
||||
}
|
||||
}
|
||||
5
benchmarks/multi_turn/requirements.txt
Normal file
5
benchmarks/multi_turn/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
numpy>=1.24
|
||||
pandas>=2.0.0
|
||||
aiohttp>=3.10
|
||||
transformers>=4.46
|
||||
xlsxwriter>=3.2.1
|
||||
Loading…
x
Reference in New Issue
Block a user