From f0964e29cb3b2deccdad89f5f8c068d3a629d239 Mon Sep 17 00:00:00 2001 From: Daniel Serebrenik <74646983+pliops-daniels@users.noreply.github.com> Date: Fri, 8 Aug 2025 20:28:50 +0300 Subject: [PATCH] [Benchmark] Add benchmark tool for multi turn conversations (#20267) --- benchmarks/multi_turn/README.md | 71 + benchmarks/multi_turn/bench_dataset.py | 493 ++++++ benchmarks/multi_turn/bench_utils.py | 25 + .../benchmark_serving_multi_turn.py | 1557 +++++++++++++++++ .../multi_turn/convert_sharegpt_to_openai.py | 354 ++++ .../multi_turn/generate_multi_turn.json | 35 + benchmarks/multi_turn/requirements.txt | 5 + 7 files changed, 2540 insertions(+) create mode 100644 benchmarks/multi_turn/README.md create mode 100644 benchmarks/multi_turn/bench_dataset.py create mode 100644 benchmarks/multi_turn/bench_utils.py create mode 100644 benchmarks/multi_turn/benchmark_serving_multi_turn.py create mode 100644 benchmarks/multi_turn/convert_sharegpt_to_openai.py create mode 100644 benchmarks/multi_turn/generate_multi_turn.json create mode 100644 benchmarks/multi_turn/requirements.txt diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md new file mode 100644 index 0000000000000..ae0866ae60751 --- /dev/null +++ b/benchmarks/multi_turn/README.md @@ -0,0 +1,71 @@ +# Benchmark KV Cache Offloading with Multi-Turn Conversations + +The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `requirements.txt` + +First start serving your model + +```bash +export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ + +vllm serve $MODEL_NAME --disable-log-requests +``` + +## Synthetic Multi-Turn Conversations + +Download the following text file (used for generation of synthetic conversations) + +```bash +wget https://www.gutenberg.org/ebooks/1184.txt.utf-8 +mv 1184.txt.utf-8 pg1184.txt +``` + +The filename `pg1184.txt` is used in `generate_multi_turn.json` (see `"text_files"`). + +But you may use other text files if you prefer (using this specific file is not required). + +Then run the benchmarking script + +```bash +export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ + +python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \ +--num-clients 2 --max-active-conversations 6 +``` + +You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.). + +If successful, you will see the following output + +```bash +---------------------------------------------------------------------------------------------------- +Statistics summary: +runtime_sec = 215.810 +requests_per_sec = 0.769 +---------------------------------------------------------------------------------------------------- + count mean std min 25% 50% 75% 90% 99% max +ttft_ms 166.0 78.22 67.63 45.91 59.94 62.26 64.43 69.66 353.18 567.54 +tpot_ms 166.0 25.37 0.57 24.40 25.07 25.31 25.50 25.84 27.50 28.05 +latency_ms 166.0 2591.07 326.90 1998.53 2341.62 2573.01 2860.10 3003.50 3268.46 3862.94 +input_num_turns 166.0 7.43 4.57 1.00 3.00 7.00 11.00 13.00 17.00 17.00 +input_num_tokens 166.0 2006.20 893.56 522.00 1247.75 2019.00 2718.00 3233.00 3736.45 3899.00 +output_num_tokens 166.0 100.01 11.80 80.00 91.00 99.00 109.75 116.00 120.00 120.00 +output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 115.00 119.00 119.00 +---------------------------------------------------------------------------------------------------- +``` + +## ShareGPT Conversations + +To run with the ShareGPT data, download the following ShareGPT dataset: +`https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json` + +Use the `convert_sharegpt_to_openai.py` script to convert the dataset to a format supported by `benchmark_serving_multi_turn.py` + +```bash +python convert_sharegpt_to_openai.py sharegpt_20230401_clean_lang_split.json sharegpt_conv_128.json --seed=99 --max-items=128 +``` + +The script will convert the ShareGPT dataset to a dataset with the standard user/assistant roles. + +The flag `--max-items=128` is used to sample 128 conversations from the original dataset (change as needed). + +Use the output JSON file `sharegpt_conv_128.json` as the `--input-file` for `benchmark_serving_multi_turn.py`. diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py new file mode 100644 index 0000000000000..411b89dd23dc6 --- /dev/null +++ b/benchmarks/multi_turn/bench_dataset.py @@ -0,0 +1,493 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from statistics import mean +from typing import Any, NamedTuple, Optional, Union + +import numpy as np # type: ignore +import pandas as pd # type: ignore +from bench_utils import ( + TEXT_SEPARATOR, + Color, + logger, +) +from transformers import AutoTokenizer # type: ignore + +# Conversation ID is a string (e.g: "UzTK34D") +ConvId = str + +# A list of dicts (dicts with keys "id" and "messages") +ShareGptConversations = list[dict[str, Any]] + +# A list of dicts (dicts with keys "role" and "content") +MessagesList = list[dict[str, str]] + +# Map conversation ID to conversation messages +ConversationsMap = list[ConvId, MessagesList] + + +class Distribution(ABC): + @abstractmethod + def sample(self, size: int = 1) -> np.ndarray: + pass + + +class UniformDistribution(Distribution): + def __init__( + self, + min_val: Union[int, float], + max_val: Union[int, float], + is_integer: bool = True, + ) -> None: + self.min_val = min_val + self.max_val = max_val + self.is_integer = is_integer + + def sample(self, size: int = 1) -> np.ndarray: + if self.is_integer: + return np.random.randint( + int(self.min_val), int(self.max_val + 1), size=size + ) + else: + return np.random.uniform(self.min_val, self.max_val, size=size) + + def __repr__(self) -> str: + return f"UniformDistribution[{self.min_val}, {self.max_val}]" + + +class ConstantDistribution(Distribution): + def __init__(self, value: Union[int, float]) -> None: + self.value = value + self.max_val = value + + def sample(self, size: int = 1) -> np.ndarray: + return np.full(shape=size, fill_value=self.value) + + def __repr__(self) -> str: + return f"Constant[{self.value}]" + + +class ZipfDistribution(Distribution): + def __init__(self, alpha: float, max_val: Optional[int] = None) -> None: + self.alpha = alpha + self.max_val = max_val + + def sample(self, size: int = 1) -> np.ndarray: + samples = np.random.zipf(self.alpha, size=size) + if self.max_val: + samples = np.minimum(samples, self.max_val) + return samples + + def __repr__(self) -> str: + return f"ZipfDistribution[{self.alpha}]" + + +class PoissonDistribution(Distribution): + def __init__(self, alpha: float, max_val: Optional[int] = None) -> None: + self.alpha = alpha + self.max_val = max_val + + def sample(self, size: int = 1) -> np.ndarray: + samples = np.random.poisson(self.alpha, size=size) + if self.max_val: + samples = np.minimum(samples, self.max_val) + return samples + + def __repr__(self) -> str: + return f"PoissonDistribution[{self.alpha}]" + + +class LognormalDistribution(Distribution): + def __init__( + self, mean: float, sigma: float, max_val: Optional[int] = None + ) -> None: + self.mean = mean + self.sigma = sigma + self.max_val = max_val + + def sample(self, size: int = 1) -> np.ndarray: + samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size) + if self.max_val: + samples = np.minimum(samples, self.max_val) + + return np.round(samples).astype(int) + + def __repr__(self) -> str: + return f"LognormalDistribution[{self.mean}, {self.sigma}]" + + +class GenConvArgs(NamedTuple): + num_conversations: int + text_files: list[str] + input_num_turns: Distribution + input_common_prefix_num_tokens: Distribution + input_prefix_num_tokens: Distribution + input_num_tokens: Distribution + output_num_tokens: Distribution + print_stats: bool + + +def verify_field_exists( + conf: dict, field_name: str, section: str, subsection: str +) -> None: + if field_name not in conf: + raise ValueError( + f"Missing field '{field_name}' in {section=} and {subsection=}" + ) + + +def get_random_distribution( + conf: dict, section: str, subsection: str, optional: bool = False +) -> Distribution: + # section can be "prompt_input" or "prompt_output" (both required) + conf = conf[section] + + if optional and subsection not in conf: + # Optional subsection, if not found assume the value is always 0 + return ConstantDistribution(0) + + # subsection can be "num_turns", "num_tokens" or "prefix_num_tokens" + if subsection not in conf: + raise ValueError(f"Missing subsection {subsection} in section {section}") + + conf = conf[subsection] + + distribution = conf.get("distribution") + if distribution is None: + raise ValueError( + f"Missing field 'distribution' in {section=} and {subsection=}" + ) + + if distribution == "constant": + verify_field_exists(conf, "value", section, subsection) + return ConstantDistribution(conf["value"]) + + elif distribution == "zipf": + verify_field_exists(conf, "alpha", section, subsection) + max_val = conf.get("max", None) + return ZipfDistribution(conf["alpha"], max_val=max_val) + + elif distribution == "poisson": + verify_field_exists(conf, "alpha", section, subsection) + max_val = conf.get("max", None) + return PoissonDistribution(conf["alpha"], max_val=max_val) + + elif distribution == "lognormal": + verify_field_exists(conf, "mean", section, subsection) + verify_field_exists(conf, "sigma", section, subsection) + max_val = conf.get("max", None) + return LognormalDistribution(conf["mean"], conf["sigma"], max_val=max_val) + + elif distribution == "uniform": + verify_field_exists(conf, "min", section, subsection) + verify_field_exists(conf, "max", section, subsection) + + min_value = conf["min"] + max_value = conf["max"] + + assert min_value > 0 + assert min_value <= max_value + + is_integer = isinstance(min_value, int) and isinstance(max_value, int) + return UniformDistribution(min_value, max_value, is_integer) + else: + raise ValueError(f"Unknown distribution: {distribution}") + + +def parse_input_json_file(conf: dict) -> GenConvArgs: + # Validate the input file + assert isinstance(conf, dict) + required_fields = [ + "filetype", + "num_conversations", + "text_files", + "prompt_input", + "prompt_output", + ] + for field in required_fields: + assert field in conf, f"Missing field {field} in input {conf}" + + assert conf["filetype"] == "generate_conversations" + + assert conf["num_conversations"] > 0, "num_conversations should be larger than zero" + + text_files = conf["text_files"] + + assert isinstance(text_files, list), "Field 'text_files' should be a list" + assert len(text_files) > 0, ( + "Field 'text_files' should be a list with at least one file" + ) + + # Parse the parameters for the prompt input/output workload + input_num_turns = get_random_distribution(conf, "prompt_input", "num_turns") + input_num_tokens = get_random_distribution(conf, "prompt_input", "num_tokens") + input_common_prefix_num_tokens = get_random_distribution( + conf, "prompt_input", "common_prefix_num_tokens", optional=True + ) + input_prefix_num_tokens = get_random_distribution( + conf, "prompt_input", "prefix_num_tokens" + ) + output_num_tokens = get_random_distribution(conf, "prompt_output", "num_tokens") + + print_stats: bool = conf.get("print_stats", False) + assert isinstance(print_stats, bool), ( + "Field 'print_stats' should be either 'true' or 'false'" + ) + + args = GenConvArgs( + num_conversations=conf["num_conversations"], + text_files=text_files, + input_num_turns=input_num_turns, + input_common_prefix_num_tokens=input_common_prefix_num_tokens, + input_prefix_num_tokens=input_prefix_num_tokens, + input_num_tokens=input_num_tokens, + output_num_tokens=output_num_tokens, + print_stats=print_stats, + ) + return args + + +def print_conv_stats(conversations: ConversationsMap, tokenizer: AutoTokenizer) -> None: + # Collect statistics + conv_stats: list[dict[Any, Any]] = [] + req_stats: list[int] = [] + + print("\nCollecting statistics...") + for messages in conversations.values(): + # messages is a list of dicts + user_tokens: list[int] = [] + assistant_tokens: list[int] = [] + request_tokens: list[int] = [] + + req_tokens = 0 + for m in messages: + content = m["content"] + num_tokens = len(tokenizer(content).input_ids) + + if m["role"] == "user": + user_tokens.append(num_tokens) + # New user prompt including all chat history + req_tokens += num_tokens + request_tokens.append(req_tokens) + + elif m["role"] == "assistant": + assistant_tokens.append(num_tokens) + # Update assistant answer + # (will be part of chat history for the next user prompt) + req_tokens += num_tokens + + item_stats = { + "conversation_turns": len(messages), + "user_tokens": mean(user_tokens), + "assistant_tokens": mean(assistant_tokens), + } + + conv_stats.append(item_stats) + req_stats.extend(request_tokens) + + # Print statistics + percentiles = [0.25, 0.5, 0.75, 0.9, 0.99] + + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}") + print(TEXT_SEPARATOR) + df = pd.DataFrame(conv_stats) + print(df.describe(percentiles=percentiles).transpose()) + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Request statistics:{Color.RESET}") + print(TEXT_SEPARATOR) + df = pd.DataFrame(req_stats, columns=["request_tokens"]) + print(df.describe(percentiles=percentiles).transpose()) + print(TEXT_SEPARATOR) + + +def generate_conversations( + args: GenConvArgs, tokenizer: AutoTokenizer +) -> ConversationsMap: + # Text for all user prompts + # (text from the input text files will be appended to this line) + base_prompt_text = "Please rewrite the following text and add more content: " + base_prompt_token_count = len( + tokenizer.encode(base_prompt_text, add_special_tokens=False) + ) + + logger.info(f"{Color.PURPLE}Generating conversations...{Color.RESET}") + logger.info(args) + + list_of_tokens = [] + + for filename in args.text_files: + # Load text file that will be used to generate prompts + with open(filename) as file: + data = file.read() + tokens_in_file = tokenizer.encode(data, add_special_tokens=False) + list_of_tokens.extend(tokens_in_file) + + conversations: ConversationsMap = {} + conv_id = 0 + + # Generate number of turns for every conversation + turn_count: np.ndarray = args.input_num_turns.sample(args.num_conversations) + + # Turn count should be at least 2 (one user prompt and one assistant answer) + turn_count = np.maximum(turn_count, 2) + + # Round up to an even number (every user prompt should have an answer) + turn_count = turn_count + (turn_count % 2) + + # Generate number of prefix tokens for every conversation + conv_prefix_tokens: np.ndarray = args.input_prefix_num_tokens.sample( + args.num_conversations + ) + + # Used to reduce shared text between conversations + # (jump/skip over text sections between conversations) + base_offset = 0 + + # Common prefix size for all conversations (only 1 sample required) + common_prefix_text = "" + common_prefix_tokens: int = args.input_common_prefix_num_tokens.sample(1)[0] + if common_prefix_tokens > 0: + # Using "." at the end to separate sentences + common_prefix_text = ( + tokenizer.decode(list_of_tokens[: common_prefix_tokens - 2]) + "." + ) + base_offset += common_prefix_tokens + + for conv_id in range(args.num_conversations): + # Generate a single conversation + messages: MessagesList = [] + + nturns = turn_count[conv_id] + + # User prompt token count per turn (with lower limit) + input_token_count: np.ndarray = args.input_num_tokens.sample(nturns) + input_token_count = np.maximum(input_token_count, base_prompt_token_count) + + # Assistant answer token count per turn (with lower limit) + output_token_count: np.ndarray = args.output_num_tokens.sample(nturns) + output_token_count = np.maximum(output_token_count, 1) + + user_turn = True + for turn_id in range(nturns): + if user_turn: + role = "user" + num_tokens = input_token_count[turn_id] + + # Generate the user prompt, + # use a unique prefix (the conv_id) for each conversation + # (to avoid shared prefix between conversations) + content = f"{conv_id} is a nice number... " + + if len(common_prefix_text) > 0 and turn_id == 0: + content = common_prefix_text + content + + # Update the number of tokens left for the content + num_tokens -= len(tokenizer.encode(content, add_special_tokens=False)) + + if turn_id == 0: + prefix_num_tokens = conv_prefix_tokens[conv_id] + if prefix_num_tokens > 0: + # Add prefix text (context) to the first turn + start_offset = base_offset + end_offset = start_offset + prefix_num_tokens + assert len(list_of_tokens) > end_offset, ( + "Not enough input text to generate " + f"{prefix_num_tokens} tokens for the " + f"prefix text ({start_offset=}, {end_offset=})" + ) + + content += f"{conv_id}, " + tokenizer.decode( + list_of_tokens[start_offset:end_offset] + ) + base_offset += prefix_num_tokens + + # Add the actual user prompt/question after the prefix text + content += base_prompt_text + num_tokens -= base_prompt_token_count + + if num_tokens > 0: + # Add text from the input file (to reach the desired token count) + start_offset = base_offset + turn_id * input_token_count.max() + end_offset = start_offset + num_tokens + assert len(list_of_tokens) > end_offset, ( + f"Not enough input text to generate {num_tokens} tokens " + f"for the prompt ({start_offset=}, {end_offset=})" + ) + + # Convert tokens back to text + content += tokenizer.decode(list_of_tokens[start_offset:end_offset]) + else: + role = "assistant" + # This content will not be used as input to the LLM server + # (actual answers will be used instead). + # Content is only required to determine the min_tokens/max_tokens + # (inputs to the LLM server). + num_tokens = output_token_count[turn_id] + assert len(list_of_tokens) > num_tokens, ( + f"Not enough input text to generate {num_tokens} " + "tokens for assistant content" + ) + content = tokenizer.decode(list_of_tokens[:num_tokens]) + + # Append the user/assistant message to the list of messages + messages.append({"role": role, "content": content}) + user_turn = not user_turn + + # Add the new conversation + conversations[f"CONV_ID_{conv_id}"] = messages + + # Increase base offset for the next conversation + base_offset += nturns + + if args.print_stats: + print_conv_stats(conversations, tokenizer) + + return conversations + + +def conversations_list_to_dict(input_list: ShareGptConversations) -> ConversationsMap: + conversations: ConversationsMap = {} + + for item in input_list: + conv_id: str = item["id"] + assert isinstance(conv_id, str) + + assert conv_id not in conversations, ( + f"Conversation ID {conv_id} found more than once in the input" + ) + + messages: MessagesList = item["messages"] + assert isinstance(messages, list), ( + f"Conversation messages should be a list (ID: {conv_id})" + ) + assert len(messages) > 0, f"Conversation with no messages (ID: {conv_id})" + + conversations[conv_id] = messages + + logger.info(f"Using {len(conversations)} unique conversations (IDs)") + assert len(conversations) == len(input_list) + + # Print statistics about the selected conversations + stats: list[dict[str, Any]] = [] + for conv_data in conversations.values(): + stats.append({"num_turns": len(conv_data)}) + + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}") + print(TEXT_SEPARATOR) + percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999] + conv_stats = pd.DataFrame(stats).describe(percentiles=percentiles) + print(conv_stats.transpose()) + print(TEXT_SEPARATOR) + + return conversations + + +def conversations_dict_to_list(input_dict: ConversationsMap) -> ShareGptConversations: + output: ShareGptConversations = [] + for conv_id, conv_data in input_dict.items(): + new_item = {"id": conv_id, "messages": conv_data} + output.append(new_item) + + return output diff --git a/benchmarks/multi_turn/bench_utils.py b/benchmarks/multi_turn/bench_utils.py new file mode 100644 index 0000000000000..d4d3c1ca8c52f --- /dev/null +++ b/benchmarks/multi_turn/bench_utils.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +from enum import Enum + + +class Color(str, Enum): + RED = "\033[91m" + GREEN = "\033[92m" + BLUE = "\033[94m" + PURPLE = "\033[95m" + CYAN = "\033[96m" + YELLOW = "\033[93m" + RESET = "\033[0m" + + +TEXT_SEPARATOR = "-" * 100 + +# Configure the logger +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] - %(message)s", + datefmt="%d-%m-%Y %H:%M:%S", +) +logger = logging.getLogger(__name__) diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py new file mode 100644 index 0000000000000..53c3207491d18 --- /dev/null +++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py @@ -0,0 +1,1557 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import asyncio +import json +import logging +import multiprocessing as mp +import os +import random +import time +from collections import Counter, deque +from datetime import datetime +from enum import Enum +from http import HTTPStatus +from statistics import mean +from typing import NamedTuple, Optional, Union + +import aiohttp # type: ignore +import numpy as np # type: ignore +import pandas as pd # type: ignore +from bench_dataset import ( + ConversationsMap, + ConvId, + GenConvArgs, + MessagesList, + ShareGptConversations, + conversations_dict_to_list, + conversations_list_to_dict, + generate_conversations, + parse_input_json_file, +) +from bench_utils import TEXT_SEPARATOR, Color, logger +from transformers import AutoTokenizer # type: ignore + +NUM_TOKENS_FROM_DATASET = 0 +TERM_SIGNAL = None + + +class ConversationSampling(str, Enum): + ROUND_ROBIN = "round_robin" + RANDOM = "random" + + def __str__(self): + return self.value + + +class ClientArgs(NamedTuple): + seed: int + max_num_requests: Optional[int] + skip_first_turn: bool + max_turns: Optional[int] + max_active_conversations: int + verbose: bool + print_content: bool + verify_output: bool + conversation_sampling: ConversationSampling + request_rate: float + + +class RequestArgs(NamedTuple): + chat_url: str + model: str + stream: bool + limit_min_tokens: int # Use negative value for no limit + limit_max_tokens: int # Use negative value for no limit + + +class BenchmarkArgs(NamedTuple): + url: str + num_clients: int + early_stop: bool + + +class ServerResponse(NamedTuple): + valid: bool + ttft_ms: float # time to first chunk + tpot_ms: float # time per output chunk (one or more tokens) + latency_ms: float + start_time_ms: float + first_chunk: str # first chunk of the content + content: str # includes the first_chunk + num_chunks: int + + def __str__(self) -> str: + return f"ttft_ms {self.ttft_ms:.2f}, tpot_ms {self.tpot_ms:.2f}, latency_ms {self.latency_ms:.2f}" # noqa: E501 + + +class RequestStats(NamedTuple): + ttft_ms: float + tpot_ms: float + latency_ms: float + start_time_ms: float + input_num_turns: int + input_num_tokens: int + output_num_tokens: int + output_num_chunks: int + output_num_first_chunk_tokens: int + approx_cached_percent: float + conversation_id: str + client_id: int + + def __str__(self) -> str: + return ( + f"ttft_ms {self.ttft_ms:.2f}, tpot_ms {self.tpot_ms:.2f}, latency_ms {self.latency_ms:.2f}, input_num_tokens {self.input_num_tokens}, " # noqa: E501 + f"output_num_tokens {self.output_num_tokens} ({self.output_num_chunks} chunks, {self.output_num_first_chunk_tokens} tokens in first chunk), " # noqa: E501 + f"approx_cached_percent {self.approx_cached_percent:.2f}%" + ) + + +class MetricStats: + def __init__(self) -> None: + self.min: Optional[float] = None + self.max: Optional[float] = None + self.avg: Optional[float] = None + self.sum = 0.0 + self.count = 0 + + def update(self, value: float) -> None: + if self.min is None: + self.min = value + else: + self.min = min(self.min, value) + + if self.max is None: + self.max = value + else: + self.max = max(self.max, value) + + self.sum += value + self.count += 1 + self.avg = self.sum / self.count + + def __repr__(self) -> str: + if self.count == 0: + return "no data" + return f"avg: {self.avg:>10.3f}, min: {self.min:>10.3f}, max: {self.max:>10.3f}" + + +class MovingAverage: + def __init__(self, window_size: int) -> None: + self.window_size = window_size + self.window = np.zeros(window_size) + self.index = 0 + self.sum = 0.0 + self.count = 0 + self.avg: Optional[float] = None + + def update(self, new_value: float) -> None: + if self.count < self.window_size: + # Filling up the window + self.sum += new_value + self.window[self.count] = new_value + self.count += 1 + else: + # Window is full, start replacing old values + old_value = self.window[self.index] + self.sum = self.sum - old_value + new_value + self.window[self.index] = new_value + self.index = (self.index + 1) % self.window_size + + self.avg = self.sum / self.count + + def __repr__(self) -> str: + if self.count == 0: + return "no data" + return f"avg: {self.avg:>10.3f} ({self.count} samples)" + + +class DebugStats: + def __init__(self, logger: logging.Logger, window_size: int) -> None: + self.logger = logger + self.metrics: dict[str, Union[MovingAverage, MetricStats]] = { + "moving_avg_ttft_ms": MovingAverage(window_size), + "moving_avg_tpot_ms": MovingAverage(window_size), + "ttft_ms": MetricStats(), + "tpot_ms": MetricStats(), + "latency_ms": MetricStats(), + "input_num_turns": MetricStats(), + "input_num_tokens": MetricStats(), + "output_num_tokens": MetricStats(), + } + + def update(self, data: RequestStats) -> None: + self.metrics["ttft_ms"].update(data.ttft_ms) + self.metrics["moving_avg_ttft_ms"].update(data.ttft_ms) + self.metrics["tpot_ms"].update(data.tpot_ms) + self.metrics["moving_avg_tpot_ms"].update(data.tpot_ms) + self.metrics["latency_ms"].update(data.latency_ms) + self.metrics["input_num_turns"].update(data.input_num_turns) + self.metrics["input_num_tokens"].update(data.input_num_tokens) + self.metrics["output_num_tokens"].update(data.output_num_tokens) + + def print(self) -> None: + self.logger.info("-" * 50) + for k, v in self.metrics.items(): + kv_info = f"[{k:25}] {v}" + self.logger.info(kv_info) + self.logger.info("-" * 50) + + +# Must support Python 3.8, we can't use str.removeprefix(prefix) +# introduced in Python 3.9 +def remove_prefix(text: str, prefix: str) -> str: + if text.startswith(prefix): + return text[len(prefix) :] + return text + + +def nanosec_to_millisec(value: float) -> float: + return value / 1000000.0 + + +def nanosec_to_sec(value: float) -> float: + return value / 1000000000.0 + + +async def send_request( + session: aiohttp.ClientSession, + messages: list[dict[str, str]], + chat_url: str, + model: str, + stream: bool = True, + min_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, +) -> ServerResponse: + payload = { + "model": model, + "messages": messages, + "seed": 0, + "temperature": 0.0, + } + + if stream: + payload["stream"] = True + payload["stream_options"] = {"include_usage": False} + + if min_tokens is not None: + payload["min_tokens"] = min_tokens + + if max_tokens is not None: + payload["max_tokens"] = max_tokens + + headers = {"Content-Type": "application/json"} + + # Calculate the timeout for the request + timeout_sec = 120 + if max_tokens is not None: + # Assume TPOT of 200ms and use max_tokens to determine timeout + timeout_sec = max(timeout_sec, int(max_tokens * 0.2)) + timeout = aiohttp.ClientTimeout(total=timeout_sec) + + valid_response = True + ttft: Optional[float] = None + chunk_delay: list[int] = [] + latency: Optional[float] = None + first_chunk = "" + generated_text = "" + + start_time: int = time.perf_counter_ns() + most_recent_timestamp: int = start_time + + async with session.post( + url=chat_url, json=payload, headers=headers, timeout=timeout + ) as response: + http_status = HTTPStatus(response.status) + if http_status == HTTPStatus.OK: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + if chunk == "[DONE]": + # End of stream + latency = time.perf_counter_ns() - start_time + elif stream is False: + data = json.loads(chunk) + message = data["choices"][0]["message"] + assert message["role"] == "assistant" + generated_text += message["content"] + else: + timestamp: int = time.perf_counter_ns() + data = json.loads(chunk) + + # Delta is the new content/text/data + delta = data["choices"][0]["delta"] + if delta.get("content", None): + if ttft is None: + # First token + first_token_time = time.perf_counter_ns() + ttft = first_token_time - start_time + first_chunk = delta["content"] + else: + # Decoding phase + chunk_delay.append(timestamp - most_recent_timestamp) + + generated_text += delta["content"] + + most_recent_timestamp = timestamp + else: + valid_response = False + content = await response.text() + logger.warning( + f"{Color.YELLOW}Received HTTP status {http_status.value} " + f"({http_status.phrase}): {content}{Color.RESET}" + ) + + if latency is None: + latency = -1.0 + if valid_response: + # Streaming is disabled, latency was not set + latency = time.perf_counter_ns() - start_time + + if ttft is None: + # The response was a single chunk + ttft = latency + + # Each chunk may include more than one token + tpot: float = mean(chunk_delay) if len(chunk_delay) > 0 else 0.0 + num_chunks: int = len(chunk_delay) + + sr = ServerResponse( + valid=valid_response, + ttft_ms=nanosec_to_millisec(ttft) if ttft > 0.0 else -1.0, + tpot_ms=nanosec_to_millisec(tpot), + latency_ms=nanosec_to_millisec(latency), + start_time_ms=nanosec_to_millisec(start_time), + first_chunk=first_chunk, + content=generated_text, + num_chunks=num_chunks, + ) + return sr + + +def get_short_string(input: str) -> str: + n = 20 + if len(input) < 400: + return input + + return f"{input[:n]}...{input[-n:]}" + + +def get_token_count(tokenizer: AutoTokenizer, text: str) -> int: + return len(tokenizer(text, add_special_tokens=False).input_ids) + + +def get_messages_token_count( + tokenizer: AutoTokenizer, messages: list[dict[str, str]] +) -> int: + token_count = 0 + for m in messages: + token_count += get_token_count(tokenizer, m["content"]) + + return token_count + + +async def send_turn( + session: aiohttp.ClientSession, + client_id: int, + conv_id: str, + conversation_messages: MessagesList, + messages_to_use: int, + tokenizer: AutoTokenizer, + req_args: RequestArgs, + verbose: bool, + verify_output: bool, +) -> Optional[RequestStats]: + assert messages_to_use > 0 + assert messages_to_use <= len(conversation_messages) + + messages = conversation_messages[:messages_to_use] + + # Index of the next message (the role should be "user") + index = messages_to_use - 1 + + # Verify that the message has only two keys, "role" and "content" + assert len(messages[index].keys()) == 2 + assert "role" in messages[index] and "content" in messages[index] + assert messages[index]["role"] == "user", ( + f"Failed on conversation ID {conv_id}, message role should be user" + ) + + if verbose: + print( + f"{Color.CYAN}Messages (conversation ID {conv_id}," + f" {len(messages)} turns):{Color.RESET}", + messages, + ) + + # None means that there is no upper/lower limit for the output token count + min_tokens = None if req_args.limit_min_tokens < 0 else req_args.limit_min_tokens + max_tokens = None if req_args.limit_max_tokens < 0 else req_args.limit_max_tokens + + if len(conversation_messages) > messages_to_use: + # The conversation contains an assistant answer for the next user prompt + if ( + min_tokens == NUM_TOKENS_FROM_DATASET + or max_tokens == NUM_TOKENS_FROM_DATASET + ): + # Compute number of tokens in the answer (from the input conversation) + assistant_answer = conversation_messages[messages_to_use] + answer_num_tokens = get_token_count(tokenizer, assistant_answer["content"]) + assert assistant_answer["role"] == "assistant" + + if min_tokens == NUM_TOKENS_FROM_DATASET: + min_tokens = max(1, answer_num_tokens) + + if max_tokens == NUM_TOKENS_FROM_DATASET: + max_tokens = max(1, answer_num_tokens) + + # Send the current conversation to LLM and get a response + response: ServerResponse = await send_request( + session, + messages, + req_args.chat_url, + req_args.model, + req_args.stream, + min_tokens, + max_tokens, + ) + + if response.valid is False: + # Request failed + return None + + # Compute number of tokens in input / output + input_num_tokens = get_messages_token_count(tokenizer, messages) + + # Num tokens in the user's last question + question_num_tokens = get_token_count(tokenizer, messages[index]["content"]) + + # Num tokens in the history/context of the question + assert input_num_tokens >= question_num_tokens + history_num_tokens = input_num_tokens - question_num_tokens + + # Num tokens in the LLM's answer (first chunk and full answer) + first_chunk_tokens = get_token_count(tokenizer, response.first_chunk) + + output_content = response.content + output_num_tokens = get_token_count(tokenizer, output_content) + + # Prefix caching approximated cached percent + approx_cached_percent = ( + 100.0 * (history_num_tokens / input_num_tokens) if input_num_tokens > 0 else 0.0 + ) + + # Compute the correct TTFT and TPOT (based on tokens and not chunks). + # Required because multiple output tokens may be bundled in a single chunk. + if output_num_tokens > 1 and output_num_tokens > first_chunk_tokens: + # More than one token and more than one chunk in the output + decode_ms = response.latency_ms - response.ttft_ms + decode_num_tokens = output_num_tokens - first_chunk_tokens + tpot_ms = decode_ms / decode_num_tokens + else: + # In this case: output_num_tokens == first_chunk_tokens + # Output was a single chunk (output_num_tokens > 1) + # or even a single token (output_num_tokens == 1) + tpot_ms = 0.0 + + if first_chunk_tokens > 1: + # First chunk had multiple tokens, adjust TTFT for a single token + delta_ms = (first_chunk_tokens - 1) * tpot_ms + ttft_ms = max(0.1, response.ttft_ms - delta_ms) + else: + # First chunk had only one token + ttft_ms = response.ttft_ms + + rs = RequestStats( + ttft_ms=ttft_ms, + tpot_ms=tpot_ms, + latency_ms=response.latency_ms, + start_time_ms=response.start_time_ms, + input_num_turns=len(messages), + input_num_tokens=input_num_tokens, + output_num_tokens=output_num_tokens, + output_num_chunks=response.num_chunks, + output_num_first_chunk_tokens=first_chunk_tokens, + approx_cached_percent=approx_cached_percent, + conversation_id=conv_id, + client_id=client_id, + ) + + if verbose: + print( + f"\n{Color.YELLOW}Response ({output_num_tokens} tokens):{Color.RESET}", + output_content, + ) + print(f"{Color.YELLOW}Response metrics: {rs}{Color.RESET}") + print("-" * 70) + + # Save the LLM's answer (will be used as part of the context for the next user turn) + answer_index = messages_to_use + if len(conversation_messages) > answer_index: + assert conversation_messages[answer_index]["role"] == "assistant", ( + f"Failed on conversation ID {conv_id}, message role should be assistant" + ) + + orig_content = conversation_messages[answer_index]["content"] + if verify_output: + # Compare the new answer to the answer from the input file + debug_info = ( + f"LLM/dataset answers do not match ({conv_id}):" + f"\n'{get_short_string(output_content)}' (len: {len(output_content)})," + f"\n'{get_short_string(orig_content)}' (len: {len(orig_content)})" + ) + if orig_content != output_content: + raise ValueError(debug_info) + + # Update the answer + conversation_messages[answer_index]["content"] = output_content + else: + # A user prompt that has no answer, add the answer as a new message + new_answer = {"role": "assistant", "content": output_content} + conversation_messages.append(new_answer) + + return rs + + +async def poisson_sleep(request_rate: float, verbose: bool = False) -> None: + # Generate a random time interval from the Poisson distribution + assert request_rate > 0 + + interval = np.random.exponential(1.0 / request_rate) + if verbose: + logger.info(f"Sleeping for {interval:.3f} seconds...") + await asyncio.sleep(interval) + + +async def client_main( + args: ClientArgs, + req_args: RequestArgs, + client_id: int, + tokenizer: AutoTokenizer, + stop_event: mp.Event, # type: ignore + task_queue: mp.Queue, + result_queue: mp.Queue, + conv_queue: mp.Queue, +) -> None: + logger.info( + f"{Color.CYAN}Started client {client_id}: max_num_requests={args.max_num_requests}, max_active_conversations={args.max_active_conversations}{Color.RESET}" # noqa: E501 + ) + + random.seed(args.seed) + np.random.seed(args.seed) + + # Active conversations + active_convs: ConversationsMap = {} + conv_id_queue: deque = deque(maxlen=args.max_active_conversations) + + # Keep track of how many messages have been used for each conversation + turns_count: Counter = Counter() + num_successes = 0 + num_failures = 0 + + # Track the timestamp (time.perf_counter()) + # of the last turn per conversation (only for debug) + time_of_last_turn: dict[ConvId, float] = {} + + # Flag that indicates that there are no new tasks (conversations) for the client + task_queue_empty = False + + async with aiohttp.ClientSession() as session: + # Print progress + + while task_queue_empty is False: + result = None + + if ( + args.max_num_requests + and num_successes + num_failures == args.max_num_requests + ): + logger.info( + f"{Color.YELLOW}Client {client_id} reached " + f"request limit{Color.RESET}" + ) + break + + if stop_event.is_set(): # type: ignore + logger.info( + f"{Color.YELLOW}Client {client_id} received " + f"a termination signal{Color.RESET}" + ) + break + + while ( + len(active_convs) < args.max_active_conversations + and task_queue_empty is False + ): + # Get a new conversation from the task queue + conv_id, messages = task_queue.get() + + if conv_id is TERM_SIGNAL: + task_queue_empty = True + break + + if args.skip_first_turn: + # Skip the first turn (both user and assistant), + # relevant if warmup was enabled. + # Default turns_count[conv_id] will be zero if conv_id + # was never inserted/updated in turns_count. + turns_count[conv_id] += 2 + + if turns_count[conv_id] < len(messages): + # Add new conversation + active_convs[conv_id] = messages + conv_id_queue.append(conv_id) + + if args.verbose: + logger.info( + f"{Color.GREEN}Client {client_id} will use conversation ID {conv_id} (active conversations {len(active_convs)}){Color.RESET}" # noqa: E501 + ) + + elif args.verbose: + # No more messages (conversation finished during the warmup) + logger.info( + f"{Color.YELLOW}Client {client_id} will not use conversation ID {conv_id} (all {len(messages)} messages already sent){Color.RESET}" # noqa: E501 + ) + + if len(active_convs) == 0 or task_queue_empty: + logger.info( + f"{Color.YELLOW}Client {client_id} has no more work{Color.RESET}" + ) + break + + # Pick an active conversation for the next request + if args.conversation_sampling == ConversationSampling.ROUND_ROBIN: + conv_id = conv_id_queue.pop() + else: + # ConversationSampling.RANDOM + active_ids = list(active_convs.keys()) + conv_id = random.choice(active_ids) + + messages = active_convs[conv_id] + assert isinstance(messages, list) and len(messages) > 0 + + # Update the amount of messages to use + turns_count[conv_id] += 1 + current_turn = turns_count[conv_id] + + assert current_turn < len(messages), ( + f"Turn number {current_turn} is invalid for conversation ID {conv_id}" + f" that has only {len(messages)} messages" + ) + + if args.verbose: + curr_time_sec: float = time.perf_counter() + time_since_last_turn: Union[str, float] = "N/A" + if conv_id in time_of_last_turn: + time_since_last_turn = round( + curr_time_sec - time_of_last_turn[conv_id], 3 + ) + logger.info( + f"Client {client_id} using conversation ID {conv_id} (turn: {current_turn}, time since last turn [sec]: {time_since_last_turn})" # noqa: E501 + ) + time_of_last_turn[conv_id] = curr_time_sec + + success = True + try: + result = await send_turn( + session, + client_id, + conv_id, + messages, + current_turn, + tokenizer, + req_args, + args.print_content, + args.verify_output, + ) + if result is not None: + result_queue.put(result) + else: + # None means that the request failed, + # and should not be added to the statistics. + success = False + num_failures += 1 + + logger.warning( + f"{Color.YELLOW}Client {client_id} - Request rejected during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + ) + + # Remove the conversation (should not be used again) + active_convs.pop(conv_id) + + except asyncio.exceptions.TimeoutError: + num_failures += 1 + logger.exception( + f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + ) + break # Exit gracefully instead of raising an error + + except Exception: + num_failures += 1 + logger.exception( + f"{Color.RED}Client {client_id} - Exception during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501 + ) + break # Exit gracefully instead of raising an error + + if success: + num_successes += 1 + + # Update the turns counter to include the LLM response + # The LLM response will be used as context for the next user turn + turns_count[conv_id] += 1 + + max_turns = len(messages) + if args.max_turns is not None: + # Limit the number of turns in the conversation + max_turns = min(args.max_turns, max_turns) + + if turns_count[conv_id] >= max_turns: + # Conversation has no more turns (no longer active) + # save the updated conversation (with the LLM server's answer) + conv_queue.put((conv_id, active_convs.pop(conv_id))) + if args.verbose: + logger.info( + f"{Color.GREEN}Client {client_id} finished " + f"conversation ID {conv_id}{Color.RESET}" + ) + else: + # Conversation is not finished, insert it at the back of the queue + conv_id_queue.appendleft(conv_id) + + # Sleep between requests (if lambda is positive) + if args.request_rate > 0: + await poisson_sleep(args.request_rate, args.verbose) + + # Send indication that the client is done + conv_queue.put((TERM_SIGNAL, TERM_SIGNAL)) + + logger.info( + f"{Color.CYAN}Client {client_id} is done " + f"({num_successes=}, {num_failures=}){Color.RESET}" + ) + + +def worker_function( + client_id: int, + tokenizer: AutoTokenizer, + client_args: ClientArgs, + req_args: RequestArgs, + stop_event: mp.Event, # type: ignore + task_queue: mp.Queue, + result_queue: mp.Queue, + conv_queue: mp.Queue, +) -> None: + asyncio.run( + client_main( + client_args, + req_args, + client_id, + tokenizer, + stop_event, + task_queue, + result_queue, + conv_queue, + ) + ) + + +def get_client_config( + args: argparse.Namespace, input_conv: ConversationsMap +) -> tuple[ClientArgs, RequestArgs]: + if args.num_clients < 1: + raise ValueError("Number of clients must be a positive number") + + if len(input_conv) < args.num_clients: + raise ValueError( + "Number of conversations must be equal or larger than the number of clients" + ) + + max_req_per_client: Optional[int] = None + if args.max_num_requests is not None: + # Max number of requests per client + req_per_client = args.max_num_requests // args.num_clients + if req_per_client < 1: + raise ValueError("Number of requests should be at least one per client") + max_req_per_client = req_per_client + + max_active_conversations = args.max_active_conversations + if max_active_conversations is None: + # Each client will have only one active conversation at a time + max_active_conversations = args.num_clients + + if max_active_conversations > len(input_conv): + raise ValueError( + f"Max active conversations {max_active_conversations} " + "must be equal or less than the total number of conversations" + ) + + # Max number of active conversations per client + max_active_conv_per_client = max_active_conversations // args.num_clients + if max_active_conv_per_client < 1: + raise ValueError( + f"Max active conversations {max_active_conversations} " + "must be equal or greater than the number of clients" + ) + + # Skip the first user turn (as part of the warmup) + skip_first_turn = args.warmup_step + + # Common arguments for all clients + client_args = ClientArgs( + seed=args.seed, + max_num_requests=max_req_per_client, + skip_first_turn=skip_first_turn, + max_turns=args.max_turns, + max_active_conversations=max_active_conv_per_client, + verbose=args.verbose, + print_content=args.print_content, + verify_output=args.verify_output, + conversation_sampling=args.conversation_sampling, + request_rate=args.request_rate, + ) + + if args.limit_min_tokens > 0 or args.limit_max_tokens > 0: + if args.limit_min_tokens < 1 or args.limit_max_tokens < 1: + raise ValueError( + "Invalid min/max tokens limits (both limits should be provided)" + ) + if args.limit_min_tokens > args.limit_max_tokens: + raise ValueError( + "Invalid min/max tokens limits (min should not be larger than max)" + ) + + # Arguments for API requests + chat_url = f"{args.url}/v1/chat/completions" + req_args = RequestArgs( + chat_url=chat_url, + model=args.model, + stream=not args.no_stream, + limit_min_tokens=args.limit_min_tokens, + limit_max_tokens=args.limit_max_tokens, + ) + + return client_args, req_args + + +async def main_mp( + client_args: ClientArgs, + req_args: RequestArgs, + bench_args: BenchmarkArgs, + tokenizer: AutoTokenizer, + input_conv: ConversationsMap, +) -> tuple[ConversationsMap, list[RequestStats]]: + # An event that will trigger graceful termination of all the clients + stop_event = mp.Event() + + # Queue for input conversations (from the input file/dataset) + task_queue: mp.Queue = mp.Queue() + + # Queue for client measurements (TTFT, TPOT, etc. for each request) + result_queue: mp.Queue = mp.Queue() + + # Queue for output conversations (with the LLM answers, sent by the server) + conv_queue: mp.Queue = mp.Queue() + output_conv: ConversationsMap = {} + client_metrics: list[RequestStats] = [] + + # Start all clients + start_time = time.perf_counter_ns() + logger.info(f"{Color.GREEN}Starting {bench_args.num_clients} clients{Color.RESET}") + + clients = [] + for client_id in range(bench_args.num_clients): + client = mp.Process( + name=f"client_{client_id}", + target=worker_function, + args=( + client_id, + tokenizer, + client_args, + req_args, + stop_event, + task_queue, + result_queue, + conv_queue, + ), + ) + clients.append(client) + client.start() + + # Submit all the input conversations as tasks for the clients + for conv_id, messages in input_conv.items(): + task_queue.put((conv_id, messages)) + + # Add termination signals for clients + for _ in range(bench_args.num_clients): + task_queue.put((TERM_SIGNAL, TERM_SIGNAL)) + + # Collect the updated conversations from all clients + num_clients_finished = 0 + total_convs = len(input_conv) + + debug_stats = DebugStats(logger, min(15 * bench_args.num_clients, 500)) + + while num_clients_finished < bench_args.num_clients: + # Collect updated conversation + conv_id, messages = conv_queue.get() + + # Collect results (measurements) + while not result_queue.empty(): + new_data = result_queue.get() + client_metrics.append(new_data) + debug_stats.update(new_data) + + if conv_id is TERM_SIGNAL: + num_clients_finished += 1 + logger.info( + f"{Color.CYAN}{num_clients_finished} out of " + f"{bench_args.num_clients} clients finished{Color.RESET}" + ) + + if bench_args.early_stop and not stop_event.is_set(): + # Once one client finished, stop all other clients. + # there is no reason to continue the benchmark with fewer clients. + logger.info( + f"{Color.YELLOW}Sending termination signal to clients{Color.RESET}" + ) + stop_event.set() + else: + output_conv[conv_id] = messages + + finished_convs = len(output_conv) + percent = finished_convs / total_convs + + # Tuned to control the print rate (can be changed if required) + print_cycle = max(3, int(bench_args.num_clients / 4)) + + if finished_convs % print_cycle == 0: + runtime_sec = nanosec_to_sec(time.perf_counter_ns() - start_time) + logger.info( + f"{Color.CYAN}Finished {finished_convs} out of {total_convs} conversations ({percent:.0%}), " # noqa: E501 + f"{num_clients_finished} out of {bench_args.num_clients} clients finished, collected {len(client_metrics)} measurements, runtime {runtime_sec:.3f} sec{Color.RESET}" # noqa: E501 + ) + + rps: Union[str, float] = round(len(client_metrics) / runtime_sec, 3) + if len(client_metrics) < (5 * bench_args.num_clients): + # Do not estimate the RPS if the number of samples is very low + # (threshold can be tuned if needed) + rps = "N/A" + + runtime_left_sec: Union[str, float] = round( + (runtime_sec / finished_convs) * (total_convs - finished_convs), 3 + ) + if percent < 0.05: + # If less than 5% of the conversations were not finished, + # the estimation will probably be very inaccurate + # (threshold can be tuned if needed). + runtime_left_sec = "N/A" + + logger.info( + f"{Color.CYAN}Estimated req/sec {rps}, estimated runtime left {runtime_left_sec} sec{Color.RESET}" # noqa: E501 + ) + debug_stats.print() + + logger.info( + f"{Color.CYAN}All {bench_args.num_clients} clients finished{Color.RESET}" + ) + + # At this point all the clients finished, + # collect results (TTFT, TPOT, etc.) from all the clients. + # This needs to happens before calling join on the clients + # (result_queue should be emptied). + while not result_queue.empty(): + client_metrics.append(result_queue.get()) + + logger.info(f"Collected {len(client_metrics)} samples from all the clients") + + # Wait for all clients to finish + for client in clients: + logger.info( + f"{Color.CYAN}Waiting for client {client.name} " + f"(is alive: {client.is_alive()}){Color.RESET}" + ) + + client.join(timeout=120) + + if client.is_alive(): + logger.warning( + f"{Color.YELLOW}Client {client.name} will be terminated{Color.RESET}" + ) + client.terminate() + + exitcode = client.exitcode + if exitcode != 0: + logger.error( + f"{Color.RED}Client {client.name} exited " + f"with exit code {exitcode}{Color.RESET}" + ) + + logger.info( + f"All {bench_args.num_clients} clients exited (successfully " + f"finished {len(output_conv)} out of {total_convs} conversations)" + ) + + # Queues should be closed, required to avoid hang at interpreter shutdown + unfinished_tasks = 0 + while not task_queue.empty(): + task_queue.get() + unfinished_tasks += 1 + + if unfinished_tasks > 0: + # Can happen if not all tasks (conversations) have finished. + # May happen if --max-num-requests was used, + # or if an error occurred in one of the clients. + logger.debug(f"Discarding {unfinished_tasks} unfinished tasks") + + task_queue.close() + task_queue.join_thread() + + result_queue.close() + result_queue.join_thread() + + conv_queue.close() + conv_queue.join_thread() + + return output_conv, client_metrics + + +def get_filename_with_timestamp(label: str, extension: str) -> str: + time_now = datetime.now() + timestamp = time_now.strftime("%d-%m-%Y_%H-%M-%S") + filename = f"{label}__{timestamp}.{extension}" + return filename + + +def process_statistics( + client_metrics: list[RequestStats], + warmup_percentages: list[float], + test_params: dict, + verbose: bool, + gen_conv_args: Optional[GenConvArgs] = None, + excel_output: bool = False, +) -> None: + if len(client_metrics) == 0: + logger.info("No samples to process") + return + + logger.info(f"Processing {len(client_metrics)} samples...") + + raw_data = pd.DataFrame(client_metrics) + + if verbose: + # Calculate the time between user turns in each conversation (in a new column) + raw_data = raw_data.sort_values(by=["conversation_id", "start_time_ms"]) + raw_data["time_between_user_turns_sec"] = raw_data.groupby("conversation_id")[ + "start_time_ms" + ].diff() + + # Convert milliseconds to seconds + raw_data["time_between_user_turns_sec"] = ( + raw_data["time_between_user_turns_sec"] / 1000.0 + ) + + # Final raw data should be sorted by time + raw_data = raw_data.sort_values(by=["start_time_ms"]) + raw_data["end_time_ms"] = raw_data["start_time_ms"] + raw_data["latency_ms"] + + percentiles = [0.25, 0.5, 0.75, 0.9] + + # Add more percentiles if there are enough samples + if len(raw_data) >= 100: + percentiles.append(0.99) + + if len(raw_data) >= 1000: + percentiles.append(0.999) + + if len(raw_data) >= 10000: + percentiles.append(0.9999) + + # Set precision for numbers in the output text (the dataframes) + pd.set_option("display.precision", 2) + + # Exclude parameters from RequestStats + exclude = [ + "start_time_ms", + "end_time_ms", + "output_num_first_chunk_tokens", + "approx_cached_percent", + "conversation_id", + "client_id", + ] + + print(TEXT_SEPARATOR) + print(f"{Color.YELLOW}Parameters:{Color.RESET}") + for k, v in test_params.items(): + print(f"{k}={v}") + + # conversations generation parameters + if gen_conv_args is not None: + gen_params = { + "text_files": ", ".join(gen_conv_args.text_files), + "input_num_turns": str(gen_conv_args.input_num_turns), + "input_common_prefix_num_tokens": str( + gen_conv_args.input_common_prefix_num_tokens + ), + "input_prefix_num_tokens": str(gen_conv_args.input_prefix_num_tokens), + "input_num_tokens": str(gen_conv_args.input_num_tokens), + "output_num_tokens": str(gen_conv_args.output_num_tokens), + } + + print(f"{Color.YELLOW}Conversations Generation Parameters:{Color.RESET}") + for k, v in gen_params.items(): + print(f"{k}={v}") + + print(TEXT_SEPARATOR) + + params_list = [] + df_list = [] + for percent in warmup_percentages: + # Select samples from the end (tail) of the dataframe + warmup_count = int(percent * len(raw_data)) + tail_count = len(raw_data) - warmup_count + if tail_count == 0: + # No reason to process if the count of samples is zero + break + + df = raw_data.tail(tail_count) + + # Runtime is the diff between the end of the last request + # and the start of the first request + runtime_sec = df["end_time_ms"].iloc[-1] - df["start_time_ms"].iloc[0] + + # Convert milliseconds to seconds + runtime_sec = runtime_sec / 1000.0 + requests_per_sec = float(len(df)) / runtime_sec + + params = {"runtime_sec": runtime_sec, "requests_per_sec": requests_per_sec} + + # Generate a summary of relevant metrics (and drop irrelevant data) + df = df.drop(columns=exclude).describe(percentiles=percentiles).transpose() + + # List for Excel file + params_list.append(params) + df_list.append(df) + + # Print the statistics summary + if percent > 0 or len(warmup_percentages) > 1: + print( + f"{Color.YELLOW}Statistics summary " + f"(assuming {percent:.0%} warmup samples):{Color.RESET}" + ) + else: + print(f"{Color.YELLOW}Statistics summary:{Color.RESET}") + + for k, v in params.items(): + if isinstance(v, float): + print(f"{k} = {v:.3f}") + else: + print(f"{k} = {v}") + print(TEXT_SEPARATOR) + print(df) + print(TEXT_SEPARATOR) + + if excel_output: + prefix = f"statistics_{test_params['num_clients']}_clients" + filename = get_filename_with_timestamp(prefix, "xlsx") + + with pd.ExcelWriter(filename, engine="xlsxwriter") as writer: + startrow = 0 + test_params_df = pd.DataFrame([test_params]) + test_params_df.to_excel( + writer, sheet_name="Summary", index=False, startrow=startrow + ) + startrow += len(test_params_df) + 3 + + if gen_conv_args is not None: + gen_params_df = pd.DataFrame([gen_params]) + gen_params_df.to_excel( + writer, sheet_name="Summary", index=False, startrow=(startrow - 1) + ) + startrow += len(gen_params_df) + 3 + + for params, df_stats in zip(params_list, df_list): + df_params = pd.DataFrame([params]) + df_params.to_excel( + writer, sheet_name="Summary", index=False, startrow=startrow + ) + startrow += len(df_params) + 2 + df_stats.to_excel( + writer, sheet_name="Summary", index=True, startrow=startrow + ) + startrow += len(df_stats) + 3 + + raw_data.to_excel(writer, sheet_name="Raw data", index=False, startrow=0) + + logger.info( + f"{Color.GREEN}Client metrics exported to file: {filename}{Color.RESET}" + ) + + +async def get_server_info(url: str) -> None: + logger.info(f"{Color.BLUE}Collecting information from server: {url}{Color.RESET}") + async with aiohttp.ClientSession() as session: + # Get server version (not mandatory, "version" endpoint may not exist) + url_version = f"{url}/version" + async with session.get(url_version) as response: + if HTTPStatus(response.status) == HTTPStatus.OK: + text = await response.text() + logger.info(f"{Color.BLUE}Server version: {text}{Color.RESET}") + + # Get available models + url_models = f"{url}/v1/models" + async with session.get(url_models) as response: + if HTTPStatus(response.status) == HTTPStatus.OK: + text = await response.text() + logger.info(f"{Color.BLUE}Models:{Color.RESET}") + models_data = json.loads(text) + models_list = models_data["data"] + for model in models_list: + model_id = model["id"] + max_model_len = model.get("max_model_len", "N/A") + logger.info( + f"{Color.BLUE}\t{model_id=}, {max_model_len=}{Color.RESET}" + ) + else: + logger.info(f"{Color.RED}Failed to get models{Color.RESET}") + + +async def main() -> None: + parser = argparse.ArgumentParser( + prog="Benchmark serving with multi-turn conversations", + description="Benchmark online inference using REST API", + ) + parser.add_argument("--version", action="version", version="%(prog)s 1.0") + + parser.add_argument( + "-i", + "--input-file", + type=str, + required=True, + help="Input JSON file with ShareGPT conversations or " + "configuration file for generation of synthetic conversations", + ) + parser.add_argument( + "-o", + "--output-file", + type=str, + default=None, + help="Output JSON file containing conversations with updated assistant answers", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help="Seed for random number generators (default: 0)", + ) + parser.add_argument( + "-m", "--model", type=str, required=True, help="Path of the LLM model" + ) + parser.add_argument( + "-u", + "--url", + type=str, + default="http://localhost:8000", + help="Base URL for the LLM API server", + ) + + parser.add_argument( + "-p", + "--num-clients", + type=int, + default=1, + help="Number of clients that will send requests in parallel", + ) + parser.add_argument( + "-k", + "--max-active-conversations", + type=int, + default=None, + help="Max number of active conversations at a time (for all clients)", + ) + parser.add_argument( + "-n", + "--max-num-requests", + type=int, + default=None, + help="Max number of requests to send (total for all clients)", + ) + + parser.add_argument( + "--warmup-step", + default=False, + action="store_true", + help="Run a warmup step (using only the first turn of every conversation), " + "measurements will not be included in the final benchmark results", + ) + + parser.add_argument( + "--max-turns", + type=int, + default=None, + help="Maximum number of turns/messages per conversation, " + "includes both user and assistant messages " + "(a positive number, e.g: 2, 4, 6, etc.), disabled by default", + ) + parser.add_argument( + "--no-early-stop", + default=False, + action="store_true", + help="By default, the benchmark will stop if at least one client exits." + " Use this flag to disable this behavior", + ) + + parser.add_argument( + "--limit-max-tokens", + type=int, + default=NUM_TOKENS_FROM_DATASET, + help="Set max_tokens for the output token count of each request " + "(must also set --limit-min-tokens). " + "Overrides output token count from the input dataset. " + "Use a negative value to disable this limit.", + ) + parser.add_argument( + "--limit-min-tokens", + type=int, + default=NUM_TOKENS_FROM_DATASET, + help="Set min_tokens for the output token count of each request " + "(must also set --limit-max-tokens). " + "Overrides output token count from the input dataset. " + "Use a negative value to disable this limit.", + ) + + parser.add_argument( + "--request-rate", + type=float, + default=0, + help="Expected request rate (Poisson process) per client in requests/sec." + "Set to 0 for no delay between requests.", + ) + parser.add_argument( + "--conversation-sampling", + type=ConversationSampling, + choices=list(ConversationSampling), + default=ConversationSampling.ROUND_ROBIN, + help=( + "Strategy for selecting which conversation to use for the next request. " + "Options: 'round_robin' (cycle through conversations), " + "'random' (pick randomly)." + ), + ) + parser.add_argument( + "--verify-output", + default=False, + action="store_true", + help="Verify the LLM output (compare to the answers in the input JSON file)", + ) + + parser.add_argument( + "--no-stream", + default=False, + action="store_true", + help="Disable stream/streaming mode (set 'stream' to False in the API request)", + ) + + parser.add_argument( + "-e", + "--excel-output", + default=False, + action="store_true", + help="Export summary to Excel file (optional)", + ) + parser.add_argument( + "-v", + "--verbose", + default=False, + action="store_true", + help="Enable verbose output", + ) + parser.add_argument( + "--print-content", + default=False, + action="store_true", + help="Print the user prompts and the server's answers", + ) + + parser.add_argument( + "--warmup-percentages", + type=str, + default="0%", + help="Ignore the first X samples as warmup (X is a percentage)." + " A comma separated list of percentages can be used " + "(for example: --warmup-percentages=0%%,50%%)", + ) + + args = parser.parse_args() + + logger.info(args) + + logger.info(f"{Color.GREEN}Input parameters:{Color.RESET}") + logger.info(f"url={args.url}") + logger.info(f"model={args.model}") + logger.info(f"num_clients={args.num_clients}") + + if args.verify_output: + logger.info(f"{Color.PURPLE}Verify is enabled{Color.RESET}") + + # Calculate the amount of samples to filter (as warmup samples/measurements). + try: + warmup_percentages: list[float] = [0.0] + if not args.warmup_step: + # Warmup percentage can be used only if the warmup step was used + warmup_strings: list[str] = args.warmup_percentages.split(",") + warmup_strings = [x.replace("%", "") for x in warmup_strings] + warmup_percentages = [float(x) / 100 for x in warmup_strings] + + # Check for valid range (0 to 1) + for p in warmup_percentages: + assert p >= 0.0 and p < 1.0 + + # Sort from high to low warmup percentage + warmup_percentages.sort() + + logger.info( + f"Warmup percentages (percentage of samples): {warmup_percentages}" + ) + + except Exception: + raise ValueError( + f"Invalid --warmup-percentage={args.warmup_percentage}" + ) from None + + random.seed(args.seed) + np.random.seed(args.seed) + + if not os.path.exists(args.model): + raise OSError(f"Path does not exist: {args.model}") + logger.info("Loading tokenizer") + tokenizer = AutoTokenizer.from_pretrained(args.model) + + await get_server_info(args.url) + + # Load the input file (either conversations of configuration file) + logger.info(f"Reading input file: {args.input_file}") + with open(args.input_file) as f: + input_data = json.load(f) + + gen_conv_args = None + if isinstance(input_data, list): + # The conversations are stored as a list of dicts + logger.info(f"Found {len(input_data)} items in the input file") + + # Convert the list to a ConversationsMap + conversations = conversations_list_to_dict(input_data) + + elif isinstance(input_data, dict): + # The input file is a configuration file + # (type is determined by the field 'filetype') + if "filetype" not in input_data: + raise Exception( + f"Input file {args.input_file} is invalid (missing 'filetype')" + ) + + logger.info(f"Using input file with filetype: {input_data['filetype']}") + + gen_conv_args = parse_input_json_file(input_data) + + # Disable warning from "huggingface/tokenizers" + # (when using python multiprocessing and tokenizers) + os.environ["TOKENIZERS_PARALLELISM"] = "true" + + # Generate synthetic conversations + conversations = generate_conversations(gen_conv_args, tokenizer) + + else: + raise Exception(f"Input file {args.input_file} is invalid") + + if args.max_turns is not None: + if args.max_turns < 1: + raise ValueError("Max turns must be a positive number") + logger.info( + f"{Color.PURPLE}Max turns per conversation " + f"is limited to {args.max_turns}{Color.RESET}" + ) + + # Create benchmark configurations + client_args, req_args = get_client_config(args, conversations) + + bench_args = BenchmarkArgs( + url=args.url, num_clients=args.num_clients, early_stop=not args.no_early_stop + ) + + # Warm-up step + if args.warmup_step: + # Only send a single user prompt from every conversation. + # max_active_conversations must be 1, + # otherwise the clients may exit after sending a single request + # (because the task queue is empty). + warmup_client_args = client_args._replace( + skip_first_turn=False, max_turns=1, max_active_conversations=1 + ) + + # Early stop should be disabled, + # all clients should finish their work before exiting + warmup_bench_args = bench_args._replace(early_stop=False) + + logger.info(f"{Color.PURPLE}Warmup start{Color.RESET}") + conversations, _ = await main_mp( + warmup_client_args, req_args, warmup_bench_args, tokenizer, conversations + ) + logger.info(f"{Color.PURPLE}Warmup done{Color.RESET}") + + # Run the benchmark + start_time = time.perf_counter_ns() + client_convs, client_metrics = await main_mp( + client_args, req_args, bench_args, tokenizer, conversations + ) + total_runtime_ms = nanosec_to_millisec(time.perf_counter_ns() - start_time) + + # Calculate requests per second + total_runtime_sec = total_runtime_ms / 1000.0 + rps = len(client_metrics) / total_runtime_sec + logger.info( + f"{Color.GREEN}All clients finished, total runtime: {total_runtime_sec:.3f} sec" + f" ({total_runtime_ms:.3f} ms), requests per second: {rps:.3f}{Color.RESET}" + ) + + # Benchmark parameters + params = { + "model": args.model, + "num_clients": args.num_clients, + "num_conversations": len(conversations), + "active_conversations": args.max_active_conversations, + "seed": args.seed, + } + + if args.limit_min_tokens > 0: + params["min_tokens"] = args.limit_min_tokens + + if args.limit_max_tokens > 0: + params["max_tokens"] = args.limit_max_tokens + + # Process and print statistics (and save excel file with the statistics) + process_statistics( + client_metrics, + test_params=params, + warmup_percentages=warmup_percentages, + verbose=args.verbose, + gen_conv_args=gen_conv_args, + excel_output=args.excel_output, + ) + + if args.output_file is not None: + # Write a JSON file with the updated conversations + # The "assistant" content will contain the answers from the tested LLM + output_data: ShareGptConversations = conversations_dict_to_list(client_convs) + logger.info( + f"{Color.GREEN}Writing conversations file: {args.output_file}{Color.RESET}" + ) + with open(args.output_file, "w") as f: + json.dump(output_data, f, indent=4) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmarks/multi_turn/convert_sharegpt_to_openai.py b/benchmarks/multi_turn/convert_sharegpt_to_openai.py new file mode 100644 index 0000000000000..c3622c99a2e53 --- /dev/null +++ b/benchmarks/multi_turn/convert_sharegpt_to_openai.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Download dataset from: +https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json + +Convert to OpenAI API: +export INPUT_FILE=sharegpt_20230401_clean_lang_split.json +python convert_sharegpt_to_openai.py $INPUT_FILE sharegpt_conv_128.json --max-items=128 +""" + +import argparse +import json +import random +from statistics import mean +from typing import Any, Optional + +import pandas as pd # type: ignore +import tqdm # type: ignore +from transformers import AutoTokenizer # type: ignore + + +def has_non_english_chars(text: str) -> bool: + return not text.isascii() + + +def content_is_valid( + content: str, min_content_len: Optional[int], max_content_len: Optional[int] +) -> bool: + if min_content_len and len(content) < min_content_len: + return False + + if max_content_len and len(content) > max_content_len: + return False + + return has_non_english_chars(content) + + +def print_stats( + conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None +) -> None: + # Collect statistics + stats = [] + + print("\nCollecting statistics...") + for item in tqdm.tqdm(conversations): + # item has "id" and "messages" + messages = item["messages"] + + user_turns = 0 + assistant_turns = 0 + user_words = 0 + assistant_words = 0 + conv_chars = 0 + + user_tokens: list[int] = [] + assistant_tokens: list[int] = [] + + for m in messages: + content = m["content"] + conv_chars += len(content) + content_num_words = content.count(" ") + 1 + + num_tokens = 0 + if tokenizer: + num_tokens = len(tokenizer(m["content"]).input_ids) + + if m["role"] == "user": + user_turns += 1 + user_words += content_num_words + if tokenizer: + user_tokens.append(num_tokens) + + elif m["role"] == "assistant": + assistant_turns += 1 + assistant_words += content_num_words + if tokenizer: + assistant_tokens.append(num_tokens) + + # assert user_turns == assistant_turns, \ + # f"Invalid conversation ID {item['id']}" + + conv_words = user_words + assistant_words + item_stats = { + "user_turns": user_turns, + "assistant_turns": assistant_turns, + "user_words": user_words, + "assistant_words": assistant_words, + "conv_turns": len(messages), + "conv_words": conv_words, + "conv_characters": conv_chars, + } + + if len(user_tokens) > 0: + item_stats["user_tokens"] = int(mean(user_tokens)) + + if len(assistant_tokens) > 0: + item_stats["assistant_tokens"] = int(mean(assistant_tokens)) + + stats.append(item_stats) + + print("\nStatistics:") + percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999] + df = pd.DataFrame(stats) + print(df.describe(percentiles=percentiles).transpose()) + + +def convert_sharegpt_to_openai( + seed: int, + input_file: str, + output_file: str, + max_items: Optional[int], + min_content_len: Optional[int] = None, + max_content_len: Optional[int] = None, + min_turns: Optional[int] = None, + max_turns: Optional[int] = None, + model: Optional[str] = None, +) -> None: + if min_turns and max_turns: + assert min_turns <= max_turns + + if min_content_len and max_content_len: + # Verify that min is not larger than max if both were given + assert min_content_len <= max_content_len + + print( + f"Input parameters:\n{seed=}, {max_items=}, {min_content_len=}," + f" {max_content_len=}, {min_turns=}, {max_turns=}\n" + ) + + random.seed(seed) + + tokenizer = None + if model is not None: + print(f"Loading tokenizer from: {model}") + tokenizer = AutoTokenizer.from_pretrained(model) + + # Read the ShareGPT JSON file + print(f"Reading file: {input_file}") + with open(input_file, encoding="utf-8") as f: + # Should be a list of dicts + # Each dict should have "id" (string) and "conversations" (list of dicts) + sharegpt_data = json.load(f) + + assert isinstance(sharegpt_data, list), "Input file should contain a list of dicts" + + print(f"Total items in input file: {len(sharegpt_data):,}") + + print(f"Shuffling dataset with seed {seed}") + random.shuffle(sharegpt_data) + + # Map conversation ID to the all the messages + conversation_parts: dict[str, list[Any]] = {} + + for item in tqdm.tqdm(sharegpt_data): + assert "id" in item, "Missing key 'id'" + assert "conversations" in item, "Missing key 'conversations'" + + # Conversation ID (e.g: "hiWPlMD") and part/session (0, 1, 2, etc.) + conv_id, _ = item["id"].split("_") + new_turns = item["conversations"] + + if conv_id not in conversation_parts: + # Start new conversation + conversation_parts[conv_id] = [] + elif len(conversation_parts[conv_id]) > 0 and len(new_turns) > 0: + prev_turns = conversation_parts[conv_id][-1] + if prev_turns[-1]["from"] == new_turns[0]["from"]: + new_turns = new_turns[1:] + + if len(new_turns) > 0: + # We assume that parts are in order in the ShareGPT dataset + conversation_parts[conv_id].append(new_turns) + + dataset: list[dict[str, Any]] = [] + for conv_id, conv_parts in conversation_parts.items(): + new_item = {"id": conv_id} + + conversations: list[dict[str, str]] = [] + + # Merge all parts + for conv_part in conv_parts: + conversations.extend(conv_part) + + if len(conversations) > 0: + new_item["conversations"] = conversations + dataset.append(new_item) + + print(f"Total unique conversations (IDs) in input file: {len(dataset):,}") + + # Final output data + final_openai_dataset: list[dict] = [] + + # Filter conversations from the ShareGPT dataset and convert to OpenAI format + for item in tqdm.tqdm(dataset): + messages: list[dict] = [] + + assert "id" in item, "Missing key 'id'" + assert "conversations" in item, "Missing key 'conversations'" + + conv_id = item["id"] + conversations = item["conversations"] + + if min_turns is not None and len(conversations) < min_turns: + # Skip short conversations + continue + + # Convert each message in the conversation, up to max_turns if specified + for i, turn in enumerate(conversations): + assert "from" in turn and "value" in turn, ( + f"Invalid conversation ID {conv_id} - missing 'from' or 'value'" + ) + + role = None + turn_from = turn["from"] + + if turn_from in {"human", "user"}: + role = "user" + elif turn_from in {"gpt", "bing", "chatgpt", "bard"}: + role = "assistant" + elif turn_from == "system": + role = "system" + + assert role is not None, ( + f"Invalid conversation ID {conv_id} - 'from'='{turn_from}' is invalid" + ) + + if i == 0 and role != "user": + # If the first message is from assistant (gpt), skip it. + # this happens when the conversation is a follow-up + # to a previous conversation (from the same user). + continue + + if max_turns is not None and i >= max_turns: + break + + # Convert message to OpenAI format (with "role" and "content") + content = turn["value"] + messages.append({"role": role, "content": content}) + + # Add the converted conversation to the OpenAI format + if len(messages) > 0: + valid_messages = True + + # First turn should always be from the user + user_turn = True + + for m in messages: + # Make sure that turns alternate between user and assistant + if (user_turn and m["role"] != "user") or ( + not user_turn and m["role"] != "assistant" + ): + valid_messages = False + break + + user_turn = not user_turn + + content = m["content"] + valid_messages = content_is_valid( + content, min_content_len, max_content_len + ) + if not valid_messages: + break + + if valid_messages is True: + final_openai_dataset.append({"id": conv_id, "messages": messages}) + + assert len(final_openai_dataset) > 0, "Final number of conversations is zero" + + print_stats(final_openai_dataset) + + print_stats_again = False + if max_items is not None and len(final_openai_dataset) > max_items: + print(f"\n\nSampling {max_items} items from the dataset...") + print_stats_again = True + final_openai_dataset = random.sample(final_openai_dataset, max_items) + + if print_stats_again: + # Print stats after the dataset changed + print_stats(final_openai_dataset, tokenizer) + + # Write the converted data to a new JSON file + final_size = len(final_openai_dataset) + print(f"\nTotal conversations converted (after filtering): {final_size:,}") + print(f"\nWriting file: {output_file}") + with open(output_file, "w", encoding="utf-8") as f: + json.dump(final_openai_dataset, f, ensure_ascii=False, indent=2) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert ShareGPT dataset to OpenAI API format" + ) + parser.add_argument("input_file", help="Path to the input ShareGPT JSON file") + parser.add_argument( + "output_file", help="Path to the output OpenAI format JSON file" + ) + parser.add_argument( + "--seed", type=int, default=0, help="Seed for random number generators" + ) + parser.add_argument( + "--max-items", + type=int, + default=None, + help="Maximum number of items in the output file", + ) + parser.add_argument( + "--min-turns", + type=int, + default=None, + help="Minimum number of turns per conversation", + ) + parser.add_argument( + "--max-turns", + type=int, + default=None, + help="Maximum number of turns per conversation", + ) + parser.add_argument( + "--min-content-len", + type=int, + default=None, + help="Min number of characters in the messages' content", + ) + parser.add_argument( + "--max-content-len", + type=int, + default=None, + help="Max number of characters in the messages' content", + ) + parser.add_argument( + "--model", + type=str, + default=None, + help="LLM model, only the tokenizer will be used", + ) + + args = parser.parse_args() + + convert_sharegpt_to_openai( + args.seed, + args.input_file, + args.output_file, + args.max_items, + args.min_content_len, + args.max_content_len, + args.min_turns, + args.max_turns, + args.model, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/multi_turn/generate_multi_turn.json b/benchmarks/multi_turn/generate_multi_turn.json new file mode 100644 index 0000000000000..274d03c2bdb2b --- /dev/null +++ b/benchmarks/multi_turn/generate_multi_turn.json @@ -0,0 +1,35 @@ +{ + "filetype": "generate_conversations", + "num_conversations": 24, + "text_files": ["pg1184.txt"], + "print_stats": false, + "prompt_input": { + "num_turns": { + "distribution": "uniform", + "min": 12, + "max": 18 + }, + "common_prefix_num_tokens": { + "distribution": "constant", + "value": 500 + }, + "prefix_num_tokens": { + "distribution": "lognormal", + "mean": 6, + "sigma": 4, + "max": 1500 + }, + "num_tokens": { + "distribution": "uniform", + "min": 120, + "max": 160 + } + }, + "prompt_output": { + "num_tokens": { + "distribution": "uniform", + "min": 80, + "max": 120 + } + } +} \ No newline at end of file diff --git a/benchmarks/multi_turn/requirements.txt b/benchmarks/multi_turn/requirements.txt new file mode 100644 index 0000000000000..f0e1935914a14 --- /dev/null +++ b/benchmarks/multi_turn/requirements.txt @@ -0,0 +1,5 @@ +numpy>=1.24 +pandas>=2.0.0 +aiohttp>=3.10 +transformers>=4.46 +xlsxwriter>=3.2.1 \ No newline at end of file