mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 03:15:01 +08:00
[CI] Generalize gsm8k test args and add Qwen3-Next MTP B200 test (#30723)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
66c3537e5d
commit
10ee1c64cf
@ -654,7 +654,7 @@ steps:
|
|||||||
- vllm/model_executor/layers/quantization
|
- vllm/model_executor/layers/quantization
|
||||||
autorun_on_main: true
|
autorun_on_main: true
|
||||||
commands:
|
commands:
|
||||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt
|
||||||
|
|
||||||
- label: OpenAI API correctness # 22min
|
- label: OpenAI API correctness # 22min
|
||||||
timeout_in_minutes: 30
|
timeout_in_minutes: 30
|
||||||
@ -1064,7 +1064,7 @@ steps:
|
|||||||
- csrc/
|
- csrc/
|
||||||
- vllm/model_executor/layers/quantization
|
- vllm/model_executor/layers/quantization
|
||||||
commands:
|
commands:
|
||||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt
|
||||||
|
|
||||||
##### 1 GPU test #####
|
##### 1 GPU test #####
|
||||||
##### multi gpus test #####
|
##### multi gpus test #####
|
||||||
|
|||||||
@ -7,9 +7,8 @@ This directory contains a replacement for the lm-eval-harness GSM8K evaluation,
|
|||||||
### Run tests with pytest (like buildkite)
|
### Run tests with pytest (like buildkite)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
|
pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
|
||||||
--config-list-file=configs/models-small.txt \
|
--config-list-file=configs/models-small.txt
|
||||||
--tp-size=1
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run standalone evaluation script
|
### Run standalone evaluation script
|
||||||
@ -31,5 +30,11 @@ model_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
|||||||
accuracy_threshold: 0.54 # Minimum expected accuracy
|
accuracy_threshold: 0.54 # Minimum expected accuracy
|
||||||
num_questions: 1319 # Number of questions (default: full test set)
|
num_questions: 1319 # Number of questions (default: full test set)
|
||||||
num_fewshot: 5 # Few-shot examples from train set
|
num_fewshot: 5 # Few-shot examples from train set
|
||||||
max_model_len: 4096 # Model context length
|
server_args: "--max-model-len 4096 --tensor-parallel-size 2" # Server arguments
|
||||||
|
env: # Environment variables (optional)
|
||||||
|
VLLM_USE_FLASHINFER_MOE_FP4: "1"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The `server_args` field accepts any arguments that can be passed to `vllm serve`.
|
||||||
|
|
||||||
|
The `env` field accepts a dictionary of environment variables to set for the server process.
|
||||||
|
|||||||
@ -2,5 +2,4 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
|||||||
accuracy_threshold: 0.72
|
accuracy_threshold: 0.72
|
||||||
num_questions: 1319
|
num_questions: 1319
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
max_model_len: 4096
|
server_args: "--enforce-eager --max-model-len 4096"
|
||||||
|
|
||||||
|
|||||||
@ -2,4 +2,4 @@ model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
|||||||
accuracy_threshold: 0.74
|
accuracy_threshold: 0.74
|
||||||
num_questions: 1319
|
num_questions: 1319
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
max_model_len: 4096
|
server_args: "--enforce-eager --max-model-len 4096"
|
||||||
|
|||||||
@ -2,4 +2,4 @@ model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
|
|||||||
accuracy_threshold: 0.31
|
accuracy_threshold: 0.31
|
||||||
num_questions: 1319
|
num_questions: 1319
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
max_model_len: 4096
|
server_args: "--enforce-eager --max-model-len 4096"
|
||||||
|
|||||||
@ -2,4 +2,4 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
|||||||
accuracy_threshold: 0.45
|
accuracy_threshold: 0.45
|
||||||
num_questions: 1319
|
num_questions: 1319
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
max_model_len: 4096
|
server_args: "--enforce-eager --max-model-len 4096"
|
||||||
|
|||||||
@ -2,4 +2,4 @@ model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
|
|||||||
accuracy_threshold: 0.60
|
accuracy_threshold: 0.60
|
||||||
num_questions: 1319
|
num_questions: 1319
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
max_model_len: 4096
|
server_args: "--enforce-eager --max-model-len 4096"
|
||||||
|
|||||||
@ -2,4 +2,4 @@ model_name: "Qwen/Qwen3-0.6B-FP8"
|
|||||||
accuracy_threshold: 0.375
|
accuracy_threshold: 0.375
|
||||||
num_questions: 1319
|
num_questions: 1319
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
max_model_len: 4096
|
server_args: "--enforce-eager --max-model-len 4096"
|
||||||
|
|||||||
@ -2,5 +2,4 @@ model_name: "nvidia/Qwen3-30B-A3B-FP4"
|
|||||||
accuracy_threshold: 0.89
|
accuracy_threshold: 0.89
|
||||||
num_questions: 1319
|
num_questions: 1319
|
||||||
num_fewshot: 5
|
num_fewshot: 5
|
||||||
max_model_len: 4096
|
server_args: "--enforce-eager --max-model-len 4096"
|
||||||
|
|
||||||
|
|||||||
12
tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
Normal file
12
tests/evals/gsm8k/configs/Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
model_name: "nm-testing/Qwen3-Next-80B-A3B-Instruct-NVFP4"
|
||||||
|
accuracy_threshold: 0.75
|
||||||
|
num_questions: 1319
|
||||||
|
num_fewshot: 5
|
||||||
|
server_args: >-
|
||||||
|
--enforce-eager
|
||||||
|
--max-model-len 4096
|
||||||
|
--tensor-parallel-size 2
|
||||||
|
--enable-expert-parallel
|
||||||
|
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}'
|
||||||
|
env:
|
||||||
|
VLLM_USE_FLASHINFER_MOE_FP4: "1"
|
||||||
@ -3,3 +3,4 @@ Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
|||||||
Qwen1.5-MoE-W4A16-CT.yaml
|
Qwen1.5-MoE-W4A16-CT.yaml
|
||||||
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
||||||
Qwen3-30B-A3B-NVFP4.yaml
|
Qwen3-30B-A3B-NVFP4.yaml
|
||||||
|
Qwen3-Next-80B-A3B-NVFP4-EP2.yaml
|
||||||
|
|||||||
@ -11,14 +11,12 @@ def pytest_addoption(parser):
|
|||||||
default="configs/models-small.txt",
|
default="configs/models-small.txt",
|
||||||
help="File containing list of config files to test",
|
help="File containing list of config files to test",
|
||||||
)
|
)
|
||||||
parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size")
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
"""Generate test parameters from config files."""
|
"""Generate test parameters from config files."""
|
||||||
if "config_filename" in metafunc.fixturenames:
|
if "config_filename" in metafunc.fixturenames:
|
||||||
config_list_file = metafunc.config.getoption("--config-list-file")
|
config_list_file = metafunc.config.getoption("--config-list-file")
|
||||||
tp_size = metafunc.config.getoption("--tp-size")
|
|
||||||
|
|
||||||
# Handle both relative and absolute paths
|
# Handle both relative and absolute paths
|
||||||
config_list_path = Path(config_list_file)
|
config_list_path = Path(config_list_file)
|
||||||
@ -55,9 +53,9 @@ def pytest_generate_tests(metafunc):
|
|||||||
# Generate test parameters
|
# Generate test parameters
|
||||||
if config_files:
|
if config_files:
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
["config_filename", "tp_size"],
|
"config_filename",
|
||||||
[(config_file, int(tp_size)) for config_file in config_files],
|
config_files,
|
||||||
ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files],
|
ids=[config_file.stem for config_file in config_files],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("No config files found, test will be skipped")
|
print("No config files found, test will be skipped")
|
||||||
|
|||||||
@ -5,30 +5,31 @@ GSM8K evaluation using vLLM server and isolated GSM8K script.
|
|||||||
Replacement for lm-eval-harness with better performance and control.
|
Replacement for lm-eval-harness with better performance and control.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
pytest -s -v test_gsm8k_correctness.py \
|
pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
|
||||||
--config-list-file=configs/models-small.txt \
|
--config-list-file=configs/models-small.txt
|
||||||
--tp-size=1
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import shlex
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
|
||||||
from .gsm8k_eval import evaluate_gsm8k
|
from .gsm8k_eval import evaluate_gsm8k
|
||||||
|
|
||||||
RTOL = 0.08 # Relative tolerance for accuracy comparison
|
TOL = 0.08 # Absolute tolerance for accuracy comparison
|
||||||
|
|
||||||
|
|
||||||
def launch_gsm8k_eval(eval_config, server_url, tp_size):
|
def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
|
||||||
"""Launch GSM8K evaluation using our isolated script."""
|
"""Run GSM8K evaluation using our isolated script."""
|
||||||
# Extract host and port from server URL
|
# Extract host and port from server URL
|
||||||
if "://" in server_url:
|
if "://" in server_url:
|
||||||
server_url = server_url.split("://")[1]
|
server_url = server_url.split("://")[1]
|
||||||
|
|
||||||
host_port = server_url.split("/")[0] # Remove path if present
|
host_port = server_url.split("/")[0] # Remove path if present
|
||||||
if ":" in host_port:
|
if ":" in host_port:
|
||||||
host, port = host_port.split(":")
|
host, p = host_port.split(":")
|
||||||
port = int(port)
|
port = int(p)
|
||||||
else:
|
else:
|
||||||
host = host_port
|
host = host_port
|
||||||
port = 8000
|
port = 8000
|
||||||
@ -48,46 +49,57 @@ def launch_gsm8k_eval(eval_config, server_url, tp_size):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def test_gsm8k_correctness_param(config_filename, tp_size):
|
def test_gsm8k_correctness(config_filename):
|
||||||
"""Test GSM8K correctness for a given model configuration."""
|
"""Test GSM8K correctness for a given model configuration."""
|
||||||
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
# Server arguments
|
# Parse server arguments from config (use shlex to handle quoted strings)
|
||||||
server_args = [
|
server_args_str = eval_config.get("server_args", "")
|
||||||
"--max-model-len",
|
server_args = shlex.split(server_args_str) if server_args_str else []
|
||||||
str(eval_config.get("max_model_len", 4096)),
|
|
||||||
"--enforce-eager",
|
# Add standard server arguments
|
||||||
|
server_args.extend(
|
||||||
|
[
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
"--tensor-parallel-size",
|
|
||||||
str(tp_size),
|
|
||||||
]
|
]
|
||||||
|
)
|
||||||
|
|
||||||
env_dict = eval_config.get("env", None)
|
env_dict = eval_config.get("env", None)
|
||||||
|
|
||||||
|
print(f"Starting GSM8K evaluation for model: {eval_config['model_name']}")
|
||||||
|
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
|
||||||
|
print(f"Number of questions: {eval_config['num_questions']}")
|
||||||
|
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
|
||||||
|
print(f"Server args: {' '.join(server_args)}")
|
||||||
|
|
||||||
# Launch server and run evaluation
|
# Launch server and run evaluation
|
||||||
with RemoteOpenAIServer(
|
with RemoteOpenAIServer(
|
||||||
eval_config["model_name"], server_args, env_dict=env_dict, max_wait_seconds=480
|
eval_config["model_name"],
|
||||||
|
server_args,
|
||||||
|
env_dict=env_dict,
|
||||||
|
max_wait_seconds=600,
|
||||||
) as remote_server:
|
) as remote_server:
|
||||||
server_url = remote_server.url_for("v1")
|
server_url = remote_server.url_for("v1")
|
||||||
|
print(f"Server started at: {server_url}")
|
||||||
|
|
||||||
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
|
results = run_gsm8k_eval(eval_config, server_url)
|
||||||
|
|
||||||
# Check accuracy against threshold
|
measured_metric = results["accuracy"]
|
||||||
measured_accuracy = results["accuracy"]
|
expected_metric = eval_config["accuracy_threshold"]
|
||||||
expected_accuracy = eval_config["accuracy_threshold"]
|
|
||||||
|
|
||||||
print(f"GSM8K Results for {eval_config['model_name']}:")
|
print(f"GSM8K Results for {eval_config['model_name']}:")
|
||||||
print(f" Accuracy: {measured_accuracy:.3f}")
|
print(f" Measured metric: {measured_metric:.4f}")
|
||||||
print(f" Expected: {expected_accuracy:.3f}")
|
print(f" Expected metric: {expected_metric:.4f}")
|
||||||
|
print(f" Tolerance: {TOL:.4f}")
|
||||||
print(f" Questions: {results['num_questions']}")
|
print(f" Questions: {results['num_questions']}")
|
||||||
print(f" Invalid rate: {results['invalid_rate']:.3f}")
|
print(f" Invalid rate: {results['invalid_rate']:.3f}")
|
||||||
print(f" Latency: {results['latency']:.1f}s")
|
print(f" Latency: {results['latency']:.1f}s")
|
||||||
print(f" QPS: {results['questions_per_second']:.1f}")
|
print(f" QPS: {results['questions_per_second']:.1f}")
|
||||||
|
|
||||||
# Verify accuracy is within tolerance
|
# Verify metric is within tolerance
|
||||||
assert measured_accuracy >= expected_accuracy - RTOL, (
|
assert measured_metric >= expected_metric - TOL, (
|
||||||
f"Accuracy too low: {measured_accuracy:.3f} < "
|
f"GSM8K metric too low: {measured_metric:.4f} < "
|
||||||
f"{expected_accuracy:.3f} - {RTOL:.3f}"
|
f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"✅ GSM8K test passed for {eval_config['model_name']}")
|
print(f"✅ GSM8K test passed for {eval_config['model_name']}")
|
||||||
|
|||||||
@ -626,17 +626,11 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||||
|
# only (no EP).
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||||
|
|
||||||
assert layer.expert_map is None, (
|
|
||||||
"Expert Parallelism / expert_map "
|
|
||||||
"is currently not supported for "
|
|
||||||
"CompressedTensorsW4A4Nvfp4MoEMethod."
|
|
||||||
)
|
|
||||||
assert self.moe_quant_config is not None
|
assert self.moe_quant_config is not None
|
||||||
|
|
||||||
# Cutlass moe takes in activations in BF16/Half precision
|
|
||||||
# and fp4 quantized weights loaded from the checkpoint
|
|
||||||
return cutlass_moe_fp4(
|
return cutlass_moe_fp4(
|
||||||
a=x,
|
a=x,
|
||||||
w1_fp4=layer.w13_weight,
|
w1_fp4=layer.w13_weight,
|
||||||
@ -644,6 +638,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
quant_config=self.moe_quant_config,
|
quant_config=self.moe_quant_config,
|
||||||
|
expert_map=layer.expert_map,
|
||||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||||
# TODO(bnell): derive these from arguments
|
# TODO(bnell): derive these from arguments
|
||||||
m=x.shape[0],
|
m=x.shape[0],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user