[CI] Add ci_envs for convenient local testing (#24630)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi 2025-09-12 16:52:25 +08:00 committed by GitHub
parent 561a0baee0
commit d21a36f5f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 98 additions and 18 deletions

45
tests/ci_envs.py Normal file
View File

@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
These envs only work for a small part of the tests, fix what you need!
"""
import os
from typing import TYPE_CHECKING, Any, Callable, Optional
if TYPE_CHECKING:
VLLM_CI_NO_SKIP: bool = False
VLLM_CI_DTYPE: Optional[str] = None
VLLM_CI_HEAD_DTYPE: Optional[str] = None
VLLM_CI_HF_DTYPE: Optional[str] = None
environment_variables: dict[str, Callable[[], Any]] = {
# A model family has many models with the same architecture.
# By default, a model family tests only one model.
# Through this flag, all models can be tested.
"VLLM_CI_NO_SKIP": lambda: bool(int(os.getenv("VLLM_CI_NO_SKIP", "0"))),
# Allow changing the dtype used by vllm in tests
"VLLM_CI_DTYPE": lambda: os.getenv("VLLM_CI_DTYPE", None),
# Allow changing the head dtype used by vllm in tests
"VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None),
# Allow changing the head dtype used by transformers in tests
"VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None),
}
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(environment_variables.keys())
def is_set(name: str):
"""Check if an environment variable is explicitly set."""
if name in environment_variables:
return name in os.environ
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -7,6 +7,7 @@ import pytest
import torch import torch
from datasets import load_dataset from datasets import load_dataset
import tests.ci_envs as ci_envs
from tests.models.utils import (GenerateModelInfo, from tests.models.utils import (GenerateModelInfo,
TokensTextLogprobsPromptLogprobs) TokensTextLogprobsPromptLogprobs)
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
@ -26,19 +27,26 @@ def wikitext_ppl_test(hf_runner,
# A model family has many models with the same architecture, # A model family has many models with the same architecture,
# and we don't need to test each one. # and we don't need to test each one.
if not model_info.enable_test: if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
# Allow vllm to test using the given dtype, such as float32 # Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
# Allow vllm to test using hf_overrides # Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None: if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
# Allow changing the head dtype used by vllm in tests
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
if "hf_overrides" not in vllm_extra_kwargs:
vllm_extra_kwargs["hf_overrides"] = {}
vllm_extra_kwargs["hf_overrides"][
"head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7,
max_model_len=max_length, max_model_len=max_length,
@ -46,7 +54,7 @@ def wikitext_ppl_test(hf_runner,
enforce_eager=True, enforce_eager=True,
**vllm_extra_kwargs) as vllm_model: **vllm_extra_kwargs) as vllm_model:
# Use max_num_seqs=1 to avoid OOM, # Use max_num_seqs=1 to avoid OOM,
# and batch different requests together. # and avoid batch different requests together.
model_config = vllm_model.llm.llm_engine.model_config model_config = vllm_model.llm.llm_engine.model_config
@ -91,12 +99,13 @@ def wikitext_ppl_test(hf_runner,
n_tokens += len(token_log_probs) n_tokens += len(token_log_probs)
vllm_ppl = float(torch.exp(nll_sum / n_tokens)) vllm_ppl = float(torch.exp(nll_sum / n_tokens))
vllm_dtype = model_config.dtype vllm_dtype = model_config.dtype
head_dtype = model_config.head_dtype
# Accelerate ppl test by setting Transformers ppl score to a constant # Accelerate ppl test by setting Transformers ppl score to a constant
if model_info.hf_ppl is None: if model_info.hf_ppl is None:
with hf_runner( with hf_runner(
model_info.name, model_info.name,
dtype=model_info.hf_dtype, dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype,
) as hf_model: ) as hf_model:
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu") nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
n_tokens = 0 n_tokens = 0
@ -121,7 +130,7 @@ def wikitext_ppl_test(hf_runner,
differ = (vllm_ppl - hf_ppl) / hf_ppl differ = (vllm_ppl - hf_ppl) / hf_ppl
print("Model:", model_info.name) print("Model:", model_info.name)
print("VLLM:", vllm_dtype, vllm_ppl) print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_ppl)
print("Transformers:", hf_dtype, hf_ppl) print("Transformers:", hf_dtype, hf_ppl)
print("Difference (%):", differ * 100) print("Difference (%):", differ * 100)

View File

@ -11,6 +11,7 @@ import pytest
import requests import requests
import torch import torch
import tests.ci_envs as ci_envs
from tests.models.utils import (EmbedModelInfo, RerankModelInfo, from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
check_embeddings_close) check_embeddings_close)
@ -168,7 +169,7 @@ def mteb_test_embed_models(hf_runner,
atol=MTEB_EMBED_TOL): atol=MTEB_EMBED_TOL):
# A model family has many models with the same architecture, # A model family has many models with the same architecture,
# and we don't need to test each one. # and we don't need to test each one.
if not model_info.enable_test: if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
# Test embed_dims, isnan and whether to use normalize # Test embed_dims, isnan and whether to use normalize
@ -176,12 +177,19 @@ def mteb_test_embed_models(hf_runner,
# Allow vllm to test using the given dtype, such as float32 # Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
# Allow vllm to test using hf_overrides # Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None: if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
# Allow changing the head dtype used by vllm in tests
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
if "hf_overrides" not in vllm_extra_kwargs:
vllm_extra_kwargs["hf_overrides"] = {}
vllm_extra_kwargs["hf_overrides"][
"head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
runner="pooling", runner="pooling",
max_model_len=None, max_model_len=None,
@ -202,6 +210,7 @@ def mteb_test_embed_models(hf_runner,
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS) MTEB_EMBED_TASKS)
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
head_dtype = model_config.head_dtype
# Test embed_dims, isnan and whether to use normalize # Test embed_dims, isnan and whether to use normalize
vllm_outputs = vllm_model.embed(example_prompts, vllm_outputs = vllm_model.embed(example_prompts,
@ -211,9 +220,11 @@ def mteb_test_embed_models(hf_runner,
# Accelerate mteb test by setting # Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant # SentenceTransformers mteb score to a constant
if model_info.mteb_score is None: if model_info.mteb_score is None:
with hf_runner(model_info.name, with hf_runner(
model_info.name,
is_sentence_transformer=True, is_sentence_transformer=True,
dtype=model_info.hf_dtype) as hf_model: dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype,
) as hf_model:
# e.g. setting default parameters for the encode method of hf_runner # e.g. setting default parameters for the encode method of hf_runner
if hf_model_callback is not None: if hf_model_callback is not None:
@ -236,7 +247,8 @@ def mteb_test_embed_models(hf_runner,
st_dtype = "Constant" st_dtype = "Constant"
print("Model:", model_info.name) print("Model:", model_info.name)
print("VLLM:", vllm_dtype, vllm_main_score) print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}",
vllm_main_score)
print("SentenceTransformers:", st_dtype, st_main_score) print("SentenceTransformers:", st_dtype, st_main_score)
print("Difference:", st_main_score - vllm_main_score) print("Difference:", st_main_score - vllm_main_score)
@ -319,17 +331,24 @@ def mteb_test_rerank_models(hf_runner,
atol=MTEB_RERANK_TOL): atol=MTEB_RERANK_TOL):
# A model family has many models with the same architecture, # A model family has many models with the same architecture,
# and we don't need to test each one. # and we don't need to test each one.
if not model_info.enable_test: if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
# Allow vllm to test using the given dtype, such as float32 # Allow vllm to test using the given dtype, such as float32
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
# Allow vllm to test using hf_overrides # Allow vllm to test using hf_overrides
if model_info.hf_overrides is not None: if model_info.hf_overrides is not None:
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
# Allow changing the head dtype used by vllm in tests
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
if "hf_overrides" not in vllm_extra_kwargs:
vllm_extra_kwargs["hf_overrides"] = {}
vllm_extra_kwargs["hf_overrides"][
"head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
runner="pooling", runner="pooling",
max_model_len=None, max_model_len=None,
@ -355,6 +374,7 @@ def mteb_test_rerank_models(hf_runner,
tasks=MTEB_RERANK_TASKS, tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS) languages=MTEB_RERANK_LANGS)
vllm_dtype = model_config.dtype vllm_dtype = model_config.dtype
head_dtype = model_config.head_dtype
# Accelerate mteb test by setting # Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant # SentenceTransformers mteb score to a constant
@ -366,7 +386,8 @@ def mteb_test_rerank_models(hf_runner,
st_dtype = "Constant" st_dtype = "Constant"
print("Model:", model_info.name) print("Model:", model_info.name)
print("VLLM:", vllm_dtype, vllm_main_score) print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}",
vllm_main_score)
print("SentenceTransformers:", st_dtype, st_main_score) print("SentenceTransformers:", st_dtype, st_main_score)
print("Difference:", st_main_score - vllm_main_score) print("Difference:", st_main_score - vllm_main_score)

View File

@ -1775,16 +1775,21 @@ class ModelConfig:
such as the lm_head in a generation model, such as the lm_head in a generation model,
or the score or classifier in a classification model. or the score or classifier in a classification model.
The default head_dtype based on runner_type.\n `head_dtype` currently only supports pooling models.\n
- The pooling model defaults to using fp32 head, - The pooling model defaults to using fp32 head,
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n you can use --hf-overrides '{"head_dtype": "model"}' to disable it.
- The generate model defaults to not using fp32 head,
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
""" """
head_dtype = _get_head_dtype(config=self.hf_config, head_dtype = _get_head_dtype(config=self.hf_config,
dtype=self.dtype, dtype=self.dtype,
runner_type=self.runner_type) runner_type=self.runner_type)
if self.runner_type != "pooling" and head_dtype != self.dtype:
logger.warning_once(
"`head_dtype` currently only supports pooling models."
"fallback to model dtype [%s].", self.dtype)
return self.dtype
if head_dtype not in current_platform.supported_dtypes: if head_dtype not in current_platform.supported_dtypes:
logger.warning_once( logger.warning_once(
"The current platform does not support [%s] head dtype, " "The current platform does not support [%s] head dtype, "