[CI] improve embed testing (#18747)

This commit is contained in:
wang.yuqi 2025-05-28 15:16:35 +08:00 committed by GitHub
parent 0c492b7824
commit de65fc8e1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 248 additions and 178 deletions

View File

@ -4,6 +4,7 @@ import os
import pytest import pytest
from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS, from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS,
MTEB_EMBED_TOL,
OpenAIClientMtebEncoder, OpenAIClientMtebEncoder,
run_mteb_embed_task, run_mteb_embed_task,
run_mteb_embed_task_st) run_mteb_embed_task_st)
@ -38,4 +39,4 @@ def test_mteb(server):
print("SentenceTransformer main score: ", st_main_score) print("SentenceTransformer main score: ", st_main_score)
print("Difference: ", st_main_score - vllm_main_score) print("Difference: ", st_main_score - vllm_main_score)
assert st_main_score == pytest.approx(vllm_main_score, rel=1e-4) assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)

View File

@ -11,7 +11,8 @@ import requests
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...models.utils import run_embedding_correctness_test from ...models.language.pooling.embed_utils import (
run_embedding_correctness_test)
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "intfloat/multilingual-e5-small" MODEL_NAME = "intfloat/multilingual-e5-small"

View File

@ -11,7 +11,9 @@ import pytest
from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.entrypoints.openai.protocol import EmbeddingResponse
from ...conftest import HfRunner from ...conftest import HfRunner
from ...models.utils import EmbedModelInfo, run_embedding_correctness_test from ...models.language.pooling.embed_utils import (
run_embedding_correctness_test)
from ...models.utils import EmbedModelInfo
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODELS = [ MODELS = [

View File

@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Optional
import pytest
from tests.conftest import HfRunner
from tests.models.utils import (EmbedModelInfo, check_embeddings_close,
matryoshka_fy)
def run_embedding_correctness_test(
hf_model: "HfRunner",
inputs: list[str],
vllm_outputs: Sequence[list[float]],
dimensions: Optional[int] = None,
):
hf_outputs = hf_model.encode(inputs)
if dimensions:
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
def correctness_test_embed_models(hf_runner,
vllm_runner,
model_info: EmbedModelInfo,
example_prompts,
vllm_extra_kwargs=None,
hf_model_callback=None):
if not model_info.enable_test:
# A model family has many models with the same architecture,
# and we don't need to test each one.
pytest.skip("Skipping test.")
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
# This makes the input_ids different between hf_model and vllm_model.
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]
vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype
with vllm_runner(model_info.name,
task="embed",
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
model_dtype = getattr(
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
vllm_dtype)
with hf_runner(
model_info.name,
dtype=model_dtype,
is_sentence_transformer=True,
) as hf_model:
if hf_model_callback is not None:
hf_model_callback(hf_model)
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

View File

@ -80,18 +80,19 @@ def run_mteb_embed_task_st(model_name, tasks):
def mteb_test_embed_models(hf_runner, def mteb_test_embed_models(hf_runner,
vllm_runner, vllm_runner,
model_info: EmbedModelInfo, model_info: EmbedModelInfo,
vllm_extra_kwargs=None): vllm_extra_kwargs=None,
hf_model_callback=None):
if not model_info.enable_test: if not model_info.enable_test:
# A model family has many models with the same architecture, # A model family has many models with the same architecture,
# and we don't need to test each one. # and we don't need to test each one.
pytest.skip("Skipping test.") pytest.skip("Skipping test.")
vllm_extra_kwargs = vllm_extra_kwargs or {} vllm_extra_kwargs = vllm_extra_kwargs or {}
vllm_extra_kwargs["dtype"] = model_info.dtype
with vllm_runner(model_info.name, with vllm_runner(model_info.name,
task="embed", task="embed",
max_model_len=None, max_model_len=None,
dtype=model_info.dtype,
**vllm_extra_kwargs) as vllm_model: **vllm_extra_kwargs) as vllm_model:
if model_info.architecture: if model_info.architecture:
@ -108,10 +109,14 @@ def mteb_test_embed_models(hf_runner,
with set_default_torch_dtype(model_dtype) and hf_runner( with set_default_torch_dtype(model_dtype) and hf_runner(
model_info.name, is_sentence_transformer=True, model_info.name, is_sentence_transformer=True,
dtype=model_dtype) as hf_model: dtype=model_dtype) as hf_model:
if hf_model_callback is not None:
hf_model_callback(hf_model)
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
print("VLLM:", vllm_dtype, vllm_main_score) print("VLLM:", vllm_dtype, vllm_main_score)
print("SentenceTransformer:", model_dtype, st_main_score) print("SentenceTransformer:", model_dtype, st_main_score)
print("Difference:", st_main_score - vllm_main_score) print("Difference:", st_main_score - vllm_main_score)
assert st_main_score == pytest.approx(vllm_main_score, rel=MTEB_EMBED_TOL) assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)

View File

@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [
########## BertModel
EmbedModelInfo("BAAI/bge-base-en",
architecture="BertModel",
enable_test=True),
EmbedModelInfo("BAAI/bge-base-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-en",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-en",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-zh",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-zh-noinstruct",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-base-en-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-base-zh-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-en-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-small-zh-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-en-v1.5",
architecture="BertModel",
enable_test=False),
EmbedModelInfo("BAAI/bge-large-zh-v1.5",
architecture="BertModel",
enable_test=False),
########## XLMRobertaModel
EmbedModelInfo("BAAI/bge-m3",
architecture="XLMRobertaModel",
enable_test=True),
########## Qwen2Model
EmbedModelInfo("BAAI/bge-code-v1",
architecture="Qwen2Model",
dtype="float32",
enable_test=True),
]
@pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_correctness(hf_runner, vllm_runner,
model_info: EmbedModelInfo,
example_prompts) -> None:
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts)

View File

@ -3,7 +3,8 @@ from typing import Any
import pytest import pytest
from ...utils import EmbedModelInfo, run_embedding_correctness_test from .embed_utils import EmbedModelInfo, correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [ MODELS = [
########## BertModel ########## BertModel
@ -53,9 +54,8 @@ MODELS = [
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner, def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None: model_info: EmbedModelInfo) -> None:
from .mteb_utils import mteb_test_embed_models
vllm_extra_kwargs: dict[str, Any] = {} vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel": if model_info.architecture == "GteNewModel":
@ -66,28 +66,13 @@ def test_models_mteb(hf_runner, vllm_runner,
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, def test_embed_models_correctness(hf_runner, vllm_runner,
example_prompts) -> None: model_info: EmbedModelInfo,
if not model_info.enable_test: example_prompts) -> None:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
vllm_extra_kwargs: dict[str, Any] = {} vllm_extra_kwargs: dict[str, Any] = {}
if model_info.architecture == "GteNewModel": if model_info.architecture == "GteNewModel":
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
with vllm_runner(model_info.name, correctness_test_embed_models(hf_runner, vllm_runner, model_info,
task="embed", example_prompts, vllm_extra_kwargs)
dtype=model_info.dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model_info.name,
dtype=model_info.dtype,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

View File

@ -1,9 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import partial
import pytest import pytest
from vllm import PoolingParams from vllm import PoolingParams
from ...utils import check_embeddings_close, matryoshka_fy from .embed_utils import (EmbedModelInfo, check_embeddings_close,
correctness_test_embed_models, matryoshka_fy)
from .mteb_utils import mteb_test_embed_models
SCORING_MODELS = [ SCORING_MODELS = [
"jinaai/jina-reranker-v2-base-multilingual", # Roberta "jinaai/jina-reranker-v2-base-multilingual", # Roberta
@ -25,16 +29,10 @@ TEXTS_2 = [
] ]
EMBEDDING_MODELS = [ EMBEDDING_MODELS = [
"jinaai/jina-embeddings-v3", EmbedModelInfo("jinaai/jina-embeddings-v3",
] architecture="XLMRobertaModel",
is_matryoshka=True,
EMBEDDING_PROMPTS = [ dtype="float32")
"Follow the white rabbit.", # English
"Sigue al conejo blanco.", # Spanish
"Suis le lapin blanc.", # French
"跟着白兔走。", # Chinese
"اتبع الأرنب الأبيض.", # Arabic
"Folge dem weißen Kaninchen.", # German
] ]
@ -80,73 +78,66 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
@pytest.fixture(scope="module", params=EMBEDDING_MODELS) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def emb_model_name(request): def test_embed_models_mteb(hf_runner, vllm_runner,
yield request.param model_info: EmbedModelInfo) -> None:
def hf_model_callback(model):
model.encode = partial(model.encode, task="text-matching")
mteb_test_embed_models(hf_runner,
vllm_runner,
model_info,
hf_model_callback=hf_model_callback)
def test_is_matryoshka(vllm_runner, emb_model_name): @pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
with vllm_runner(emb_model_name, task="embed", def test_embed_models_correctness(hf_runner, vllm_runner,
max_model_len=None) as vllm_model: model_info: EmbedModelInfo,
assert vllm_model.model.llm_engine.model_config.is_matryoshka example_prompts) -> None:
def hf_model_callback(model):
model.encode = partial(model.encode, task="text-matching")
correctness_test_embed_models(hf_runner,
vllm_runner,
model_info,
example_prompts,
hf_model_callback=hf_model_callback)
@pytest.mark.parametrize("model", EMBEDDING_MODELS) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_embeddings(
hf_runner,
vllm_runner,
model,
dtype: str,
monkeypatch,
) -> None:
example_prompts = EMBEDDING_PROMPTS
with hf_runner(
model,
dtype=dtype,
is_sentence_transformer=True,
) as hf_model:
hf_outputs = hf_model.encode(example_prompts, task="text-matching")
with vllm_runner(model, task="embed", dtype=dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
@pytest.mark.parametrize("model", EMBEDDING_MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dimensions", [16, 32]) @pytest.mark.parametrize("dimensions", [16, 32])
def test_matryoshka( def test_matryoshka(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
model, model_info,
dtype: str, dtype: str,
dimensions: int, dimensions: int,
example_prompts,
monkeypatch, monkeypatch,
) -> None: ) -> None:
if not model_info.is_matryoshka:
pytest.skip("Model is not matryoshka")
example_prompts = EMBEDDING_PROMPTS # ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with hf_runner( with hf_runner(
model, model_info.name,
dtype=dtype, dtype=dtype,
is_sentence_transformer=True, is_sentence_transformer=True,
) as hf_model: ) as hf_model:
hf_outputs = hf_model.encode(example_prompts, task="text-matching") hf_outputs = hf_model.encode(example_prompts, task="text-matching")
hf_outputs = matryoshka_fy(hf_outputs, dimensions) hf_outputs = matryoshka_fy(hf_outputs, dimensions)
with vllm_runner(model, task="embed", dtype=dtype, with vllm_runner(model_info.name,
task="embed",
dtype=dtype,
max_model_len=None) as vllm_model: max_model_len=None) as vllm_model:
assert vllm_model.model.llm_engine.model_config.is_matryoshka
matryoshka_dimensions = ( matryoshka_dimensions = (
vllm_model.model.llm_engine.model_config.matryoshka_dimensions) vllm_model.model.llm_engine.model_config.matryoshka_dimensions)
assert matryoshka_dimensions is not None assert matryoshka_dimensions is not None

View File

@ -2,7 +2,8 @@
import pytest import pytest
from ...utils import EmbedModelInfo, run_embedding_correctness_test from .embed_utils import EmbedModelInfo, correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [ MODELS = [
EmbedModelInfo("nomic-ai/nomic-embed-text-v1", EmbedModelInfo("nomic-ai/nomic-embed-text-v1",
@ -13,6 +14,9 @@ MODELS = [
architecture="NomicBertModel", architecture="NomicBertModel",
dtype="float32", dtype="float32",
enable_test=False), enable_test=False),
EmbedModelInfo("nomic-ai/CodeRankEmbed",
architecture="NomicBertModel",
enable_test=False),
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe",
architecture="NomicBertModel", architecture="NomicBertModel",
dtype="float32", dtype="float32",
@ -21,30 +25,14 @@ MODELS = [
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb(hf_runner, vllm_runner, def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None: model_info: EmbedModelInfo) -> None:
from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info) mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, def test_embed_models_correctness(hf_runner, vllm_runner,
example_prompts) -> None: model_info: EmbedModelInfo,
if not model_info.enable_test: example_prompts) -> None:
pytest.skip("Skipping test.") correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts)
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model_info.name,
dtype=model_info.dtype,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

View File

@ -2,7 +2,8 @@
import pytest import pytest
from ...utils import EmbedModelInfo, run_embedding_correctness_test from .embed_utils import EmbedModelInfo, correctness_test_embed_models
from .mteb_utils import mteb_test_embed_models
MODELS = [ MODELS = [
EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs",
@ -41,37 +42,14 @@ MODELS = [
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_models_mteb( def test_embed_models_mteb(hf_runner, vllm_runner,
hf_runner, model_info: EmbedModelInfo) -> None:
vllm_runner,
model_info: EmbedModelInfo,
) -> None:
from .mteb_utils import mteb_test_embed_models
mteb_test_embed_models(hf_runner, vllm_runner, model_info) mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", MODELS) @pytest.mark.parametrize("model_info", MODELS)
def test_models_correctness( def test_embed_models_correctness(hf_runner, vllm_runner,
hf_runner, model_info: EmbedModelInfo,
vllm_runner, example_prompts) -> None:
model_info: EmbedModelInfo, correctness_test_embed_models(hf_runner, vllm_runner, model_info,
example_prompts, example_prompts)
) -> None:
if not model_info.enable_test:
pytest.skip("Skipping test.")
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
with vllm_runner(model_info.name,
task="embed",
dtype=model_info.dtype,
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
with hf_runner(
model_info.name,
dtype=model_info.dtype,
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)

View File

@ -283,7 +283,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base",
trust_remote_code=True), trust_remote_code=True),
"NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501 "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe",
trust_remote_code=True), trust_remote_code=True),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),

View File

@ -2,7 +2,7 @@
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union from typing import Any, NamedTuple, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -13,9 +13,6 @@ from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from .registry import HF_EXAMPLE_MODELS from .registry import HF_EXAMPLE_MODELS
if TYPE_CHECKING:
from ..conftest import HfRunner
TokensText = tuple[list[int], str] TokensText = tuple[list[int], str]
@ -337,22 +334,3 @@ class EmbedModelInfo(NamedTuple):
architecture: str = "" architecture: str = ""
dtype: str = "auto" dtype: str = "auto"
enable_test: bool = True enable_test: bool = True
def run_embedding_correctness_test(
hf_model: "HfRunner",
inputs: list[str],
vllm_outputs: Sequence[list[float]],
dimensions: Optional[int] = None,
):
hf_outputs = hf_model.encode(inputs)
if dimensions:
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@ -572,13 +572,7 @@ class ModelConfig:
sliding_window = None sliding_window = None
self.original_max_model_len = self.max_model_len self.original_max_model_len = self.max_model_len
self.max_model_len = _get_and_verify_max_len( self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
hf_config=self.hf_text_config,
max_model_len=self.max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config)
self.served_model_name = get_served_model_name(self.model, self.served_model_name = get_served_model_name(self.model,
self.served_model_name) self.served_model_name)
self.multimodal_config = self._init_multimodal_config() self.multimodal_config = self._init_multimodal_config()
@ -1382,6 +1376,16 @@ class ModelConfig:
def matryoshka_dimensions(self): def matryoshka_dimensions(self):
return getattr(self.hf_config, "matryoshka_dimensions", None) return getattr(self.hf_config, "matryoshka_dimensions", None)
def get_and_verify_max_len(self, max_model_len: int):
max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config)
return max_model_len
BlockSize = Literal[1, 8, 16, 32, 64, 128] BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
@ -4469,13 +4473,7 @@ class VllmConfig:
def recalculate_max_model_len(self, max_model_len: int): def recalculate_max_model_len(self, max_model_len: int):
model_config = self.model_config model_config = self.model_config
max_model_len = _get_and_verify_max_len( max_model_len = model_config.get_and_verify_max_len(max_model_len)
hf_config=model_config.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=model_config.disable_sliding_window,
sliding_window_len=model_config.get_hf_config_sliding_window(),
spec_target_max_model_len=model_config.spec_target_max_model_len,
encoder_config=model_config.encoder_config)
self.model_config.max_model_len = max_model_len self.model_config.max_model_len = max_model_len
self.scheduler_config.max_model_len = max_model_len self.scheduler_config.max_model_len = max_model_len
self.compute_hash() self.compute_hash()