vllm/tests/evals/gsm8k/test_gsm8k_correctness.py
Zhewen Li a65a934ebe
[CI/Build] Temporary fix to LM Eval Small Models (#28324)
Signed-off-by: zhewenli <zhewenli@meta.com>
2025-11-09 21:08:38 +00:00

94 lines
2.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GSM8K evaluation using vLLM server and isolated GSM8K script.
Replacement for lm-eval-harness with better performance and control.
Usage:
pytest -s -v test_gsm8k_correctness.py \
--config-list-file=configs/models-small.txt \
--tp-size=1
"""
import yaml
from tests.utils import RemoteOpenAIServer
from .gsm8k_eval import evaluate_gsm8k
RTOL = 0.08 # Relative tolerance for accuracy comparison
def launch_gsm8k_eval(eval_config, server_url, tp_size):
"""Launch GSM8K evaluation using our isolated script."""
# Extract host and port from server URL
if "://" in server_url:
server_url = server_url.split("://")[1]
host_port = server_url.split("/")[0] # Remove path if present
if ":" in host_port:
host, port = host_port.split(":")
port = int(port)
else:
host = host_port
port = 8000
# Add http:// prefix if not present
if not host.startswith("http"):
host = f"http://{host}"
# Run GSM8K evaluation
results = evaluate_gsm8k(
num_questions=eval_config["num_questions"],
num_shots=eval_config["num_fewshot"],
host=host,
port=port,
)
return results
def test_gsm8k_correctness_param(config_filename, tp_size):
"""Test GSM8K correctness for a given model configuration."""
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
# Server arguments
server_args = [
"--max-model-len",
str(eval_config.get("max_model_len", 4096)),
"--enforce-eager",
"--trust-remote-code",
"--tensor-parallel-size",
str(tp_size),
]
env_dict = eval_config.get("env", None)
# Launch server and run evaluation
with RemoteOpenAIServer(
eval_config["model_name"], server_args, env_dict=env_dict, max_wait_seconds=480
) as remote_server:
server_url = remote_server.url_for("v1")
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
# Check accuracy against threshold
measured_accuracy = results["accuracy"]
expected_accuracy = eval_config["accuracy_threshold"]
print(f"GSM8K Results for {eval_config['model_name']}:")
print(f" Accuracy: {measured_accuracy:.3f}")
print(f" Expected: {expected_accuracy:.3f}")
print(f" Questions: {results['num_questions']}")
print(f" Invalid rate: {results['invalid_rate']:.3f}")
print(f" Latency: {results['latency']:.1f}s")
print(f" QPS: {results['questions_per_second']:.1f}")
# Verify accuracy is within tolerance
assert measured_accuracy >= expected_accuracy - RTOL, (
f"Accuracy too low: {measured_accuracy:.3f} < "
f"{expected_accuracy:.3f} - {RTOL:.3f}"
)
print(f"✅ GSM8K test passed for {eval_config['model_name']}")