Add output token counting to gsm8k eval (#28594)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-11-14 04:32:03 -05:00 committed by GitHub
parent bc3e43069a
commit c9a3a02149
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -83,8 +83,12 @@ async def call_vllm_api(
stop: list[str] | None = None, stop: list[str] | None = None,
url: str | None = None, url: str | None = None,
seed: int | None = None, seed: int | None = None,
) -> str: ) -> tuple[str, int]:
"""Call vLLM's OpenAI-compatible completions endpoint.""" """Call vLLM's OpenAI-compatible completions endpoint.
Returns:
Tuple of (response_text, completion_tokens)
"""
data = { data = {
"prompt": prompt, "prompt": prompt,
"temperature": temperature, "temperature": temperature,
@ -98,10 +102,12 @@ async def call_vllm_api(
async with session.post(f"{url}/v1/completions", json=data) as response: async with session.post(f"{url}/v1/completions", json=data) as response:
response.raise_for_status() response.raise_for_status()
result = await response.json() result = await response.json()
return result["choices"][0]["text"] text = result["choices"][0]["text"]
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
return text, completion_tokens
except Exception as e: except Exception as e:
print(f"Error calling vLLM API: {e}") print(f"Error calling vLLM API: {e}")
return "" return "", 0
def evaluate_gsm8k( def evaluate_gsm8k(
@ -146,10 +152,11 @@ def evaluate_gsm8k(
# Run evaluation # Run evaluation
async def run_async_evaluation(): async def run_async_evaluation():
states: list[str] = [""] * num_questions states: list[str] = [""] * num_questions
output_tokens: list[int] = [0] * num_questions
async def get_answer(session: aiohttp.ClientSession, i: int) -> str: async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]:
prompt = few_shot_examples + questions[i] prompt = few_shot_examples + questions[i]
answer = await call_vllm_api( answer, tokens = await call_vllm_api(
session=session, session=session,
prompt=prompt, prompt=prompt,
temperature=temperature, temperature=temperature,
@ -159,7 +166,8 @@ def evaluate_gsm8k(
seed=seed, seed=seed,
) )
states[i] = answer states[i] = answer
return answer output_tokens[i] = tokens
return answer, tokens
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=600) timeout=aiohttp.ClientTimeout(total=600)
@ -167,24 +175,28 @@ def evaluate_gsm8k(
tasks = [get_answer(session, i) for i in range(num_questions)] tasks = [get_answer(session, i) for i in range(num_questions)]
await tqdm.gather(*tasks, desc="Evaluating") await tqdm.gather(*tasks, desc="Evaluating")
return states return states, output_tokens
print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot") print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot")
tic = time.perf_counter() tic = time.perf_counter()
states = asyncio.run(run_async_evaluation()) states, output_tokens = asyncio.run(run_async_evaluation())
latency = time.perf_counter() - tic latency = time.perf_counter() - tic
# Compute metrics # Compute metrics
preds = [get_answer_value(state) for state in states] preds = [get_answer_value(state) for state in states]
accuracy = np.mean(np.array(preds) == np.array(labels)) accuracy = np.mean(np.array(preds) == np.array(labels))
invalid_rate = np.mean(np.array(preds) == INVALID) invalid_rate = np.mean(np.array(preds) == INVALID)
total_output_tokens = sum(output_tokens)
tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0
result = { result = {
"accuracy": accuracy, "accuracy": accuracy,
"invalid_rate": invalid_rate, "invalid_rate": invalid_rate,
"latency": latency, "latency": latency,
"questions_per_second": num_questions / latency, "questions_per_second": num_questions / latency,
"total_output_tokens": total_output_tokens,
"tokens_per_second": tokens_per_second,
"num_questions": num_questions, "num_questions": num_questions,
"num_shots": num_shots, "num_shots": num_shots,
"max_tokens": max_tokens, "max_tokens": max_tokens,
@ -236,6 +248,8 @@ def main() -> None:
print(f"Invalid responses: {result['invalid_rate']:.3f}") print(f"Invalid responses: {result['invalid_rate']:.3f}")
print(f"Total latency: {result['latency']:.3f} s") print(f"Total latency: {result['latency']:.3f} s")
print(f"Questions per second: {result['questions_per_second']:.3f}") print(f"Questions per second: {result['questions_per_second']:.3f}")
print(f"Total output tokens: {result['total_output_tokens']}")
print(f"Output tokens per second: {result['tokens_per_second']:.3f}")
# Optional file saving # Optional file saving
if args.save_results: if args.save_results: