diff --git a/benchmarks/multi_turn/bench_dataset.py b/benchmarks/multi_turn/bench_dataset.py index 2674899d1cc56..8cb8a2f386a97 100644 --- a/benchmarks/multi_turn/bench_dataset.py +++ b/benchmarks/multi_turn/bench_dataset.py @@ -11,6 +11,7 @@ from bench_utils import ( Color, logger, ) +from tqdm import tqdm from transformers import AutoTokenizer # type: ignore # Conversation ID is a string (e.g: "UzTK34D") @@ -417,6 +418,10 @@ def generate_conversations( data = file.read() tokens_in_file = tokenizer.encode(data, add_special_tokens=False) list_of_tokens.extend(tokens_in_file) + logger.info( + f"Loaded {len(tokens_in_file)} tokens from file {filename}, " + f"total tokens so far: {len(list_of_tokens)}" + ) conversations: ConversationsMap = {} conv_id = 0 @@ -449,18 +454,25 @@ def generate_conversations( ) base_offset += common_prefix_tokens - for conv_id in range(args.num_conversations): + for conv_id in tqdm( + range(args.num_conversations), + total=args.num_conversations, + desc="Generating conversations", + unit="conv", + ): # Generate a single conversation messages: MessagesList = [] nturns = turn_count[conv_id] # User prompt token count per turn (with lower limit) - input_token_count: np.ndarray = args.input_num_tokens.sample(nturns) + input_token_count: np.ndarray = args.input_num_tokens.sample(nturns).astype(int) input_token_count = np.maximum(input_token_count, base_prompt_token_count) # Assistant answer token count per turn (with lower limit) - output_token_count: np.ndarray = args.output_num_tokens.sample(nturns) + output_token_count: np.ndarray = args.output_num_tokens.sample(nturns).astype( + int + ) output_token_count = np.maximum(output_token_count, 1) user_turn = True diff --git a/benchmarks/multi_turn/requirements.txt b/benchmarks/multi_turn/requirements.txt index f0e1935914a14..bae656a5c5c4b 100644 --- a/benchmarks/multi_turn/requirements.txt +++ b/benchmarks/multi_turn/requirements.txt @@ -2,4 +2,5 @@ numpy>=1.24 pandas>=2.0.0 aiohttp>=3.10 transformers>=4.46 -xlsxwriter>=3.2.1 \ No newline at end of file +xlsxwriter>=3.2.1 +tqdm>=4.66