mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 22:35:01 +08:00
[Frontend] Using matryoshka_dimensions control the allowed output dimensions. (#16970)
This commit is contained in:
parent
b724afe343
commit
67309a1cb5
@ -159,14 +159,14 @@ For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model
|
|||||||
|
|
||||||
### Manually enable Matryoshka Embeddings
|
### Manually enable Matryoshka Embeddings
|
||||||
|
|
||||||
There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, we simply check the existence of the fields `is_matryoshka` or `matryoshka_dimensions` inside `config.json`.
|
There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions.
|
||||||
|
|
||||||
For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}` (offline) or `--hf_overrides '{"is_matryoshka": true}'` (online).
|
For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online).
|
||||||
|
|
||||||
Here is an example to serve a model with Matryoshka Embeddings enabled.
|
Here is an example to serve a model with Matryoshka Embeddings enabled.
|
||||||
|
|
||||||
```text
|
```text
|
||||||
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"is_matryoshka":true}'
|
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Offline Inference
|
### Offline Inference
|
||||||
@ -204,14 +204,14 @@ curl http://127.0.0.1:8000/v1/embeddings \
|
|||||||
"input": "Follow the white rabbit.",
|
"input": "Follow the white rabbit.",
|
||||||
"model": "jinaai/jina-embeddings-v3",
|
"model": "jinaai/jina-embeddings-v3",
|
||||||
"encoding_format": "float",
|
"encoding_format": "float",
|
||||||
"dimensions": 1
|
"dimensions": 32
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
Expected output:
|
Expected output:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{"id":"embd-0aab28c384d348c3b8f0eb783109dc5f","object":"list","created":1744195454,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-1.0]}],"usage":{"prompt_tokens":10,"total_tokens":10,"completion_tokens":0,"prompt_tokens_details":null}}
|
{"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}}
|
||||||
```
|
```
|
||||||
|
|
||||||
A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py>
|
A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py>
|
||||||
|
|||||||
@ -25,11 +25,11 @@ def main():
|
|||||||
responses = client.embeddings.create(
|
responses = client.embeddings.create(
|
||||||
input=["Follow the white rabbit."],
|
input=["Follow the white rabbit."],
|
||||||
model=model,
|
model=model,
|
||||||
dimensions=1,
|
dimensions=32,
|
||||||
)
|
)
|
||||||
|
|
||||||
for data in responses.data:
|
for data in responses.data:
|
||||||
print(data.embedding) # List of float of len 1
|
print(data.embedding) # List of float of len 32
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -11,11 +11,12 @@ import requests
|
|||||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
from ...models.embedding.utils import check_embeddings_close
|
from ...models.embedding.utils import correctness_test
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
MODEL_NAME = "intfloat/multilingual-e5-small"
|
MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||||
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||||
|
DTYPE = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -25,7 +26,7 @@ def server():
|
|||||||
"embed",
|
"embed",
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
"--dtype",
|
"--dtype",
|
||||||
"bfloat16",
|
DTYPE,
|
||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--max-model-len",
|
"--max-model-len",
|
||||||
"512",
|
"512",
|
||||||
@ -43,9 +44,17 @@ async def client(server):
|
|||||||
yield async_client
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def hf_model(hf_runner):
|
||||||
|
with hf_runner(MODEL_NAME, dtype=DTYPE,
|
||||||
|
is_sentence_transformer=True) as hf_model:
|
||||||
|
yield hf_model
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
|
async def test_single_embedding(hf_model, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"The chef prepared a delicious meal.",
|
"The chef prepared a delicious meal.",
|
||||||
]
|
]
|
||||||
@ -66,6 +75,9 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
assert embeddings.usage.prompt_tokens == 11
|
assert embeddings.usage.prompt_tokens == 11
|
||||||
assert embeddings.usage.total_tokens == 11
|
assert embeddings.usage.total_tokens == 11
|
||||||
|
|
||||||
|
vllm_outputs = [d.embedding for d in embeddings.data]
|
||||||
|
correctness_test(hf_model, input_texts, vllm_outputs)
|
||||||
|
|
||||||
# test using token IDs
|
# test using token IDs
|
||||||
input_tokens = [1, 1, 1, 1, 1]
|
input_tokens = [1, 1, 1, 1, 1]
|
||||||
embedding_response = await client.embeddings.create(
|
embedding_response = await client.embeddings.create(
|
||||||
@ -86,7 +98,8 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
|
async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
# test list[str]
|
# test list[str]
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"The cat sat on the mat.", "A feline was resting on a rug.",
|
"The cat sat on the mat.", "A feline was resting on a rug.",
|
||||||
@ -107,6 +120,9 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
assert embeddings.usage.prompt_tokens == 33
|
assert embeddings.usage.prompt_tokens == 33
|
||||||
assert embeddings.usage.total_tokens == 33
|
assert embeddings.usage.total_tokens == 33
|
||||||
|
|
||||||
|
vllm_outputs = [d.embedding for d in embeddings.data]
|
||||||
|
correctness_test(hf_model, input_texts, vllm_outputs)
|
||||||
|
|
||||||
# test list[list[int]]
|
# test list[list[int]]
|
||||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||||
[25, 32, 64, 77]]
|
[25, 32, 64, 77]]
|
||||||
@ -181,7 +197,7 @@ async def test_conversation_embedding(server: RemoteOpenAIServer,
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
|
async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str):
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"Hello my name is",
|
"Hello my name is",
|
||||||
@ -192,6 +208,7 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
encoding_format="float")
|
encoding_format="float")
|
||||||
float_data = [d.embedding for d in responses_float.data]
|
float_data = [d.embedding for d in responses_float.data]
|
||||||
|
correctness_test(hf_model, input_texts, float_data)
|
||||||
|
|
||||||
responses_base64 = await client.embeddings.create(input=input_texts,
|
responses_base64 = await client.embeddings.create(input=input_texts,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
@ -202,24 +219,13 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
|
|||||||
np.frombuffer(base64.b64decode(data.embedding),
|
np.frombuffer(base64.b64decode(data.embedding),
|
||||||
dtype="float32").tolist())
|
dtype="float32").tolist())
|
||||||
|
|
||||||
check_embeddings_close(
|
correctness_test(hf_model, input_texts, base64_data)
|
||||||
embeddings_0_lst=float_data,
|
|
||||||
embeddings_1_lst=base64_data,
|
|
||||||
name_0="float",
|
|
||||||
name_1="base64",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Default response is float32 decoded from base64 by OpenAI Client
|
# Default response is float32 decoded from base64 by OpenAI Client
|
||||||
responses_default = await client.embeddings.create(input=input_texts,
|
responses_default = await client.embeddings.create(input=input_texts,
|
||||||
model=model_name)
|
model=model_name)
|
||||||
default_data = [d.embedding for d in responses_default.data]
|
default_data = [d.embedding for d in responses_default.data]
|
||||||
|
correctness_test(hf_model, input_texts, default_data)
|
||||||
check_embeddings_close(
|
|
||||||
embeddings_0_lst=float_data,
|
|
||||||
embeddings_1_lst=default_data,
|
|
||||||
name_0="float",
|
|
||||||
name_1="default",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -3,73 +3,121 @@
|
|||||||
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
|
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
||||||
|
|
||||||
from ...models.embedding.utils import EmbedModelInfo
|
from ...conftest import HfRunner
|
||||||
|
from ...models.embedding.utils import EmbedModelInfo, correctness_test
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
EmbedModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
|
EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
|
||||||
EmbedModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
|
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||||
|
is_matryoshka=True,
|
||||||
|
matryoshka_dimensions=[256]),
|
||||||
]
|
]
|
||||||
|
|
||||||
input_texts = [
|
input_texts = [
|
||||||
"The chef prepared a delicious meal.",
|
"The chef prepared a delicious meal.",
|
||||||
] * 3
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.fixture(scope="module", params=MODELS)
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
def model_info(request):
|
||||||
async def test_validating_dimensions(model: EmbedModelInfo):
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=["bfloat16"])
|
||||||
|
def dtype(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server(model_info, dtype: str):
|
||||||
args = [
|
args = [
|
||||||
"--task",
|
"--task",
|
||||||
"embed",
|
"embed",
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
"--dtype",
|
"--dtype",
|
||||||
"bfloat16",
|
dtype,
|
||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--max-model-len",
|
"--max-model-len",
|
||||||
"512",
|
"512"
|
||||||
"--trust_remote_code"
|
|
||||||
]
|
]
|
||||||
with RemoteOpenAIServer(model.name, args) as remote_server:
|
|
||||||
client = remote_server.get_async_client()
|
|
||||||
|
|
||||||
async def make_request(dimensions):
|
if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5":
|
||||||
embedding_response = await client.embeddings.create(
|
# Manually enable Matryoshka Embeddings
|
||||||
model=model.name,
|
args.extend([
|
||||||
input=input_texts,
|
"--trust_remote_code", "--hf_overrides",
|
||||||
dimensions=dimensions,
|
'{"matryoshka_dimensions":[256]}'
|
||||||
encoding_format="float",
|
])
|
||||||
)
|
|
||||||
embeddings = EmbeddingResponse.model_validate(
|
|
||||||
embedding_response.model_dump(mode="json"))
|
|
||||||
|
|
||||||
assert embeddings.id is not None
|
with RemoteOpenAIServer(model_info.name, args) as remote_server:
|
||||||
assert len(embeddings.data) == 3
|
yield remote_server
|
||||||
assert len(embeddings.data[0].embedding) > 0
|
|
||||||
assert embeddings.usage.completion_tokens == 0
|
|
||||||
assert embeddings.usage.prompt_tokens > 0
|
|
||||||
assert embeddings.usage.total_tokens > 0
|
|
||||||
|
|
||||||
if dimensions is not None:
|
|
||||||
assert len(embeddings.data[0].embedding) == dimensions
|
|
||||||
|
|
||||||
if model.is_matryoshka:
|
@pytest.fixture(scope="module")
|
||||||
for dimensions in [None, 16]:
|
def hf_model(hf_runner, model_info, dtype: str):
|
||||||
await make_request(dimensions)
|
with hf_runner(model_info.name, dtype=dtype,
|
||||||
|
is_sentence_transformer=True) as hf_model:
|
||||||
|
yield hf_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_matryoshka(model_info: EmbedModelInfo,
|
||||||
|
server: RemoteOpenAIServer, hf_model: HfRunner):
|
||||||
|
client = server.get_async_client()
|
||||||
|
|
||||||
|
async def make_request_and_correctness_test(dimensions):
|
||||||
|
prompts = input_texts * 3
|
||||||
|
|
||||||
|
embedding_response = await client.embeddings.create(
|
||||||
|
model=model_info.name,
|
||||||
|
input=prompts,
|
||||||
|
dimensions=dimensions,
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
embeddings = EmbeddingResponse.model_validate(
|
||||||
|
embedding_response.model_dump(mode="json"))
|
||||||
|
|
||||||
|
assert embeddings.id is not None
|
||||||
|
assert len(embeddings.data) == 3
|
||||||
|
assert len(embeddings.data[0].embedding) > 0
|
||||||
|
assert embeddings.usage.completion_tokens == 0
|
||||||
|
assert embeddings.usage.prompt_tokens > 0
|
||||||
|
assert embeddings.usage.total_tokens > 0
|
||||||
|
|
||||||
|
if dimensions is not None:
|
||||||
|
assert len(embeddings.data[0].embedding) == dimensions
|
||||||
|
|
||||||
|
vllm_outputs = [d.embedding for d in embeddings.data]
|
||||||
|
correctness_test(hf_model, prompts, vllm_outputs, dimensions)
|
||||||
|
|
||||||
|
if model_info.is_matryoshka:
|
||||||
|
valid_dimensions: list[Optional[int]] = [None]
|
||||||
|
if model_info.matryoshka_dimensions is not None:
|
||||||
|
valid_dimensions += model_info.matryoshka_dimensions[:2]
|
||||||
|
|
||||||
|
for dimensions in valid_dimensions:
|
||||||
|
await make_request_and_correctness_test(dimensions)
|
||||||
|
|
||||||
|
invalid_dimensions: list[Optional[int]] = [-1]
|
||||||
|
if model_info.matryoshka_dimensions is not None:
|
||||||
|
assert 5 not in model_info.matryoshka_dimensions
|
||||||
|
invalid_dimensions.append(5)
|
||||||
|
|
||||||
|
for dimensions in invalid_dimensions:
|
||||||
with pytest.raises(openai.BadRequestError):
|
with pytest.raises(openai.BadRequestError):
|
||||||
for dimensions in [-1]:
|
await make_request_and_correctness_test(dimensions)
|
||||||
await make_request(dimensions)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for dimensions in [None]:
|
for dimensions in [None]:
|
||||||
await make_request(dimensions)
|
await make_request_and_correctness_test(dimensions)
|
||||||
|
|
||||||
|
for dimensions in [-1, 16]:
|
||||||
with pytest.raises(openai.BadRequestError):
|
with pytest.raises(openai.BadRequestError):
|
||||||
for dimensions in [-1, 16]:
|
await make_request_and_correctness_test(dimensions)
|
||||||
await make_request(dimensions)
|
|
||||||
|
|||||||
@ -153,14 +153,24 @@ def test_matryoshka(
|
|||||||
|
|
||||||
with vllm_runner(model, task="embed", dtype=dtype,
|
with vllm_runner(model, task="embed", dtype=dtype,
|
||||||
max_model_len=None) as vllm_model:
|
max_model_len=None) as vllm_model:
|
||||||
vllm_outputs = vllm_model.encode(
|
matryoshka_dimensions = (
|
||||||
example_prompts,
|
vllm_model.model.llm_engine.model_config.matryoshka_dimensions)
|
||||||
pooling_params=PoolingParams(dimensions=dimensions))
|
assert matryoshka_dimensions is not None
|
||||||
|
|
||||||
check_embeddings_close(
|
if dimensions not in matryoshka_dimensions:
|
||||||
embeddings_0_lst=hf_outputs,
|
with pytest.raises(ValueError):
|
||||||
embeddings_1_lst=vllm_outputs,
|
vllm_model.encode(
|
||||||
name_0="hf",
|
example_prompts,
|
||||||
name_1="vllm",
|
pooling_params=PoolingParams(dimensions=dimensions))
|
||||||
tol=1e-2,
|
else:
|
||||||
)
|
vllm_outputs = vllm_model.encode(
|
||||||
|
example_prompts,
|
||||||
|
pooling_params=PoolingParams(dimensions=dimensions))
|
||||||
|
|
||||||
|
check_embeddings_close(
|
||||||
|
embeddings_0_lst=hf_outputs,
|
||||||
|
embeddings_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
tol=1e-2,
|
||||||
|
)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -43,5 +43,24 @@ def matryoshka_fy(tensor, dimensions):
|
|||||||
class EmbedModelInfo(NamedTuple):
|
class EmbedModelInfo(NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
is_matryoshka: bool
|
is_matryoshka: bool
|
||||||
|
matryoshka_dimensions: Optional[list[int]] = None
|
||||||
architecture: str = ""
|
architecture: str = ""
|
||||||
enable_test: bool = True
|
enable_test: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
def correctness_test(hf_model,
|
||||||
|
inputs,
|
||||||
|
vllm_outputs: Sequence[list[float]],
|
||||||
|
dimensions: Optional[int] = None):
|
||||||
|
|
||||||
|
hf_outputs = hf_model.encode(inputs)
|
||||||
|
if dimensions:
|
||||||
|
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
|
||||||
|
|
||||||
|
check_embeddings_close(
|
||||||
|
embeddings_0_lst=hf_outputs,
|
||||||
|
embeddings_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
tol=1e-2,
|
||||||
|
)
|
||||||
|
|||||||
@ -1248,6 +1248,10 @@ class ModelConfig:
|
|||||||
return (hasattr(self.hf_config, "matryoshka_dimensions")
|
return (hasattr(self.hf_config, "matryoshka_dimensions")
|
||||||
or getattr(self.hf_config, "is_matryoshka", False))
|
or getattr(self.hf_config, "is_matryoshka", False))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def matryoshka_dimensions(self):
|
||||||
|
return getattr(self.hf_config, "matryoshka_dimensions", None)
|
||||||
|
|
||||||
|
|
||||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
|
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
|
||||||
|
|||||||
@ -35,7 +35,16 @@ class PoolingParams(
|
|||||||
f'Model "{model_config.served_model_name}" does not '
|
f'Model "{model_config.served_model_name}" does not '
|
||||||
f'support matryoshka representation, '
|
f'support matryoshka representation, '
|
||||||
f'changing output dimensions will lead to poor results.')
|
f'changing output dimensions will lead to poor results.')
|
||||||
if self.dimensions < 1:
|
|
||||||
|
mds = model_config.matryoshka_dimensions
|
||||||
|
if mds is not None:
|
||||||
|
if self.dimensions not in mds:
|
||||||
|
raise ValueError(
|
||||||
|
f'Model "{model_config.served_model_name}" '
|
||||||
|
f'only supports {str(mds)} matryoshka dimensions, '
|
||||||
|
f'use other output dimensions will '
|
||||||
|
f'lead to poor results.')
|
||||||
|
elif self.dimensions < 1:
|
||||||
raise ValueError("Dimensions must be greater than 0")
|
raise ValueError("Dimensions must be greater than 0")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user