wang.yuqi d21a36f5f9
[CI] Add ci_envs for convenient local testing (#24630)
Signed-off-by: wang.yuqi <noooop@126.com>
2025-09-12 08:52:25 +00:00

141 lines
5.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/docs/transformers/perplexity
from typing import Optional, cast
import pytest
import torch
from datasets import load_dataset
import tests.ci_envs as ci_envs
from tests.models.utils import (GenerateModelInfo,
TokensTextLogprobsPromptLogprobs)
from vllm.logprobs import Logprob
# See #24485
PPL_TOL = 0.01
MAX_LENGTH = 1024
@torch.inference_mode
def wikitext_ppl_test(hf_runner,
vllm_runner,
model_info: GenerateModelInfo,
max_length=MAX_LENGTH,
vllm_extra_kwargs=None,
atol=PPL_TOL):
# A model family has many models with the same architecture,
# and we don't need to test each one.
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
pytest.skip("Skipping test.")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
# Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
# Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
# Allow changing the head dtype used by vllm in tests
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
if "hf_overrides" not in vllm_extra_kwargs:
vllm_extra_kwargs["hf_overrides"] = {}
vllm_extra_kwargs["hf_overrides"][
"head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
with vllm_runner(model_info.name,
gpu_memory_utilization=0.7,
max_model_len=max_length,
max_num_seqs=1,
enforce_eager=True,
**vllm_extra_kwargs) as vllm_model:
# Use max_num_seqs=1 to avoid OOM,
# and avoid batch different requests together.
model_config = vllm_model.llm.llm_engine.model_config
# Confirm whether vllm is using the correct architecture
if model_info.architecture:
assert (model_info.architecture in model_config.architectures)
max_length = min(model_config.max_model_len - 1, max_length)
stride = max_length
tokenizer = vllm_model.llm.get_tokenizer()
tokens = tokenizer.encode("\n\n".join(dataset["text"]))
n_tokens = len(tokens)
chunks = []
for begin_loc in range(0, n_tokens, stride):
end_loc = min(begin_loc + max_length, n_tokens)
chunks.append(tokens[begin_loc:end_loc])
outputs = vllm_model.generate_greedy_logprobs(prompts=chunks,
max_tokens=1,
num_logprobs=None,
num_prompt_logprobs=0,
use_tqdm=False)
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
n_tokens = 0
for output in outputs:
output = cast(TokensTextLogprobsPromptLogprobs, output)
token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
assert token_datas[0] is None
token_log_probs = []
for token_data in token_datas[1:]:
assert token_data is not None
assert len(token_data) == 1
token_log_prob = list(token_data.values())[0].logprob
token_log_probs.append(token_log_prob)
neg_log_likelihood = -torch.tensor(
token_log_probs, dtype=torch.float32, device="cpu").sum()
nll_sum += neg_log_likelihood
n_tokens += len(token_log_probs)
vllm_ppl = float(torch.exp(nll_sum / n_tokens))
vllm_dtype = model_config.dtype
head_dtype = model_config.head_dtype
# Accelerate ppl test by setting Transformers ppl score to a constant
if model_info.hf_ppl is None:
with hf_runner(
model_info.name,
dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype,
) as hf_model:
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
n_tokens = 0
for chunk in chunks:
inputs = hf_model.wrap_device(
{"input_ids": torch.tensor([chunk])})
input_ids = inputs["input_ids"]
outputs = hf_model.model(input_ids, labels=input_ids)
neg_log_likelihood = outputs.loss
neg_log_likelihood = neg_log_likelihood.to(torch.float32).cpu()
num_loss_tokens = len(chunk) - 1
nll_sum += neg_log_likelihood * num_loss_tokens
n_tokens += num_loss_tokens
hf_ppl = float(torch.exp(nll_sum / n_tokens))
hf_dtype = next(hf_model.model.parameters()).dtype
else:
hf_ppl = model_info.hf_ppl
hf_dtype = "Constant"
differ = (vllm_ppl - hf_ppl) / hf_ppl
print("Model:", model_info.name)
print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_ppl)
print("Transformers:", hf_dtype, hf_ppl)
print("Difference (%):", differ * 100)
# PPL the smaller, the better
# We are not concerned that the vllm PPL is less than Transformers,
# so we only perform one-sided testing.
assert differ < atol