mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:34:58 +08:00
[CI] Add ci_envs for convenient local testing (#24630)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
561a0baee0
commit
d21a36f5f9
45
tests/ci_envs.py
Normal file
45
tests/ci_envs.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
These envs only work for a small part of the tests, fix what you need!
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
VLLM_CI_NO_SKIP: bool = False
|
||||||
|
VLLM_CI_DTYPE: Optional[str] = None
|
||||||
|
VLLM_CI_HEAD_DTYPE: Optional[str] = None
|
||||||
|
VLLM_CI_HF_DTYPE: Optional[str] = None
|
||||||
|
|
||||||
|
environment_variables: dict[str, Callable[[], Any]] = {
|
||||||
|
# A model family has many models with the same architecture.
|
||||||
|
# By default, a model family tests only one model.
|
||||||
|
# Through this flag, all models can be tested.
|
||||||
|
"VLLM_CI_NO_SKIP": lambda: bool(int(os.getenv("VLLM_CI_NO_SKIP", "0"))),
|
||||||
|
# Allow changing the dtype used by vllm in tests
|
||||||
|
"VLLM_CI_DTYPE": lambda: os.getenv("VLLM_CI_DTYPE", None),
|
||||||
|
# Allow changing the head dtype used by vllm in tests
|
||||||
|
"VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None),
|
||||||
|
# Allow changing the head dtype used by transformers in tests
|
||||||
|
"VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
# lazy evaluation of environment variables
|
||||||
|
if name in environment_variables:
|
||||||
|
return environment_variables[name]()
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def __dir__():
|
||||||
|
return list(environment_variables.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def is_set(name: str):
|
||||||
|
"""Check if an environment variable is explicitly set."""
|
||||||
|
if name in environment_variables:
|
||||||
|
return name in os.environ
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
@ -7,6 +7,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
import tests.ci_envs as ci_envs
|
||||||
from tests.models.utils import (GenerateModelInfo,
|
from tests.models.utils import (GenerateModelInfo,
|
||||||
TokensTextLogprobsPromptLogprobs)
|
TokensTextLogprobsPromptLogprobs)
|
||||||
from vllm.logprobs import Logprob
|
from vllm.logprobs import Logprob
|
||||||
@ -26,19 +27,26 @@ def wikitext_ppl_test(hf_runner,
|
|||||||
|
|
||||||
# A model family has many models with the same architecture,
|
# A model family has many models with the same architecture,
|
||||||
# and we don't need to test each one.
|
# and we don't need to test each one.
|
||||||
if not model_info.enable_test:
|
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
|
||||||
pytest.skip("Skipping test.")
|
pytest.skip("Skipping test.")
|
||||||
|
|
||||||
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
||||||
|
|
||||||
# Allow vllm to test using the given dtype, such as float32
|
# Allow vllm to test using the given dtype, such as float32
|
||||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
|
||||||
|
|
||||||
# Allow vllm to test using hf_overrides
|
# Allow vllm to test using hf_overrides
|
||||||
if model_info.hf_overrides is not None:
|
if model_info.hf_overrides is not None:
|
||||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||||
|
|
||||||
|
# Allow changing the head dtype used by vllm in tests
|
||||||
|
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
|
||||||
|
if "hf_overrides" not in vllm_extra_kwargs:
|
||||||
|
vllm_extra_kwargs["hf_overrides"] = {}
|
||||||
|
vllm_extra_kwargs["hf_overrides"][
|
||||||
|
"head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||||
|
|
||||||
with vllm_runner(model_info.name,
|
with vllm_runner(model_info.name,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
max_model_len=max_length,
|
max_model_len=max_length,
|
||||||
@ -46,7 +54,7 @@ def wikitext_ppl_test(hf_runner,
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
**vllm_extra_kwargs) as vllm_model:
|
**vllm_extra_kwargs) as vllm_model:
|
||||||
# Use max_num_seqs=1 to avoid OOM,
|
# Use max_num_seqs=1 to avoid OOM,
|
||||||
# and batch different requests together.
|
# and avoid batch different requests together.
|
||||||
|
|
||||||
model_config = vllm_model.llm.llm_engine.model_config
|
model_config = vllm_model.llm.llm_engine.model_config
|
||||||
|
|
||||||
@ -91,12 +99,13 @@ def wikitext_ppl_test(hf_runner,
|
|||||||
n_tokens += len(token_log_probs)
|
n_tokens += len(token_log_probs)
|
||||||
vllm_ppl = float(torch.exp(nll_sum / n_tokens))
|
vllm_ppl = float(torch.exp(nll_sum / n_tokens))
|
||||||
vllm_dtype = model_config.dtype
|
vllm_dtype = model_config.dtype
|
||||||
|
head_dtype = model_config.head_dtype
|
||||||
|
|
||||||
# Accelerate ppl test by setting Transformers ppl score to a constant
|
# Accelerate ppl test by setting Transformers ppl score to a constant
|
||||||
if model_info.hf_ppl is None:
|
if model_info.hf_ppl is None:
|
||||||
with hf_runner(
|
with hf_runner(
|
||||||
model_info.name,
|
model_info.name,
|
||||||
dtype=model_info.hf_dtype,
|
dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype,
|
||||||
) as hf_model:
|
) as hf_model:
|
||||||
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
|
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
|
||||||
n_tokens = 0
|
n_tokens = 0
|
||||||
@ -121,7 +130,7 @@ def wikitext_ppl_test(hf_runner,
|
|||||||
|
|
||||||
differ = (vllm_ppl - hf_ppl) / hf_ppl
|
differ = (vllm_ppl - hf_ppl) / hf_ppl
|
||||||
print("Model:", model_info.name)
|
print("Model:", model_info.name)
|
||||||
print("VLLM:", vllm_dtype, vllm_ppl)
|
print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_ppl)
|
||||||
print("Transformers:", hf_dtype, hf_ppl)
|
print("Transformers:", hf_dtype, hf_ppl)
|
||||||
print("Difference (%):", differ * 100)
|
print("Difference (%):", differ * 100)
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import pytest
|
|||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import tests.ci_envs as ci_envs
|
||||||
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
|
from tests.models.utils import (EmbedModelInfo, RerankModelInfo,
|
||||||
check_embeddings_close)
|
check_embeddings_close)
|
||||||
|
|
||||||
@ -168,7 +169,7 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
atol=MTEB_EMBED_TOL):
|
atol=MTEB_EMBED_TOL):
|
||||||
# A model family has many models with the same architecture,
|
# A model family has many models with the same architecture,
|
||||||
# and we don't need to test each one.
|
# and we don't need to test each one.
|
||||||
if not model_info.enable_test:
|
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
|
||||||
pytest.skip("Skipping test.")
|
pytest.skip("Skipping test.")
|
||||||
|
|
||||||
# Test embed_dims, isnan and whether to use normalize
|
# Test embed_dims, isnan and whether to use normalize
|
||||||
@ -176,12 +177,19 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
|
|
||||||
# Allow vllm to test using the given dtype, such as float32
|
# Allow vllm to test using the given dtype, such as float32
|
||||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
|
||||||
|
|
||||||
# Allow vllm to test using hf_overrides
|
# Allow vllm to test using hf_overrides
|
||||||
if model_info.hf_overrides is not None:
|
if model_info.hf_overrides is not None:
|
||||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||||
|
|
||||||
|
# Allow changing the head dtype used by vllm in tests
|
||||||
|
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
|
||||||
|
if "hf_overrides" not in vllm_extra_kwargs:
|
||||||
|
vllm_extra_kwargs["hf_overrides"] = {}
|
||||||
|
vllm_extra_kwargs["hf_overrides"][
|
||||||
|
"head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||||
|
|
||||||
with vllm_runner(model_info.name,
|
with vllm_runner(model_info.name,
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
@ -202,6 +210,7 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||||
MTEB_EMBED_TASKS)
|
MTEB_EMBED_TASKS)
|
||||||
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
|
||||||
|
head_dtype = model_config.head_dtype
|
||||||
|
|
||||||
# Test embed_dims, isnan and whether to use normalize
|
# Test embed_dims, isnan and whether to use normalize
|
||||||
vllm_outputs = vllm_model.embed(example_prompts,
|
vllm_outputs = vllm_model.embed(example_prompts,
|
||||||
@ -211,9 +220,11 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
# Accelerate mteb test by setting
|
# Accelerate mteb test by setting
|
||||||
# SentenceTransformers mteb score to a constant
|
# SentenceTransformers mteb score to a constant
|
||||||
if model_info.mteb_score is None:
|
if model_info.mteb_score is None:
|
||||||
with hf_runner(model_info.name,
|
with hf_runner(
|
||||||
|
model_info.name,
|
||||||
is_sentence_transformer=True,
|
is_sentence_transformer=True,
|
||||||
dtype=model_info.hf_dtype) as hf_model:
|
dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype,
|
||||||
|
) as hf_model:
|
||||||
|
|
||||||
# e.g. setting default parameters for the encode method of hf_runner
|
# e.g. setting default parameters for the encode method of hf_runner
|
||||||
if hf_model_callback is not None:
|
if hf_model_callback is not None:
|
||||||
@ -236,7 +247,8 @@ def mteb_test_embed_models(hf_runner,
|
|||||||
st_dtype = "Constant"
|
st_dtype = "Constant"
|
||||||
|
|
||||||
print("Model:", model_info.name)
|
print("Model:", model_info.name)
|
||||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}",
|
||||||
|
vllm_main_score)
|
||||||
print("SentenceTransformers:", st_dtype, st_main_score)
|
print("SentenceTransformers:", st_dtype, st_main_score)
|
||||||
print("Difference:", st_main_score - vllm_main_score)
|
print("Difference:", st_main_score - vllm_main_score)
|
||||||
|
|
||||||
@ -319,17 +331,24 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
atol=MTEB_RERANK_TOL):
|
atol=MTEB_RERANK_TOL):
|
||||||
# A model family has many models with the same architecture,
|
# A model family has many models with the same architecture,
|
||||||
# and we don't need to test each one.
|
# and we don't need to test each one.
|
||||||
if not model_info.enable_test:
|
if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test:
|
||||||
pytest.skip("Skipping test.")
|
pytest.skip("Skipping test.")
|
||||||
|
|
||||||
# Allow vllm to test using the given dtype, such as float32
|
# Allow vllm to test using the given dtype, such as float32
|
||||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype
|
||||||
|
|
||||||
# Allow vllm to test using hf_overrides
|
# Allow vllm to test using hf_overrides
|
||||||
if model_info.hf_overrides is not None:
|
if model_info.hf_overrides is not None:
|
||||||
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
|
||||||
|
|
||||||
|
# Allow changing the head dtype used by vllm in tests
|
||||||
|
if ci_envs.VLLM_CI_HEAD_DTYPE is not None:
|
||||||
|
if "hf_overrides" not in vllm_extra_kwargs:
|
||||||
|
vllm_extra_kwargs["hf_overrides"] = {}
|
||||||
|
vllm_extra_kwargs["hf_overrides"][
|
||||||
|
"head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE
|
||||||
|
|
||||||
with vllm_runner(model_info.name,
|
with vllm_runner(model_info.name,
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
@ -355,6 +374,7 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
tasks=MTEB_RERANK_TASKS,
|
tasks=MTEB_RERANK_TASKS,
|
||||||
languages=MTEB_RERANK_LANGS)
|
languages=MTEB_RERANK_LANGS)
|
||||||
vllm_dtype = model_config.dtype
|
vllm_dtype = model_config.dtype
|
||||||
|
head_dtype = model_config.head_dtype
|
||||||
|
|
||||||
# Accelerate mteb test by setting
|
# Accelerate mteb test by setting
|
||||||
# SentenceTransformers mteb score to a constant
|
# SentenceTransformers mteb score to a constant
|
||||||
@ -366,7 +386,8 @@ def mteb_test_rerank_models(hf_runner,
|
|||||||
st_dtype = "Constant"
|
st_dtype = "Constant"
|
||||||
|
|
||||||
print("Model:", model_info.name)
|
print("Model:", model_info.name)
|
||||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}",
|
||||||
|
vllm_main_score)
|
||||||
print("SentenceTransformers:", st_dtype, st_main_score)
|
print("SentenceTransformers:", st_dtype, st_main_score)
|
||||||
print("Difference:", st_main_score - vllm_main_score)
|
print("Difference:", st_main_score - vllm_main_score)
|
||||||
|
|
||||||
|
|||||||
@ -1775,16 +1775,21 @@ class ModelConfig:
|
|||||||
such as the lm_head in a generation model,
|
such as the lm_head in a generation model,
|
||||||
or the score or classifier in a classification model.
|
or the score or classifier in a classification model.
|
||||||
|
|
||||||
The default head_dtype based on runner_type.\n
|
`head_dtype` currently only supports pooling models.\n
|
||||||
- The pooling model defaults to using fp32 head,
|
- The pooling model defaults to using fp32 head,
|
||||||
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n
|
you can use --hf-overrides '{"head_dtype": "model"}' to disable it.
|
||||||
- The generate model defaults to not using fp32 head,
|
|
||||||
you can use --hf-overrides '{"head_dtype": "float32"}' to enable it.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
head_dtype = _get_head_dtype(config=self.hf_config,
|
head_dtype = _get_head_dtype(config=self.hf_config,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
runner_type=self.runner_type)
|
runner_type=self.runner_type)
|
||||||
|
|
||||||
|
if self.runner_type != "pooling" and head_dtype != self.dtype:
|
||||||
|
logger.warning_once(
|
||||||
|
"`head_dtype` currently only supports pooling models."
|
||||||
|
"fallback to model dtype [%s].", self.dtype)
|
||||||
|
return self.dtype
|
||||||
|
|
||||||
if head_dtype not in current_platform.supported_dtypes:
|
if head_dtype not in current_platform.supported_dtypes:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"The current platform does not support [%s] head dtype, "
|
"The current platform does not support [%s] head dtype, "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user