mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
Add output token counting to gsm8k eval (#28594)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
bc3e43069a
commit
c9a3a02149
@ -83,8 +83,12 @@ async def call_vllm_api(
|
||||
stop: list[str] | None = None,
|
||||
url: str | None = None,
|
||||
seed: int | None = None,
|
||||
) -> str:
|
||||
"""Call vLLM's OpenAI-compatible completions endpoint."""
|
||||
) -> tuple[str, int]:
|
||||
"""Call vLLM's OpenAI-compatible completions endpoint.
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, completion_tokens)
|
||||
"""
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
@ -98,10 +102,12 @@ async def call_vllm_api(
|
||||
async with session.post(f"{url}/v1/completions", json=data) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
return result["choices"][0]["text"]
|
||||
text = result["choices"][0]["text"]
|
||||
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
|
||||
return text, completion_tokens
|
||||
except Exception as e:
|
||||
print(f"Error calling vLLM API: {e}")
|
||||
return ""
|
||||
return "", 0
|
||||
|
||||
|
||||
def evaluate_gsm8k(
|
||||
@ -146,10 +152,11 @@ def evaluate_gsm8k(
|
||||
# Run evaluation
|
||||
async def run_async_evaluation():
|
||||
states: list[str] = [""] * num_questions
|
||||
output_tokens: list[int] = [0] * num_questions
|
||||
|
||||
async def get_answer(session: aiohttp.ClientSession, i: int) -> str:
|
||||
async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]:
|
||||
prompt = few_shot_examples + questions[i]
|
||||
answer = await call_vllm_api(
|
||||
answer, tokens = await call_vllm_api(
|
||||
session=session,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
@ -159,7 +166,8 @@ def evaluate_gsm8k(
|
||||
seed=seed,
|
||||
)
|
||||
states[i] = answer
|
||||
return answer
|
||||
output_tokens[i] = tokens
|
||||
return answer, tokens
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=600)
|
||||
@ -167,24 +175,28 @@ def evaluate_gsm8k(
|
||||
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||
await tqdm.gather(*tasks, desc="Evaluating")
|
||||
|
||||
return states
|
||||
return states, output_tokens
|
||||
|
||||
print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot")
|
||||
|
||||
tic = time.perf_counter()
|
||||
states = asyncio.run(run_async_evaluation())
|
||||
states, output_tokens = asyncio.run(run_async_evaluation())
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute metrics
|
||||
preds = [get_answer_value(state) for state in states]
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid_rate = np.mean(np.array(preds) == INVALID)
|
||||
total_output_tokens = sum(output_tokens)
|
||||
tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0
|
||||
|
||||
result = {
|
||||
"accuracy": accuracy,
|
||||
"invalid_rate": invalid_rate,
|
||||
"latency": latency,
|
||||
"questions_per_second": num_questions / latency,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"tokens_per_second": tokens_per_second,
|
||||
"num_questions": num_questions,
|
||||
"num_shots": num_shots,
|
||||
"max_tokens": max_tokens,
|
||||
@ -236,6 +248,8 @@ def main() -> None:
|
||||
print(f"Invalid responses: {result['invalid_rate']:.3f}")
|
||||
print(f"Total latency: {result['latency']:.3f} s")
|
||||
print(f"Questions per second: {result['questions_per_second']:.3f}")
|
||||
print(f"Total output tokens: {result['total_output_tokens']}")
|
||||
print(f"Output tokens per second: {result['tokens_per_second']:.3f}")
|
||||
|
||||
# Optional file saving
|
||||
if args.save_results:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user