mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:25:33 +08:00
Add output token counting to gsm8k eval (#28594)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
bc3e43069a
commit
c9a3a02149
@ -83,8 +83,12 @@ async def call_vllm_api(
|
|||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
) -> str:
|
) -> tuple[str, int]:
|
||||||
"""Call vLLM's OpenAI-compatible completions endpoint."""
|
"""Call vLLM's OpenAI-compatible completions endpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response_text, completion_tokens)
|
||||||
|
"""
|
||||||
data = {
|
data = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
@ -98,10 +102,12 @@ async def call_vllm_api(
|
|||||||
async with session.post(f"{url}/v1/completions", json=data) as response:
|
async with session.post(f"{url}/v1/completions", json=data) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = await response.json()
|
result = await response.json()
|
||||||
return result["choices"][0]["text"]
|
text = result["choices"][0]["text"]
|
||||||
|
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
|
||||||
|
return text, completion_tokens
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error calling vLLM API: {e}")
|
print(f"Error calling vLLM API: {e}")
|
||||||
return ""
|
return "", 0
|
||||||
|
|
||||||
|
|
||||||
def evaluate_gsm8k(
|
def evaluate_gsm8k(
|
||||||
@ -146,10 +152,11 @@ def evaluate_gsm8k(
|
|||||||
# Run evaluation
|
# Run evaluation
|
||||||
async def run_async_evaluation():
|
async def run_async_evaluation():
|
||||||
states: list[str] = [""] * num_questions
|
states: list[str] = [""] * num_questions
|
||||||
|
output_tokens: list[int] = [0] * num_questions
|
||||||
|
|
||||||
async def get_answer(session: aiohttp.ClientSession, i: int) -> str:
|
async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]:
|
||||||
prompt = few_shot_examples + questions[i]
|
prompt = few_shot_examples + questions[i]
|
||||||
answer = await call_vllm_api(
|
answer, tokens = await call_vllm_api(
|
||||||
session=session,
|
session=session,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@ -159,7 +166,8 @@ def evaluate_gsm8k(
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
states[i] = answer
|
states[i] = answer
|
||||||
return answer
|
output_tokens[i] = tokens
|
||||||
|
return answer, tokens
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=aiohttp.ClientTimeout(total=600)
|
timeout=aiohttp.ClientTimeout(total=600)
|
||||||
@ -167,24 +175,28 @@ def evaluate_gsm8k(
|
|||||||
tasks = [get_answer(session, i) for i in range(num_questions)]
|
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||||
await tqdm.gather(*tasks, desc="Evaluating")
|
await tqdm.gather(*tasks, desc="Evaluating")
|
||||||
|
|
||||||
return states
|
return states, output_tokens
|
||||||
|
|
||||||
print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot")
|
print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot")
|
||||||
|
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
states = asyncio.run(run_async_evaluation())
|
states, output_tokens = asyncio.run(run_async_evaluation())
|
||||||
latency = time.perf_counter() - tic
|
latency = time.perf_counter() - tic
|
||||||
|
|
||||||
# Compute metrics
|
# Compute metrics
|
||||||
preds = [get_answer_value(state) for state in states]
|
preds = [get_answer_value(state) for state in states]
|
||||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||||
invalid_rate = np.mean(np.array(preds) == INVALID)
|
invalid_rate = np.mean(np.array(preds) == INVALID)
|
||||||
|
total_output_tokens = sum(output_tokens)
|
||||||
|
tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"accuracy": accuracy,
|
"accuracy": accuracy,
|
||||||
"invalid_rate": invalid_rate,
|
"invalid_rate": invalid_rate,
|
||||||
"latency": latency,
|
"latency": latency,
|
||||||
"questions_per_second": num_questions / latency,
|
"questions_per_second": num_questions / latency,
|
||||||
|
"total_output_tokens": total_output_tokens,
|
||||||
|
"tokens_per_second": tokens_per_second,
|
||||||
"num_questions": num_questions,
|
"num_questions": num_questions,
|
||||||
"num_shots": num_shots,
|
"num_shots": num_shots,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
@ -236,6 +248,8 @@ def main() -> None:
|
|||||||
print(f"Invalid responses: {result['invalid_rate']:.3f}")
|
print(f"Invalid responses: {result['invalid_rate']:.3f}")
|
||||||
print(f"Total latency: {result['latency']:.3f} s")
|
print(f"Total latency: {result['latency']:.3f} s")
|
||||||
print(f"Questions per second: {result['questions_per_second']:.3f}")
|
print(f"Questions per second: {result['questions_per_second']:.3f}")
|
||||||
|
print(f"Total output tokens: {result['total_output_tokens']}")
|
||||||
|
print(f"Output tokens per second: {result['tokens_per_second']:.3f}")
|
||||||
|
|
||||||
# Optional file saving
|
# Optional file saving
|
||||||
if args.save_results:
|
if args.save_results:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user