mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 01:35:49 +08:00
135 lines
4.3 KiB
Python
135 lines
4.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# ruff: noqa: SIM117
|
|
import pytest
|
|
|
|
from ...utils import EmbedModelInfo
|
|
|
|
MODELS = [
|
|
EmbedModelInfo("nomic-ai/nomic-embed-text-v1"),
|
|
# EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5"),
|
|
# EmbedModelInfo("nomic-ai/CodeRankEmbed"),
|
|
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe"),
|
|
# EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long"),
|
|
]
|
|
|
|
rope_theta = 1000
|
|
factor = 4.0
|
|
original_max_position_embeddings = 2048
|
|
max_model_len = int(original_max_position_embeddings * factor)
|
|
|
|
|
|
@pytest.mark.parametrize("model_info", MODELS)
|
|
def test_default(model_info, vllm_runner):
|
|
with vllm_runner(
|
|
model_info.name, runner="pooling", max_model_len=None
|
|
) as vllm_model:
|
|
model_config = vllm_model.llm.llm_engine.model_config
|
|
if model_info.name == "nomic-ai/nomic-embed-text-v2-moe":
|
|
# For nomic-embed-text-v2-moe the length is set to 512
|
|
# by sentence_bert_config.json.
|
|
assert model_config.max_model_len == 512
|
|
else:
|
|
assert model_config.max_model_len == original_max_position_embeddings
|
|
|
|
|
|
@pytest.mark.parametrize("model_info", MODELS)
|
|
def test_set_max_model_len_legal(model_info, vllm_runner):
|
|
# set max_model_len <= 512
|
|
with vllm_runner(
|
|
model_info.name, runner="pooling", max_model_len=256
|
|
) as vllm_model:
|
|
model_config = vllm_model.llm.llm_engine.model_config
|
|
assert model_config.max_model_len == 256
|
|
|
|
# set 512 < max_model_len <= 2048
|
|
if model_info.name == "nomic-ai/nomic-embed-text-v2-moe":
|
|
# For nomic-embed-text-v2-moe the length is set to 512
|
|
# by sentence_bert_config.json.
|
|
with pytest.raises(ValueError):
|
|
with vllm_runner(model_info.name, runner="pooling", max_model_len=1024):
|
|
pass
|
|
else:
|
|
with vllm_runner(
|
|
model_info.name, runner="pooling", max_model_len=1024
|
|
) as vllm_model:
|
|
model_config = vllm_model.llm.llm_engine.model_config
|
|
assert model_config.max_model_len == 1024
|
|
|
|
|
|
@pytest.mark.parametrize("model_info", MODELS)
|
|
def test_set_max_model_len_illegal(model_info, vllm_runner):
|
|
# set max_model_len > 2048
|
|
with pytest.raises(ValueError):
|
|
with vllm_runner(model_info.name, runner="pooling", max_model_len=4096):
|
|
pass
|
|
|
|
# set max_model_len > 2048 by hf_overrides
|
|
hf_overrides = {"max_model_len": 4096}
|
|
with pytest.raises(ValueError):
|
|
with vllm_runner(
|
|
model_info.name,
|
|
runner="pooling",
|
|
max_model_len=None,
|
|
hf_overrides=hf_overrides,
|
|
):
|
|
pass
|
|
|
|
|
|
@pytest.mark.parametrize("model_info", MODELS)
|
|
def test_use_rope_scaling_legal(model_info, vllm_runner):
|
|
hf_overrides = {
|
|
"rope_theta": rope_theta,
|
|
"rope_scaling": {
|
|
"rope_type": "yarn",
|
|
"factor": factor,
|
|
"original_max_position_embeddings": original_max_position_embeddings,
|
|
},
|
|
"max_model_len": max_model_len,
|
|
}
|
|
|
|
with vllm_runner(
|
|
model_info.name, runner="pooling", max_model_len=None, hf_overrides=hf_overrides
|
|
):
|
|
pass
|
|
|
|
|
|
@pytest.mark.parametrize("model_info", MODELS)
|
|
def test_use_rope_scaling_illegal(model_info, vllm_runner):
|
|
hf_overrides = {
|
|
"rope_theta": rope_theta,
|
|
"rope_scaling": {
|
|
"rope_type": "yarn",
|
|
"factor": factor,
|
|
"original_max_position_embeddings": original_max_position_embeddings,
|
|
},
|
|
}
|
|
# illegal max_model_len
|
|
with pytest.raises(ValueError):
|
|
with vllm_runner(
|
|
model_info.name,
|
|
runner="pooling",
|
|
max_model_len=max_model_len + 1,
|
|
hf_overrides=hf_overrides,
|
|
):
|
|
pass
|
|
|
|
hf_overrides = {
|
|
"rope_theta": rope_theta,
|
|
"rope_scaling": {
|
|
"rope_type": "yarn",
|
|
"factor": factor,
|
|
"original_max_position_embeddings": original_max_position_embeddings,
|
|
},
|
|
"max_model_len": max_model_len + 1,
|
|
}
|
|
# illegal max_model_len by hf_overrides
|
|
with pytest.raises(ValueError):
|
|
with vllm_runner(
|
|
model_info.name,
|
|
runner="pooling",
|
|
max_model_len=None,
|
|
hf_overrides=hf_overrides,
|
|
):
|
|
pass
|