[CI] Add PPL test for generation models (#24485)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-09-10 21:16:39 +08:00 committed by GitHub
parent d6069887c6
commit bd98842c8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 211 additions and 7 deletions

View File

@ -604,6 +604,16 @@ steps:
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
- pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'
- label: Language Models Test (PPL)
timeout_in_minutes: 110
mirror_hardwares: [amdexperimental]
optional: true
source_file_dependencies:
- vllm/
- tests/models/language/generation_ppl_test
commands:
- pytest -v -s models/language/generation_ppl_test
- label: Language Models Test (Extended Pooling) # 36min
timeout_in_minutes: 50
mirror_hardwares: [amdexperimental]

View File

@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/docs/transformers/perplexity
from typing import Optional, cast
import pytest
import torch
from datasets import load_dataset
from tests.models.utils import (GenerateModelInfo,
TokensTextLogprobsPromptLogprobs)
from vllm.logprobs import Logprob
# See #24485
PPL_TOL = 0.01
MAX_LENGTH = 1024
@torch.inference_mode
def wikitext_ppl_test(hf_runner,
vllm_runner,
model_info: GenerateModelInfo,
max_length=MAX_LENGTH,
vllm_extra_kwargs=None,
atol=PPL_TOL):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if not model_info.enable_test:
pytest.skip("Skipping test.")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
with vllm_runner(model_info.name,
gpu_memory_utilization=0.7,
max_model_len=max_length,
max_num_seqs=1,
enforce_eager=True,
**vllm_extra_kwargs) as vllm_model:
# Use max_num_seqs=1 to avoid OOM,
# and batch different requests together.
model_config = vllm_model.llm.llm_engine.model_config
# Confirm whether vllm is using the correct architecture
if model_info.architecture:
assert (model_info.architecture in model_config.architectures)
max_length = min(model_config.max_model_len - 1, max_length)
stride = max_length
tokenizer = vllm_model.llm.get_tokenizer()
tokens = tokenizer.encode("\n\n".join(dataset["text"]))
n_tokens = len(tokens)
chunks = []
for begin_loc in range(0, n_tokens, stride):
end_loc = min(begin_loc + max_length, n_tokens)
chunks.append(tokens[begin_loc:end_loc])
outputs = vllm_model.generate_greedy_logprobs(prompts=chunks,
max_tokens=1,
num_logprobs=None,
num_prompt_logprobs=0,
use_tqdm=False)
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
n_tokens = 0
for output in outputs:
output = cast(TokensTextLogprobsPromptLogprobs, output)
token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
assert token_datas[0] is None
token_log_probs = []
for token_data in token_datas[1:]:
assert token_data is not None
assert len(token_data) == 1
token_log_prob = list(token_data.values())[0].logprob
token_log_probs.append(token_log_prob)
neg_log_likelihood = -torch.tensor(
token_log_probs, dtype=torch.float32, device="cpu").sum()
nll_sum += neg_log_likelihood
n_tokens += len(token_log_probs)
vllm_ppl = float(torch.exp(nll_sum / n_tokens))
vllm_dtype = model_config.dtype
# Accelerate ppl test by setting Transformers ppl score to a constant
if model_info.hf_ppl is None:
with hf_runner(
model_info.name,
dtype=model_info.hf_dtype,
) as hf_model:
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
n_tokens = 0
for chunk in chunks:
inputs = hf_model.wrap_device(
{"input_ids": torch.tensor([chunk])})
input_ids = inputs["input_ids"]
outputs = hf_model.model(input_ids, labels=input_ids)
neg_log_likelihood = outputs.loss
neg_log_likelihood = neg_log_likelihood.to(torch.float32).cpu()
num_loss_tokens = len(chunk) - 1
nll_sum += neg_log_likelihood * num_loss_tokens
n_tokens += num_loss_tokens
hf_ppl = float(torch.exp(nll_sum / n_tokens))
hf_dtype = next(hf_model.model.parameters()).dtype
else:
hf_ppl = model_info.hf_ppl
hf_dtype = "Constant"
differ = (vllm_ppl - hf_ppl) / hf_ppl
print("Model:", model_info.name)
print("VLLM:", vllm_dtype, vllm_ppl)
print("Transformers:", hf_dtype, hf_ppl)
print("Difference (%):", differ * 100)
# PPL the smaller, the better
# We are not concerned that the vllm PPL is less than Transformers,
# so we only perform one-sided testing.
assert differ < atol

View File

@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.models.utils import GenerateModelInfo
from .ppl_utils import wikitext_ppl_test
MODELS = [
GenerateModelInfo("google/gemma-2b"),
GenerateModelInfo("google/gemma-2-2b"),
GenerateModelInfo("google/gemma-3-4b-it"),
]
@pytest.mark.parametrize("model_info", MODELS)
def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo):
wikitext_ppl_test(hf_runner, vllm_runner, model_info)

View File

@ -0,0 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.models.utils import GenerateModelInfo
from .ppl_utils import wikitext_ppl_test
MODELS = [GenerateModelInfo("openai-community/gpt2-large")]
@pytest.mark.parametrize("model_info", MODELS)
def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo):
wikitext_ppl_test(hf_runner, vllm_runner, model_info)

View File

@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.models.utils import GenerateModelInfo
from .ppl_utils import wikitext_ppl_test
MODELS = [
GenerateModelInfo("Qwen/Qwen3-0.6B"),
GenerateModelInfo("Qwen/Qwen3-0.6B-FP8"),
# transformers:
# Loading a GPTQ quantized model requires optimum, gptqmodel
# GenerateModelInfo("Qwen/Qwen3-0.6B-GPTQ-Int8"),
]
@pytest.mark.parametrize("model_info", MODELS)
def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo):
wikitext_ppl_test(hf_runner, vllm_runner, model_info)

View File

@ -59,7 +59,7 @@ def correctness_test_embed_models(hf_runner,
with hf_runner(
model_info.name,
dtype="float32",
dtype=model_info.hf_dtype,
is_sentence_transformer=True,
) as hf_model:

View File

@ -213,7 +213,7 @@ def mteb_test_embed_models(hf_runner,
if model_info.mteb_score is None:
with hf_runner(model_info.name,
is_sentence_transformer=True,
dtype="float32") as hf_model:
dtype=model_info.hf_dtype) as hf_model:
# e.g. setting default parameters for the encode method of hf_runner
if hf_model_callback is not None:
@ -278,9 +278,12 @@ def run_mteb_rerank(cross_encoder, tasks, languages):
return main_score
def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None):
def mteb_test_rerank_models_hf(hf_runner,
model_name,
hf_dtype="float32",
hf_model_callback=None):
with hf_runner(model_name, is_cross_encoder=True,
dtype="float32") as hf_model:
dtype=hf_dtype) as hf_model:
original_predict = hf_model.predict
@ -357,7 +360,7 @@ def mteb_test_rerank_models(hf_runner,
# SentenceTransformers mteb score to a constant
if model_info.mteb_score is None:
st_main_score, st_dtype = mteb_test_rerank_models_hf(
hf_runner, model_info.name, hf_model_callback)
hf_runner, model_info.name, model_info.hf_dtype, hf_model_callback)
else:
st_main_score = model_info.mteb_score
st_dtype = "Constant"

View File

@ -347,14 +347,15 @@ class ModelInfo:
name: str
architecture: str = ""
dtype: str = "auto"
hf_dtype: str = "float32"
hf_overrides: Optional[dict[str, Any]] = None
default_pooling_type: str = ""
mteb_score: Optional[float] = None
enable_test: bool = True
@dataclass
class EmbedModelInfo(ModelInfo):
mteb_score: Optional[float] = None
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
@ -371,7 +372,7 @@ class LASTPoolingEmbedModelInfo(EmbedModelInfo):
@dataclass
class RerankModelInfo(ModelInfo):
pass
mteb_score: Optional[float] = None
@dataclass
@ -384,6 +385,12 @@ class LASTPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "LAST"
@dataclass
class GenerateModelInfo(ModelInfo):
hf_dtype: str = "auto"
hf_ppl: Optional[float] = None
def dummy_hf_overrides(
hf_config: PretrainedConfig,
*,