diff --git a/tests/ci_envs.py b/tests/ci_envs.py new file mode 100644 index 0000000000000..d16ecce1ef8dd --- /dev/null +++ b/tests/ci_envs.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +These envs only work for a small part of the tests, fix what you need! +""" + +import os +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + VLLM_CI_NO_SKIP: bool = False + VLLM_CI_DTYPE: Optional[str] = None + VLLM_CI_HEAD_DTYPE: Optional[str] = None + VLLM_CI_HF_DTYPE: Optional[str] = None + +environment_variables: dict[str, Callable[[], Any]] = { + # A model family has many models with the same architecture. + # By default, a model family tests only one model. + # Through this flag, all models can be tested. + "VLLM_CI_NO_SKIP": lambda: bool(int(os.getenv("VLLM_CI_NO_SKIP", "0"))), + # Allow changing the dtype used by vllm in tests + "VLLM_CI_DTYPE": lambda: os.getenv("VLLM_CI_DTYPE", None), + # Allow changing the head dtype used by vllm in tests + "VLLM_CI_HEAD_DTYPE": lambda: os.getenv("VLLM_CI_HEAD_DTYPE", None), + # Allow changing the head dtype used by transformers in tests + "VLLM_CI_HF_DTYPE": lambda: os.getenv("VLLM_CI_HF_DTYPE", None), +} + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) + + +def is_set(name: str): + """Check if an environment variable is explicitly set.""" + if name in environment_variables: + return name in os.environ + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/tests/models/language/generation_ppl_test/ppl_utils.py b/tests/models/language/generation_ppl_test/ppl_utils.py index 550e874cf8579..6225bbe3377bd 100644 --- a/tests/models/language/generation_ppl_test/ppl_utils.py +++ b/tests/models/language/generation_ppl_test/ppl_utils.py @@ -7,6 +7,7 @@ import pytest import torch from datasets import load_dataset +import tests.ci_envs as ci_envs from tests.models.utils import (GenerateModelInfo, TokensTextLogprobsPromptLogprobs) from vllm.logprobs import Logprob @@ -26,19 +27,26 @@ def wikitext_ppl_test(hf_runner, # A model family has many models with the same architecture, # and we don't need to test each one. - if not model_info.enable_test: + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: pytest.skip("Skipping test.") dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") # Allow vllm to test using the given dtype, such as float32 vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = model_info.dtype + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype # Allow vllm to test using hf_overrides if model_info.hf_overrides is not None: vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"][ + "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + with vllm_runner(model_info.name, gpu_memory_utilization=0.7, max_model_len=max_length, @@ -46,7 +54,7 @@ def wikitext_ppl_test(hf_runner, enforce_eager=True, **vllm_extra_kwargs) as vllm_model: # Use max_num_seqs=1 to avoid OOM, - # and batch different requests together. + # and avoid batch different requests together. model_config = vllm_model.llm.llm_engine.model_config @@ -91,12 +99,13 @@ def wikitext_ppl_test(hf_runner, n_tokens += len(token_log_probs) vllm_ppl = float(torch.exp(nll_sum / n_tokens)) vllm_dtype = model_config.dtype + head_dtype = model_config.head_dtype # Accelerate ppl test by setting Transformers ppl score to a constant if model_info.hf_ppl is None: with hf_runner( model_info.name, - dtype=model_info.hf_dtype, + dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, ) as hf_model: nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu") n_tokens = 0 @@ -121,7 +130,7 @@ def wikitext_ppl_test(hf_runner, differ = (vllm_ppl - hf_ppl) / hf_ppl print("Model:", model_info.name) - print("VLLM:", vllm_dtype, vllm_ppl) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_ppl) print("Transformers:", hf_dtype, hf_ppl) print("Difference (%):", differ * 100) diff --git a/tests/models/language/pooling_mteb_test/mteb_utils.py b/tests/models/language/pooling_mteb_test/mteb_utils.py index 56a105e96e5ee..7b3c02fbbd9f8 100644 --- a/tests/models/language/pooling_mteb_test/mteb_utils.py +++ b/tests/models/language/pooling_mteb_test/mteb_utils.py @@ -11,6 +11,7 @@ import pytest import requests import torch +import tests.ci_envs as ci_envs from tests.models.utils import (EmbedModelInfo, RerankModelInfo, check_embeddings_close) @@ -168,7 +169,7 @@ def mteb_test_embed_models(hf_runner, atol=MTEB_EMBED_TOL): # A model family has many models with the same architecture, # and we don't need to test each one. - if not model_info.enable_test: + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: pytest.skip("Skipping test.") # Test embed_dims, isnan and whether to use normalize @@ -176,12 +177,19 @@ def mteb_test_embed_models(hf_runner, # Allow vllm to test using the given dtype, such as float32 vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = model_info.dtype + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype # Allow vllm to test using hf_overrides if model_info.hf_overrides is not None: vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"][ + "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, @@ -202,6 +210,7 @@ def mteb_test_embed_models(hf_runner, vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype + head_dtype = model_config.head_dtype # Test embed_dims, isnan and whether to use normalize vllm_outputs = vllm_model.embed(example_prompts, @@ -211,9 +220,11 @@ def mteb_test_embed_models(hf_runner, # Accelerate mteb test by setting # SentenceTransformers mteb score to a constant if model_info.mteb_score is None: - with hf_runner(model_info.name, - is_sentence_transformer=True, - dtype=model_info.hf_dtype) as hf_model: + with hf_runner( + model_info.name, + is_sentence_transformer=True, + dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, + ) as hf_model: # e.g. setting default parameters for the encode method of hf_runner if hf_model_callback is not None: @@ -236,7 +247,8 @@ def mteb_test_embed_models(hf_runner, st_dtype = "Constant" print("Model:", model_info.name) - print("VLLM:", vllm_dtype, vllm_main_score) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", + vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) @@ -319,17 +331,24 @@ def mteb_test_rerank_models(hf_runner, atol=MTEB_RERANK_TOL): # A model family has many models with the same architecture, # and we don't need to test each one. - if not model_info.enable_test: + if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: pytest.skip("Skipping test.") # Allow vllm to test using the given dtype, such as float32 vllm_extra_kwargs = vllm_extra_kwargs or {} - vllm_extra_kwargs["dtype"] = model_info.dtype + vllm_extra_kwargs["dtype"] = ci_envs.VLLM_CI_DTYPE or model_info.dtype # Allow vllm to test using hf_overrides if model_info.hf_overrides is not None: vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides + # Allow changing the head dtype used by vllm in tests + if ci_envs.VLLM_CI_HEAD_DTYPE is not None: + if "hf_overrides" not in vllm_extra_kwargs: + vllm_extra_kwargs["hf_overrides"] = {} + vllm_extra_kwargs["hf_overrides"][ + "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + with vllm_runner(model_info.name, runner="pooling", max_model_len=None, @@ -355,6 +374,7 @@ def mteb_test_rerank_models(hf_runner, tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS) vllm_dtype = model_config.dtype + head_dtype = model_config.head_dtype # Accelerate mteb test by setting # SentenceTransformers mteb score to a constant @@ -366,7 +386,8 @@ def mteb_test_rerank_models(hf_runner, st_dtype = "Constant" print("Model:", model_info.name) - print("VLLM:", vllm_dtype, vllm_main_score) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", + vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 8026d4c9e202c..ee58802766c4c 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1775,16 +1775,21 @@ class ModelConfig: such as the lm_head in a generation model, or the score or classifier in a classification model. - The default head_dtype based on runner_type.\n + `head_dtype` currently only supports pooling models.\n - The pooling model defaults to using fp32 head, - you can use --hf-overrides '{"head_dtype": "model"}' to disable it.\n - - The generate model defaults to not using fp32 head, - you can use --hf-overrides '{"head_dtype": "float32"}' to enable it. + you can use --hf-overrides '{"head_dtype": "model"}' to disable it. """ + head_dtype = _get_head_dtype(config=self.hf_config, dtype=self.dtype, runner_type=self.runner_type) + if self.runner_type != "pooling" and head_dtype != self.dtype: + logger.warning_once( + "`head_dtype` currently only supports pooling models." + "fallback to model dtype [%s].", self.dtype) + return self.dtype + if head_dtype not in current_platform.supported_dtypes: logger.warning_once( "The current platform does not support [%s] head dtype, "