diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 781b8e0fa0095..2f7f1db75bfb9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -451,13 +451,11 @@ steps: - label: LM Eval Small Models # 53min mirror_hardwares: [amdexperimental] - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization commands: - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - label: OpenAI API correctness mirror_hardwares: [amdexperimental] diff --git a/tests/evals/gsm8k/README.md b/tests/evals/gsm8k/README.md new file mode 100644 index 0000000000000..58572c3a6fbc1 --- /dev/null +++ b/tests/evals/gsm8k/README.md @@ -0,0 +1,35 @@ +# GSM8K Accuracy Evaluation + +This directory contains a replacement for the lm-eval-harness GSM8K evaluation, using an isolated GSM8K script and vLLM server for better performance and control. + +## Usage + +### Run tests with pytest (like buildkite) + +```bash +pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 +``` + +### Run standalone evaluation script + +```bash +# Start vLLM server first +vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000 + +# Run evaluation +python tests/gsm8k/gsm8k_eval.py --port 8000 +``` + +## Configuration Format + +Model configs in `configs/` directory use this YAML format: + +```yaml +model_name: "Qwen/Qwen2.5-1.5B-Instruct" +accuracy_threshold: 0.54 # Minimum expected accuracy +num_questions: 1319 # Number of questions (default: full test set) +num_fewshot: 5 # Few-shot examples from train set +max_model_len: 4096 # Model context length +``` diff --git a/tests/evals/gsm8k/__init__.py b/tests/evals/gsm8k/__init__.py new file mode 100644 index 0000000000000..0fec1fe5bcdfd --- /dev/null +++ b/tests/evals/gsm8k/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml new file mode 100644 index 0000000000000..caa0448f23d48 --- /dev/null +++ b/tests/evals/gsm8k/configs/Llama-3-8B-Instruct-nonuniform-CT.yaml @@ -0,0 +1,5 @@ +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test" +accuracy_threshold: 0.74 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml new file mode 100644 index 0000000000000..615aa69a2d2b6 --- /dev/null +++ b/tests/evals/gsm8k/configs/Llama-3.2-1B-Instruct-INT8-CT.yaml @@ -0,0 +1,5 @@ +model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8" +accuracy_threshold: 0.31 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml new file mode 100644 index 0000000000000..c5dbceeeb2b45 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml @@ -0,0 +1,5 @@ +model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16" +accuracy_threshold: 0.45 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml new file mode 100644 index 0000000000000..5319ada30f645 --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -0,0 +1,5 @@ +model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" +accuracy_threshold: 0.60 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml new file mode 100644 index 0000000000000..c39fb979d98ac --- /dev/null +++ b/tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml @@ -0,0 +1,5 @@ +model_name: "Qwen/Qwen3-0.6B-FP8" +accuracy_threshold: 0.375 +num_questions: 1319 +num_fewshot: 5 +max_model_len: 4096 \ No newline at end of file diff --git a/tests/evals/gsm8k/configs/models-small.txt b/tests/evals/gsm8k/configs/models-small.txt new file mode 100644 index 0000000000000..afd1065b9191b --- /dev/null +++ b/tests/evals/gsm8k/configs/models-small.txt @@ -0,0 +1,5 @@ +Qwen3-0.6B-FP8.yaml +Llama-3.2-1B-Instruct-INT8-CT.yaml +Llama-3-8B-Instruct-nonuniform-CT.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml +Qwen1.5-MoE-W4A16-CT.yaml diff --git a/tests/evals/gsm8k/conftest.py b/tests/evals/gsm8k/conftest.py new file mode 100644 index 0000000000000..d96b0a66ede2b --- /dev/null +++ b/tests/evals/gsm8k/conftest.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path + + +def pytest_addoption(parser): + """Add custom command line options.""" + parser.addoption("--config-list-file", + default="configs/models-small.txt", + help="File containing list of config files to test") + parser.addoption("--tp-size", + default=1, + type=int, + help="Tensor parallel size") + + +def pytest_generate_tests(metafunc): + """Generate test parameters from config files.""" + if "config_filename" in metafunc.fixturenames: + config_list_file = metafunc.config.getoption("--config-list-file") + tp_size = metafunc.config.getoption("--tp-size") + + # Handle both relative and absolute paths + config_list_path = Path(config_list_file) + if not config_list_path.is_absolute(): + # If relative, try relative to test directory first + test_dir_path = Path(__file__).parent / config_list_file + if test_dir_path.exists(): + config_list_path = test_dir_path + else: + # Try relative to current working directory + config_list_path = Path.cwd() / config_list_file + + print(f"Looking for config list at: {config_list_path}") + + config_files = [] + if config_list_path.exists(): + # Determine config directory (same directory as the list file) + config_dir = config_list_path.parent + + with open(config_list_path) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + config_path = config_dir / line + print(f"Checking config file: {config_path}") + if config_path.exists(): + config_files.append(config_path) + print(f" ✓ Found: {config_path}") + else: + print(f" ✗ Missing: {config_path}") + else: + print(f"Config list file not found: {config_list_path}") + + # Generate test parameters + if config_files: + metafunc.parametrize(["config_filename", "tp_size"], + [(config_file, int(tp_size)) + for config_file in config_files], + ids=[ + f"{config_file.stem}-tp{tp_size}" + for config_file in config_files + ]) + else: + print("No config files found, test will be skipped") diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py new file mode 100644 index 0000000000000..7d0ce25f75dd4 --- /dev/null +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Isolated GSM8K evaluation script for vLLM serve endpoint. +""" + +import argparse +import ast +import asyncio +import json +import os +import time +from collections.abc import Generator +from typing import Optional, Union + +import aiohttp +import numpy as np +import regex as re +import requests +from tqdm.asyncio import tqdm + +INVALID = -9999999 + + +def download_and_cache_file(url: str, filename: Optional[str] = None) -> str: + """Download and cache a file from a URL.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=1024): + f.write(chunk) + + return filename + + +def load_gsm8k_data() -> tuple[list[dict], list[dict]]: + """Load GSM8K train and test data""" + train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl" + test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + + train_file = download_and_cache_file(train_url) + test_file = download_and_cache_file(test_url) + + train_data = list(read_jsonl(train_file)) + test_data = list(read_jsonl(test_file)) + + return train_data, test_data + + +def read_jsonl(filename: str) -> Generator[dict, None, None]: + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if not line.startswith("#"): + yield json.loads(line) + + +def get_answer_value(answer_str: str) -> int: + """Extract the numerical answer from the response.""" + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +async def call_vllm_api(session: aiohttp.ClientSession, + prompt: str, + temperature: float, + max_tokens: int, + stop: Optional[list[str]] = None, + url: Optional[str] = None, + seed: Optional[int] = None) -> str: + """Call vLLM's OpenAI-compatible completions endpoint.""" + data = { + "prompt": prompt, + "temperature": temperature, + "max_tokens": max_tokens, + "stop": stop, + } + if seed is not None: + data["seed"] = seed + + try: + async with session.post(f"{url}/v1/completions", + json=data) as response: + response.raise_for_status() + result = await response.json() + return result["choices"][0]["text"] + except Exception as e: + print(f"Error calling vLLM API: {e}") + return "" + + +def evaluate_gsm8k(num_questions: int = 1319, + num_shots: int = 5, + max_tokens: int = 256, + host: str = "http://127.0.0.1", + port: int = 8000, + temperature: float = 0.0, + seed: Optional[int] = 42) -> dict[str, Union[float, int]]: + """ + Evaluate GSM8K accuracy using vLLM serve endpoint. + + Returns dict with accuracy, invalid_rate, latency, etc. + """ + base_url = f"{host}:{port}" + + # Load GSM8K train and test data + train_data, test_data = load_gsm8k_data() + + # Limit to available test questions + num_questions = min(num_questions, len(test_data)) + + # Build few-shot examples from train split (like lm-eval does) + few_shot_examples = "" + for i in range(num_shots): + few_shot_examples += (f"Question: {train_data[i]['question']}\n" + f"Answer: {train_data[i]['answer']}\n\n") + + # Prepare test questions and labels from test split + questions = [] + labels = [] + for i in range(num_questions): + questions.append(f"Question: {test_data[i]['question']}\nAnswer:") + labels.append(get_answer_value(test_data[i]["answer"])) + + assert all(label != INVALID for label in labels), "Some labels are invalid" + + # Run evaluation + async def run_async_evaluation(): + states: list[str] = [""] * num_questions + + async def get_answer(session: aiohttp.ClientSession, i: int) -> str: + prompt = few_shot_examples + questions[i] + answer = await call_vllm_api( + session=session, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + stop=["Question", "Assistant:", "<|separator|>"], + url=base_url, + seed=seed, + ) + states[i] = answer + return answer + + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( + total=600)) as session: + tasks = [get_answer(session, i) for i in range(num_questions)] + await tqdm.gather(*tasks, desc="Evaluating") + + return states + + print(f"Running GSM8K evaluation: {num_questions} questions, " + f"{num_shots}-shot") + + tic = time.perf_counter() + states = asyncio.run(run_async_evaluation()) + latency = time.perf_counter() - tic + + # Compute metrics + preds = [get_answer_value(state) for state in states] + accuracy = np.mean(np.array(preds) == np.array(labels)) + invalid_rate = np.mean(np.array(preds) == INVALID) + + result = { + "accuracy": accuracy, + "invalid_rate": invalid_rate, + "latency": latency, + "questions_per_second": num_questions / latency, + "num_questions": num_questions, + "num_shots": num_shots, + "max_tokens": max_tokens, + "timestamp": time.time(), + } + + return result + + +def main() -> None: + parser = argparse.ArgumentParser( + description="GSM8K evaluation for vLLM serve") + parser.add_argument("--num-shots", + type=int, + default=5, + help="Number of few-shot examples") + parser.add_argument("--num-questions", + type=int, + default=1319, + help="Number of questions to evaluate") + parser.add_argument("--max-tokens", + type=int, + default=256, + help="Max tokens for generation") + parser.add_argument("--host", + type=str, + default="http://127.0.0.1", + help="Host URL") + parser.add_argument("--port", type=int, default=8000, help="Port number") + parser.add_argument("--temperature", + type=float, + default=0.0, + help="Temperature for generation") + parser.add_argument("--seed", + type=int, + default=42, + help="Random seed for reproducibility") + parser.add_argument("--save-results", + type=str, + help="Save results to JSON file") + + args = parser.parse_args() + + result = evaluate_gsm8k( + num_questions=args.num_questions, + num_shots=args.num_shots, + max_tokens=args.max_tokens, + host=args.host, + port=args.port, + temperature=args.temperature, + seed=args.seed, + ) + + # Print results to terminal + print("\nResults:") + print(f"Accuracy: {result['accuracy']:.3f}") + print(f"Invalid responses: {result['invalid_rate']:.3f}") + print(f"Total latency: {result['latency']:.3f} s") + print(f"Questions per second: {result['questions_per_second']:.3f}") + + # Optional file saving + if args.save_results: + with open(args.save_results, "w") as f: + json.dump(result, f, indent=2) + print(f"Results saved to {args.save_results}") + + +if __name__ == "__main__": + main() diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py new file mode 100644 index 0000000000000..a12dd49dbea6d --- /dev/null +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GSM8K evaluation using vLLM server and isolated GSM8K script. +Replacement for lm-eval-harness with better performance and control. + +Usage: +pytest -s -v test_gsm8k_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 +""" + +import yaml + +from tests.utils import RemoteOpenAIServer + +from .gsm8k_eval import evaluate_gsm8k + +RTOL = 0.08 # Relative tolerance for accuracy comparison + + +def launch_gsm8k_eval(eval_config, server_url, tp_size): + """Launch GSM8K evaluation using our isolated script.""" + # Extract host and port from server URL + if "://" in server_url: + server_url = server_url.split("://")[1] + + host_port = server_url.split("/")[0] # Remove path if present + if ":" in host_port: + host, port = host_port.split(":") + port = int(port) + else: + host = host_port + port = 8000 + + # Add http:// prefix if not present + if not host.startswith("http"): + host = f"http://{host}" + + # Run GSM8K evaluation + results = evaluate_gsm8k( + num_questions=eval_config["num_questions"], + num_shots=eval_config["num_fewshot"], + host=host, + port=port, + ) + + return results + + +def test_gsm8k_correctness_param(config_filename, tp_size): + """Test GSM8K correctness for a given model configuration.""" + eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) + + # Server arguments + server_args = [ + "--max-model-len", + str(eval_config.get("max_model_len", 4096)), + "--enforce-eager", + "--trust-remote-code", + "--tensor-parallel-size", + str(tp_size), + ] + + # Launch server and run evaluation + with RemoteOpenAIServer(eval_config["model_name"], + server_args, + max_wait_seconds=480) as remote_server: + server_url = remote_server.url_for("v1") + + results = launch_gsm8k_eval(eval_config, server_url, tp_size) + + # Check accuracy against threshold + measured_accuracy = results["accuracy"] + expected_accuracy = eval_config["accuracy_threshold"] + + print(f"GSM8K Results for {eval_config['model_name']}:") + print(f" Accuracy: {measured_accuracy:.3f}") + print(f" Expected: {expected_accuracy:.3f}") + print(f" Questions: {results['num_questions']}") + print(f" Invalid rate: {results['invalid_rate']:.3f}") + print(f" Latency: {results['latency']:.1f}s") + print(f" QPS: {results['questions_per_second']:.1f}") + + # Verify accuracy is within tolerance + assert measured_accuracy >= expected_accuracy - RTOL, ( + f"Accuracy too low: {measured_accuracy:.3f} < " + f"{expected_accuracy:.3f} - {RTOL:.3f}") + + print(f"✅ GSM8K test passed for {eval_config['model_name']}")