From c9a3a02149d83cc2840769228c4e591d39351bb6 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 14 Nov 2025 04:32:03 -0500 Subject: [PATCH] Add output token counting to gsm8k eval (#28594) Signed-off-by: mgoin --- tests/evals/gsm8k/gsm8k_eval.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py index c7799607912b..0421f8bb1859 100644 --- a/tests/evals/gsm8k/gsm8k_eval.py +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -83,8 +83,12 @@ async def call_vllm_api( stop: list[str] | None = None, url: str | None = None, seed: int | None = None, -) -> str: - """Call vLLM's OpenAI-compatible completions endpoint.""" +) -> tuple[str, int]: + """Call vLLM's OpenAI-compatible completions endpoint. + + Returns: + Tuple of (response_text, completion_tokens) + """ data = { "prompt": prompt, "temperature": temperature, @@ -98,10 +102,12 @@ async def call_vllm_api( async with session.post(f"{url}/v1/completions", json=data) as response: response.raise_for_status() result = await response.json() - return result["choices"][0]["text"] + text = result["choices"][0]["text"] + completion_tokens = result.get("usage", {}).get("completion_tokens", 0) + return text, completion_tokens except Exception as e: print(f"Error calling vLLM API: {e}") - return "" + return "", 0 def evaluate_gsm8k( @@ -146,10 +152,11 @@ def evaluate_gsm8k( # Run evaluation async def run_async_evaluation(): states: list[str] = [""] * num_questions + output_tokens: list[int] = [0] * num_questions - async def get_answer(session: aiohttp.ClientSession, i: int) -> str: + async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]: prompt = few_shot_examples + questions[i] - answer = await call_vllm_api( + answer, tokens = await call_vllm_api( session=session, prompt=prompt, temperature=temperature, @@ -159,7 +166,8 @@ def evaluate_gsm8k( seed=seed, ) states[i] = answer - return answer + output_tokens[i] = tokens + return answer, tokens async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=600) @@ -167,24 +175,28 @@ def evaluate_gsm8k( tasks = [get_answer(session, i) for i in range(num_questions)] await tqdm.gather(*tasks, desc="Evaluating") - return states + return states, output_tokens print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot") tic = time.perf_counter() - states = asyncio.run(run_async_evaluation()) + states, output_tokens = asyncio.run(run_async_evaluation()) latency = time.perf_counter() - tic # Compute metrics preds = [get_answer_value(state) for state in states] accuracy = np.mean(np.array(preds) == np.array(labels)) invalid_rate = np.mean(np.array(preds) == INVALID) + total_output_tokens = sum(output_tokens) + tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0 result = { "accuracy": accuracy, "invalid_rate": invalid_rate, "latency": latency, "questions_per_second": num_questions / latency, + "total_output_tokens": total_output_tokens, + "tokens_per_second": tokens_per_second, "num_questions": num_questions, "num_shots": num_shots, "max_tokens": max_tokens, @@ -236,6 +248,8 @@ def main() -> None: print(f"Invalid responses: {result['invalid_rate']:.3f}") print(f"Total latency: {result['latency']:.3f} s") print(f"Questions per second: {result['questions_per_second']:.3f}") + print(f"Total output tokens: {result['total_output_tokens']}") + print(f"Output tokens per second: {result['tokens_per_second']:.3f}") # Optional file saving if args.save_results: