mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 11:57:14 +08:00
[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
parent
d4d1a6024f
commit
f54f85129e
@ -26,6 +26,12 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
|
|||||||
python examples/offline_inference/pooling/embed_matryoshka_fy.py
|
python examples/offline_inference/pooling/embed_matryoshka_fy.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Multi vector retrieval usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python examples/offline_inference/pooling/multi_vector_retrieval.py
|
||||||
|
```
|
||||||
|
|
||||||
## Named Entity Recognition (NER) usage
|
## Named Entity Recognition (NER) usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
56
examples/offline_inference/pooling/multi_vector_retrieval.py
Normal file
56
examples/offline_inference/pooling/multi_vector_retrieval.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
|
from vllm import LLM, EngineArgs
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
# Set example specific arguments
|
||||||
|
parser.set_defaults(
|
||||||
|
model="BAAI/bge-m3",
|
||||||
|
runner="pooling",
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: Namespace):
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
# You should pass runner="pooling" for embedding models
|
||||||
|
llm = LLM(**vars(args))
|
||||||
|
|
||||||
|
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||||
|
outputs = llm.embed(prompts)
|
||||||
|
|
||||||
|
# Print the outputs.
|
||||||
|
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||||
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
embeds = output.outputs.embedding
|
||||||
|
print(len(embeds))
|
||||||
|
|
||||||
|
# Generate embedding for each token. The output is a list of PoolingRequestOutput.
|
||||||
|
outputs = llm.encode(prompts, pooling_task="token_embed")
|
||||||
|
|
||||||
|
# Print the outputs.
|
||||||
|
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||||
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
multi_vector = output.outputs.data
|
||||||
|
print(multi_vector.shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
@ -40,7 +40,7 @@ def main():
|
|||||||
model_impl="terratorch",
|
model_impl="terratorch",
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
pooling_params = PoolingParams(task="token_classify", activation=False)
|
||||||
pooler_output = llm.encode(
|
pooler_output = llm.encode(
|
||||||
img_prompt,
|
img_prompt,
|
||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
|
|||||||
@ -18,6 +18,12 @@ python examples/online_serving/pooling/embedding_embed_dtype_client.py
|
|||||||
python examples/online_serving/pooling/jinaai_rerank_client.py
|
python examples/online_serving/pooling/jinaai_rerank_client.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Multi vector retrieval usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python examples/online_serving/pooling/multi_vector_retrieval_client.py
|
||||||
|
```
|
||||||
|
|
||||||
## Named Entity Recognition (NER) usage
|
## Named Entity Recognition (NER) usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@ -0,0 +1,54 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""
|
||||||
|
Example online usage of Pooling API for multi vector retrieval.
|
||||||
|
|
||||||
|
Run `vllm serve <model> --runner pooling`
|
||||||
|
to start up the server in vLLM. e.g.
|
||||||
|
|
||||||
|
vllm serve BAAI/bge-m3
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||||
|
headers = {"User-Agent": "Test Client"}
|
||||||
|
response = requests.post(api_url, headers=headers, json=prompt)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--model", type=str, default="BAAI/bge-m3")
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
api_url = f"http://{args.host}:{args.port}/pooling"
|
||||||
|
model_name = args.model
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
prompt = {"model": model_name, "input": prompts}
|
||||||
|
|
||||||
|
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||||
|
for output in pooling_response.json()["data"]:
|
||||||
|
multi_vector = torch.tensor(output["data"])
|
||||||
|
print(multi_vector.shape)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
@ -1011,8 +1011,12 @@ class VllmRunner:
|
|||||||
req_outputs = self.llm.embed(inputs, *args, **kwargs)
|
req_outputs = self.llm.embed(inputs, *args, **kwargs)
|
||||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||||
|
|
||||||
def encode(self, prompts: list[str]) -> list[list[float]]:
|
def token_embed(self, prompts: list[str]) -> list[list[float]]:
|
||||||
req_outputs = self.llm.encode(prompts)
|
req_outputs = self.llm.encode(prompts, pooling_task="token_embed")
|
||||||
|
return [req_output.outputs.data for req_output in req_outputs]
|
||||||
|
|
||||||
|
def token_classify(self, prompts: list[str]) -> list[list[float]]:
|
||||||
|
req_outputs = self.llm.encode(prompts, pooling_task="token_classify")
|
||||||
return [req_output.outputs.data for req_output in req_outputs]
|
return [req_output.outputs.data for req_output in req_outputs]
|
||||||
|
|
||||||
def reward(self, prompts: list[str]) -> list[list[float]]:
|
def reward(self, prompts: list[str]) -> list[list[float]]:
|
||||||
|
|||||||
@ -63,7 +63,7 @@ def test_encode_api(llm: LLM):
|
|||||||
# chunked prefill does not support all pooling
|
# chunked prefill does not support all pooling
|
||||||
err_msg = "pooling_task must be one of.+"
|
err_msg = "pooling_task must be one of.+"
|
||||||
with pytest.raises(ValueError, match=err_msg):
|
with pytest.raises(ValueError, match=err_msg):
|
||||||
llm.encode(prompts, use_tqdm=False)
|
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
|
||||||
|
|
||||||
|
|
||||||
def test_score_api(llm: LLM):
|
def test_score_api(llm: LLM):
|
||||||
|
|||||||
@ -35,6 +35,13 @@ def llm():
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
def test_encode_api(llm: LLM):
|
||||||
|
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
|
||||||
|
multi_vector = outputs[0].outputs.data
|
||||||
|
assert multi_vector.shape == (11, 384)
|
||||||
|
|
||||||
|
|
||||||
def test_pooling_params(llm: LLM):
|
def test_pooling_params(llm: LLM):
|
||||||
def get_outputs(normalize):
|
def get_outputs(normalize):
|
||||||
outputs = llm.embed(
|
outputs = llm.embed(
|
||||||
|
|||||||
@ -57,20 +57,24 @@ def test_multiple_pooling_params(llm: LLM):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Multiple PoolingParams should be matched with each prompt
|
# Multiple PoolingParams should be matched with each prompt
|
||||||
outputs = llm.encode(PROMPTS, pooling_params=pooling_params)
|
outputs = llm.encode(PROMPTS, pooling_params=pooling_params, pooling_task="embed")
|
||||||
assert len(PROMPTS) == len(outputs)
|
assert len(PROMPTS) == len(outputs)
|
||||||
|
|
||||||
# Exception raised, if the size of params does not match the size of prompts
|
# Exception raised, if the size of params does not match the size of prompts
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3])
|
outputs = llm.encode(
|
||||||
|
PROMPTS, pooling_params=pooling_params[:3], pooling_task="embed"
|
||||||
|
)
|
||||||
|
|
||||||
# Single PoolingParams should be applied to every prompt
|
# Single PoolingParams should be applied to every prompt
|
||||||
single_pooling_params = PoolingParams()
|
single_pooling_params = PoolingParams()
|
||||||
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params)
|
outputs = llm.encode(
|
||||||
|
PROMPTS, pooling_params=single_pooling_params, pooling_task="embed"
|
||||||
|
)
|
||||||
assert len(PROMPTS) == len(outputs)
|
assert len(PROMPTS) == len(outputs)
|
||||||
|
|
||||||
# pooling_params is None, default params should be applied
|
# pooling_params is None, default params should be applied
|
||||||
outputs = llm.encode(PROMPTS, pooling_params=None)
|
outputs = llm.encode(PROMPTS, pooling_params=None, pooling_task="embed")
|
||||||
assert len(PROMPTS) == len(outputs)
|
assert len(PROMPTS) == len(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -36,22 +36,23 @@ def llm():
|
|||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
|
||||||
def test_pooling_params(llm: LLM):
|
def test_pooling_params(llm: LLM):
|
||||||
def get_outputs(softmax):
|
def get_outputs(activation):
|
||||||
outputs = llm.reward(
|
outputs = llm.reward(
|
||||||
prompts, pooling_params=PoolingParams(softmax=softmax), use_tqdm=False
|
prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False
|
||||||
)
|
)
|
||||||
return torch.cat([x.outputs.data for x in outputs])
|
return torch.cat([x.outputs.data for x in outputs])
|
||||||
|
|
||||||
default = get_outputs(softmax=None)
|
default = get_outputs(activation=None)
|
||||||
w_softmax = get_outputs(softmax=True)
|
w_activation = get_outputs(activation=True)
|
||||||
wo_softmax = get_outputs(softmax=False)
|
wo_activation = get_outputs(activation=False)
|
||||||
|
|
||||||
assert torch.allclose(default, w_softmax, atol=1e-2), "Default should use softmax."
|
assert torch.allclose(default, w_activation, atol=1e-2), (
|
||||||
assert not torch.allclose(w_softmax, wo_softmax, atol=1e-2), (
|
"Default should use activation."
|
||||||
"wo_softmax should not use softmax."
|
|
||||||
)
|
)
|
||||||
assert torch.allclose(softmax(wo_softmax), w_softmax, atol=1e-2), (
|
assert not torch.allclose(w_activation, wo_activation, atol=1e-2), (
|
||||||
"w_softmax should be close to softmax(wo_softmax)."
|
"wo_activation should not use activation."
|
||||||
|
)
|
||||||
|
assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), (
|
||||||
|
"w_activation should be close to activation(wo_activation)."
|
||||||
)
|
)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from tests.utils import RemoteOpenAIServer
|
|||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
EMBED_DTYPE_TO_TORCH_DTYPE,
|
EMBED_DTYPE_TO_TORCH_DTYPE,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
|
PoolingResponse,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@ -509,3 +510,20 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str):
|
|||||||
assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
|
assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
|
||||||
"w_normal should be close to normal(wo_normal)."
|
"w_normal should be close to normal(wo_normal)."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||||
|
input_text = ["The chef prepared a delicious meal."]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("pooling"),
|
||||||
|
json={"model": model_name, "input": input_text, "encoding_format": "float"},
|
||||||
|
)
|
||||||
|
|
||||||
|
poolings = PoolingResponse.model_validate(response.json())
|
||||||
|
|
||||||
|
assert len(poolings.data) == 1
|
||||||
|
assert len(poolings.data[0].data) == 11
|
||||||
|
assert len(poolings.data[0].data[0]) == 384
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.entrypoints.openai.protocol import RerankResponse
|
from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse
|
||||||
|
|
||||||
MODEL_NAME = "BAAI/bge-reranker-base"
|
MODEL_NAME = "BAAI/bge-reranker-base"
|
||||||
DTYPE = "bfloat16"
|
DTYPE = "bfloat16"
|
||||||
@ -159,3 +159,20 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str):
|
|||||||
assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), (
|
assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), (
|
||||||
"w_activation should be close to activation(wo_activation)."
|
"w_activation should be close to activation(wo_activation)."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||||
|
input_text = ["The chef prepared a delicious meal."]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
server.url_for("pooling"),
|
||||||
|
json={"model": model_name, "input": input_text, "encoding_format": "float"},
|
||||||
|
)
|
||||||
|
|
||||||
|
poolings = PoolingResponse.model_validate(response.json())
|
||||||
|
|
||||||
|
assert len(poolings.data) == 1
|
||||||
|
assert len(poolings.data[0].data) == 11
|
||||||
|
assert len(poolings.data[0].data[0]) == 1
|
||||||
|
|||||||
45
tests/models/language/pooling/test_multi_vector_retrieval.py
Normal file
45
tests/models/language/pooling/test_multi_vector_retrieval.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
from tests.models.utils import check_embeddings_close
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
["BAAI/bge-m3"],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@torch.inference_mode
|
||||||
|
def test_embed_models(hf_runner, vllm_runner, example_prompts, model: str, dtype: str):
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
runner="pooling",
|
||||||
|
max_model_len=None,
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.token_embed(example_prompts)
|
||||||
|
|
||||||
|
with hf_runner(
|
||||||
|
model,
|
||||||
|
auto_cls=AutoModel,
|
||||||
|
) as hf_model:
|
||||||
|
tokenizer = hf_model.tokenizer
|
||||||
|
hf_outputs = []
|
||||||
|
for prompt in example_prompts:
|
||||||
|
inputs = tokenizer([prompt], return_tensors="pt")
|
||||||
|
inputs = hf_model.wrap_device(inputs)
|
||||||
|
output = hf_model.model(**inputs)
|
||||||
|
embedding = output.last_hidden_state[0].float()
|
||||||
|
# normal
|
||||||
|
hf_outputs.append(embedding.cpu())
|
||||||
|
|
||||||
|
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||||
|
check_embeddings_close(
|
||||||
|
embeddings_0_lst=hf_output,
|
||||||
|
embeddings_1_lst=vllm_output,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
tol=1e-2,
|
||||||
|
)
|
||||||
@ -93,7 +93,7 @@ def test_embed_models_using_normalize(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
def test_reward_models_using_softmax(
|
def test_reward_models_using_activation(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
@ -104,22 +104,64 @@ def test_reward_models_using_softmax(
|
|||||||
model,
|
model,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
pooler_config=PoolerConfig(softmax=False),
|
pooler_config=PoolerConfig(activation=False),
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
wo_softmax = vllm_model.encode(example_prompts)
|
wo_activation = vllm_model.reward(example_prompts)
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model, max_model_len=1024, dtype=dtype, pooler_config=PoolerConfig(softmax=True)
|
model,
|
||||||
|
max_model_len=1024,
|
||||||
|
dtype=dtype,
|
||||||
|
pooler_config=PoolerConfig(activation=True),
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
w_softmax = vllm_model.encode(example_prompts)
|
w_activation = vllm_model.reward(example_prompts)
|
||||||
|
|
||||||
for wo, w in zip(wo_softmax, w_softmax):
|
for wo, w in zip(wo_activation, w_activation):
|
||||||
wo = torch.tensor(wo)
|
wo = torch.tensor(wo)
|
||||||
w = torch.tensor(w)
|
w = torch.tensor(w)
|
||||||
|
|
||||||
assert not torch.allclose(wo, w, atol=1e-2), (
|
assert not torch.allclose(wo, w, atol=1e-2), (
|
||||||
"pooler_config softmax is not working"
|
"pooler_config activation is not working"
|
||||||
)
|
)
|
||||||
assert torch.allclose(softmax(wo), w, atol=1e-2), (
|
assert torch.allclose(softmax(wo), w, atol=1e-2), (
|
||||||
"w_softmax should be close to softmax(wo_softmax)."
|
"w_activation should be close to activation(wo_activation)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"intfloat/multilingual-e5-small",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_multi_vector_retrieval_models_using_normalize(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
max_model_len=512,
|
||||||
|
dtype=dtype,
|
||||||
|
pooler_config=PoolerConfig(normalize=False),
|
||||||
|
) as vllm_model:
|
||||||
|
wo_normalize = vllm_model.token_embed(example_prompts)
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
max_model_len=512,
|
||||||
|
dtype=dtype,
|
||||||
|
pooler_config=PoolerConfig(normalize=True),
|
||||||
|
) as vllm_model:
|
||||||
|
w_normalize = vllm_model.token_embed(example_prompts)
|
||||||
|
|
||||||
|
for wo, w in zip(wo_normalize, w_normalize):
|
||||||
|
assert not torch.allclose(wo, w, atol=1e-2), (
|
||||||
|
"pooler_config normalize is not working"
|
||||||
|
)
|
||||||
|
assert torch.allclose(F.normalize(wo, p=2, dim=-1), w, atol=1e-2), (
|
||||||
|
"w_normal should be close to normal(wo_normal)."
|
||||||
)
|
)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ def test_bert_models(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.encode(example_prompts)
|
vllm_outputs = vllm_model.token_classify(example_prompts)
|
||||||
|
|
||||||
with hf_runner(
|
with hf_runner(
|
||||||
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
|
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
|
||||||
@ -50,7 +50,7 @@ def test_modernbert_models(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.encode(example_prompts)
|
vllm_outputs = vllm_model.token_classify(example_prompts)
|
||||||
|
|
||||||
with hf_runner(
|
with hf_runner(
|
||||||
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
|
model, dtype=dtype, auto_cls=AutoModelForTokenClassification
|
||||||
|
|||||||
@ -39,7 +39,7 @@ def _run_test(
|
|||||||
max_num_seqs=32,
|
max_num_seqs=32,
|
||||||
default_torch_num_threads=1,
|
default_torch_num_threads=1,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_model.encode(prompt)
|
vllm_model.llm.encode(prompt, pooling_task="token_classify")
|
||||||
|
|
||||||
|
|
||||||
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class MyGemma2Embedding(nn.Module):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||||
"embed": Pooler.for_embed(pooler_config),
|
"embed": Pooler.for_embed(pooler_config),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -93,7 +93,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
|||||||
out_data_format="b64_json",
|
out_data_format="b64_json",
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
pooling_params = PoolingParams(activation=False)
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model_name,
|
model_name,
|
||||||
@ -108,8 +108,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
|||||||
io_processor_plugin="prithvi_to_tiff",
|
io_processor_plugin="prithvi_to_tiff",
|
||||||
) as llm_runner:
|
) as llm_runner:
|
||||||
pooler_output = llm_runner.get_llm().encode(
|
pooler_output = llm_runner.get_llm().encode(
|
||||||
img_prompt,
|
img_prompt, pooling_params=pooling_params, pooling_task="token_classify"
|
||||||
pooling_params=pooling_params,
|
|
||||||
)
|
)
|
||||||
output = pooler_output[0].outputs
|
output = pooler_output[0].outputs
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.models.utils import EmbedModelInfo
|
from tests.models.utils import EmbedModelInfo
|
||||||
from vllm import PoolingParams
|
from vllm import PoolingParams
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig, PoolerConfig
|
||||||
|
|
||||||
EMBEDDING_MODELS = [
|
EMBEDDING_MODELS = [
|
||||||
EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
|
EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
|
||||||
@ -15,6 +17,15 @@ EMBEDDING_MODELS = [
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
classify_parameters = ["activation"]
|
||||||
|
embed_parameters = ["dimensions", "normalize"]
|
||||||
|
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
|
class MockModelConfig:
|
||||||
|
pooler_config: PoolerConfig
|
||||||
|
|
||||||
|
|
||||||
def test_task():
|
def test_task():
|
||||||
pooling_params = PoolingParams()
|
pooling_params = PoolingParams()
|
||||||
@ -24,25 +35,27 @@ def test_task():
|
|||||||
pooling_params.verify(task="score")
|
pooling_params.verify(task="score")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
pooling_params.verify(task="encode")
|
pooling_params.verify(task="classify")
|
||||||
|
|
||||||
|
|
||||||
def test_embed():
|
def test_embed():
|
||||||
task = "embed"
|
task = "embed"
|
||||||
|
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
||||||
|
|
||||||
pooling_params = PoolingParams(normalize=None)
|
pooling_params = PoolingParams(normalize=None)
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(normalize=True)
|
pooling_params = PoolingParams(normalize=True)
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(normalize=False)
|
pooling_params = PoolingParams(normalize=False)
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
invalid_parameters = ["activation", "softmax"]
|
invalid_parameters = classify_parameters + step_pooling_parameters
|
||||||
for p in invalid_parameters:
|
for p in invalid_parameters:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
pooling_params = PoolingParams(**{p: True})
|
pooling_params = PoolingParams(**{p: True})
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
||||||
@ -73,35 +86,71 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("task", ["score", "classify"])
|
@pytest.mark.parametrize("task", ["score", "classify"])
|
||||||
def test_classify(task):
|
def test_classify(task):
|
||||||
|
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
||||||
|
|
||||||
pooling_params = PoolingParams(activation=None)
|
pooling_params = PoolingParams(activation=None)
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(activation=True)
|
pooling_params = PoolingParams(activation=True)
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(activation=False)
|
pooling_params = PoolingParams(activation=False)
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
invalid_parameters = ["dimensions", "normalize", "softmax"]
|
invalid_parameters = embed_parameters + step_pooling_parameters
|
||||||
for p in invalid_parameters:
|
for p in invalid_parameters:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
pooling_params = PoolingParams(**{p: True})
|
pooling_params = PoolingParams(**{p: True})
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
|
|
||||||
def test_encode():
|
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
|
||||||
task = "encode"
|
def test_token_embed(pooling_type: str):
|
||||||
pooling_params = PoolingParams(softmax=None)
|
task = "token_embed"
|
||||||
pooling_params.verify(task=task)
|
model_config = MockModelConfig(
|
||||||
|
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
||||||
|
)
|
||||||
|
|
||||||
pooling_params = PoolingParams(softmax=True)
|
pooling_params = PoolingParams(normalize=None)
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
pooling_params = PoolingParams(softmax=False)
|
pooling_params = PoolingParams(normalize=True)
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
|
pooling_params = PoolingParams(normalize=False)
|
||||||
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
|
invalid_parameters = classify_parameters
|
||||||
|
if pooling_type != "STEP":
|
||||||
|
invalid_parameters = classify_parameters + step_pooling_parameters
|
||||||
|
|
||||||
invalid_parameters = ["dimensions", "normalize", "activation"]
|
|
||||||
for p in invalid_parameters:
|
for p in invalid_parameters:
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
pooling_params = PoolingParams(**{p: True})
|
pooling_params = PoolingParams(**{p: True})
|
||||||
pooling_params.verify(task=task)
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
|
||||||
|
def test_token_classify(pooling_type: str):
|
||||||
|
task = "token_classify"
|
||||||
|
model_config = MockModelConfig(
|
||||||
|
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
pooling_params = PoolingParams(activation=None)
|
||||||
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
|
pooling_params = PoolingParams(activation=True)
|
||||||
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
|
pooling_params = PoolingParams(activation=False)
|
||||||
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|
||||||
|
invalid_parameters = embed_parameters
|
||||||
|
if pooling_type != "STEP":
|
||||||
|
invalid_parameters = embed_parameters + step_pooling_parameters
|
||||||
|
|
||||||
|
for p in invalid_parameters:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
pooling_params = PoolingParams(**{p: True})
|
||||||
|
pooling_params.verify(task=task, model_config=model_config)
|
||||||
|
|||||||
@ -951,7 +951,7 @@ class LLM:
|
|||||||
truncate_prompt_tokens: int | None = None,
|
truncate_prompt_tokens: int | None = None,
|
||||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||||
pooling_task: PoolingTask = "encode",
|
pooling_task: PoolingTask | None = None,
|
||||||
tokenization_kwargs: dict[str, Any] | None = None,
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||||||
) -> list[PoolingRequestOutput]:
|
) -> list[PoolingRequestOutput]:
|
||||||
"""Apply pooling to the hidden states corresponding to the input
|
"""Apply pooling to the hidden states corresponding to the input
|
||||||
@ -986,25 +986,24 @@ class LLM:
|
|||||||
instead pass them via the `inputs` parameter.
|
instead pass them via the `inputs` parameter.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.supported_tasks == ["encode"] and pooling_task is None:
|
error_str = (
|
||||||
pooling_task = "encode"
|
"pooling_task required for `LLM.encode`\n"
|
||||||
|
"Please use one of the more specific methods or set the "
|
||||||
|
"pooling_task when using `LLM.encode`:\n"
|
||||||
|
" - For embeddings, use `LLM.embed(...)` "
|
||||||
|
'or `pooling_task="embed"`.\n'
|
||||||
|
" - For classification logits, use `LLM.classify(...)` "
|
||||||
|
'or `pooling_task="classify"`.\n'
|
||||||
|
" - For similarity scores, use `LLM.score(...)`.\n"
|
||||||
|
" - For rewards, use `LLM.reward(...)` "
|
||||||
|
'or `pooling_task="token_classify"`\n'
|
||||||
|
" - For token classification, "
|
||||||
|
'use `pooling_task="token_classify"`\n'
|
||||||
|
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
|
||||||
|
)
|
||||||
|
|
||||||
if pooling_task is None:
|
if pooling_task is None:
|
||||||
pooling_task = "embed" if "embed" in self.supported_tasks else "encode"
|
raise ValueError(error_str)
|
||||||
|
|
||||||
logger.warning_once(
|
|
||||||
"`LLM.encode` is currently using `pooling_task = %s`.\n"
|
|
||||||
"Please use one of the more specific methods or set the "
|
|
||||||
"task directly when using `LLM.encode`:\n"
|
|
||||||
" - For embeddings, use `LLM.embed(...)` "
|
|
||||||
'or `pooling_task="embed"`.\n'
|
|
||||||
" - For classification logits, use `LLM.classify(...)` "
|
|
||||||
'or `pooling_task="classify"`.\n'
|
|
||||||
" - For rewards, use `LLM.reward(...)` "
|
|
||||||
'or `pooling_task="reward"`\n'
|
|
||||||
" - For similarity scores, use `LLM.score(...)`.",
|
|
||||||
pooling_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
runner_type = model_config.runner_type
|
runner_type = model_config.runner_type
|
||||||
@ -1206,7 +1205,7 @@ class LLM:
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
pooling_task="encode",
|
pooling_task="token_classify",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _embedding_score(
|
def _embedding_score(
|
||||||
|
|||||||
@ -1748,16 +1748,19 @@ async def init_app_state(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
state.openai_serving_pooling = (
|
state.openai_serving_pooling = (
|
||||||
OpenAIServingPooling(
|
(
|
||||||
engine_client,
|
OpenAIServingPooling(
|
||||||
state.openai_serving_models,
|
engine_client,
|
||||||
request_logger=request_logger,
|
state.openai_serving_models,
|
||||||
chat_template=resolved_chat_template,
|
supported_tasks=supported_tasks,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
request_logger=request_logger,
|
||||||
trust_request_chat_template=args.trust_request_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
log_error_stack=args.log_error_stack,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
|
trust_request_chat_template=args.trust_request_chat_template,
|
||||||
|
log_error_stack=args.log_error_stack,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if "encode" in supported_tasks
|
if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
state.openai_serving_embedding = (
|
state.openai_serving_embedding = (
|
||||||
|
|||||||
@ -1682,7 +1682,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
|||||||
When using plugins IOProcessor plugins, the actual input is processed
|
When using plugins IOProcessor plugins, the actual input is processed
|
||||||
by the plugin itself. Hence, we use a generic type for the request data
|
by the plugin itself. Hence, we use a generic type for the request data
|
||||||
"""
|
"""
|
||||||
softmax: bool = True
|
activation: bool = False
|
||||||
|
|
||||||
embed_dtype: str = Field(
|
embed_dtype: str = Field(
|
||||||
default="float32",
|
default="float32",
|
||||||
@ -1693,7 +1693,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(task="encode", softmax=self.softmax)
|
return PoolingParams(task="token_classify", activation=self.activation)
|
||||||
|
|
||||||
|
|
||||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from vllm.entrypoints.renderer import RenderConfig
|
|||||||
from vllm.entrypoints.utils import _validate_truncation_size
|
from vllm.entrypoints.utils import _validate_truncation_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||||
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils import merge_async_iterators
|
from vllm.utils import merge_async_iterators
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -62,6 +63,7 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
models: OpenAIServingModels,
|
models: OpenAIServingModels,
|
||||||
*,
|
*,
|
||||||
|
supported_tasks: tuple[SupportedTask, ...],
|
||||||
request_logger: RequestLogger | None,
|
request_logger: RequestLogger | None,
|
||||||
chat_template: str | None,
|
chat_template: str | None,
|
||||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||||
@ -75,6 +77,7 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
log_error_stack=log_error_stack,
|
log_error_stack=log_error_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.supported_tasks = supported_tasks
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.chat_template_content_format: Final = chat_template_content_format
|
self.chat_template_content_format: Final = chat_template_content_format
|
||||||
self.trust_request_chat_template = trust_request_chat_template
|
self.trust_request_chat_template = trust_request_chat_template
|
||||||
@ -178,8 +181,17 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
|
if "token_embed" in self.supported_tasks:
|
||||||
|
pooling_task = "token_embed"
|
||||||
|
elif "token_classify" in self.supported_tasks:
|
||||||
|
pooling_task = "token_classify"
|
||||||
|
else:
|
||||||
|
return self.create_error_response(
|
||||||
|
f"pooling_task must be one of {self.supported_tasks}."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pooling_params.verify("encode", self.model_config)
|
pooling_params.verify(pooling_task, self.model_config)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
|||||||
@ -64,66 +64,6 @@ class PoolingParamsUpdate:
|
|||||||
params.requires_token_ids = self.requires_token_ids
|
params.requires_token_ids = self.requires_token_ids
|
||||||
|
|
||||||
|
|
||||||
class Pooler(nn.Module, ABC):
|
|
||||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def for_encode(pooler_config: PoolerConfig):
|
|
||||||
if pooler_config.pooling_type == "STEP":
|
|
||||||
return StepPooler()
|
|
||||||
|
|
||||||
resolved_config = ResolvedPoolingConfig(
|
|
||||||
task="encode", pooling_type=PoolingType.ALL
|
|
||||||
)
|
|
||||||
|
|
||||||
return SimplePooler.from_config(resolved_config)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def for_embed(pooler_config: PoolerConfig):
|
|
||||||
resolved_config = ResolvedPoolingConfig.from_config(
|
|
||||||
task="embed",
|
|
||||||
pooler_config=pooler_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return SimplePooler.from_config(resolved_config)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def for_classify(
|
|
||||||
pooler_config: PoolerConfig,
|
|
||||||
classifier: ClassifierFn | None,
|
|
||||||
):
|
|
||||||
resolved_config = ResolvedPoolingConfig.from_config(
|
|
||||||
task="classify",
|
|
||||||
pooler_config=pooler_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
|
||||||
|
|
||||||
return ClassifierPooler(
|
|
||||||
pooling=pooling,
|
|
||||||
classifier=classifier,
|
|
||||||
)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
||||||
"""Determine which pooling tasks are supported."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
|
||||||
"""
|
|
||||||
Construct the updated pooling parameters to use for a supported task.
|
|
||||||
"""
|
|
||||||
return PoolingParamsUpdate()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
|
||||||
pooling_metadata: PoolingMetadata,
|
|
||||||
) -> PoolerOutput:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_lens(
|
def get_prompt_lens(
|
||||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||||
pooling_metadata: PoolingMetadata,
|
pooling_metadata: PoolingMetadata,
|
||||||
@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC):
|
|||||||
|
|
||||||
class CLSPool(PoolingMethod):
|
class CLSPool(PoolingMethod):
|
||||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
return {"encode", "embed", "classify", "score"}
|
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||||
|
|
||||||
def forward_all(
|
def forward_all(
|
||||||
self,
|
self,
|
||||||
@ -253,7 +193,7 @@ class CLSPool(PoolingMethod):
|
|||||||
|
|
||||||
class LastPool(PoolingMethod):
|
class LastPool(PoolingMethod):
|
||||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
return {"encode", "embed", "classify", "score"}
|
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||||
|
|
||||||
def forward_all(
|
def forward_all(
|
||||||
self,
|
self,
|
||||||
@ -265,7 +205,7 @@ class LastPool(PoolingMethod):
|
|||||||
|
|
||||||
class AllPool(PoolingMethod):
|
class AllPool(PoolingMethod):
|
||||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
return {"encode"}
|
return {"token_embed", "token_classify"}
|
||||||
|
|
||||||
def forward_all(
|
def forward_all(
|
||||||
self,
|
self,
|
||||||
@ -284,7 +224,7 @@ class AllPool(PoolingMethod):
|
|||||||
|
|
||||||
class MeanPool(PoolingMethod):
|
class MeanPool(PoolingMethod):
|
||||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
return {"encode", "embed", "classify", "score"}
|
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||||
|
|
||||||
def forward_all(
|
def forward_all(
|
||||||
self,
|
self,
|
||||||
@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation):
|
|||||||
return self.fn(pooled_data)
|
return self.fn(pooled_data)
|
||||||
|
|
||||||
|
|
||||||
|
class Pooler(nn.Module, ABC):
|
||||||
|
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def for_token_embed(pooler_config: PoolerConfig):
|
||||||
|
head = TokenEmbeddingPoolerHead()
|
||||||
|
|
||||||
|
if pooler_config.pooling_type == "STEP":
|
||||||
|
return StepPooler(head=head)
|
||||||
|
|
||||||
|
return AllPooler(head=head)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def for_token_classify(
|
||||||
|
pooler_config: PoolerConfig,
|
||||||
|
classifier: ClassifierFn | None = None,
|
||||||
|
act_fn: PoolerActivation | str | None = None,
|
||||||
|
):
|
||||||
|
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||||
|
|
||||||
|
if pooler_config.pooling_type == "STEP":
|
||||||
|
return StepPooler(head=head)
|
||||||
|
|
||||||
|
return AllPooler(head=head)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def for_embed(pooler_config: PoolerConfig):
|
||||||
|
resolved_config = ResolvedPoolingConfig.from_config(
|
||||||
|
task="embed",
|
||||||
|
pooler_config=pooler_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||||
|
head = EmbeddingPoolerHead()
|
||||||
|
|
||||||
|
return SimplePooler(pooling=pooling, head=head)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def for_classify(
|
||||||
|
pooler_config: PoolerConfig,
|
||||||
|
classifier: ClassifierFn | None,
|
||||||
|
act_fn: PoolerActivation | str | None = None,
|
||||||
|
):
|
||||||
|
resolved_config = ResolvedPoolingConfig.from_config(
|
||||||
|
task="classify",
|
||||||
|
pooler_config=pooler_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||||
|
|
||||||
|
return ClassifierPooler(
|
||||||
|
pooling=pooling,
|
||||||
|
classifier=classifier,
|
||||||
|
act_fn=act_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
"""Determine which pooling tasks are supported."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
|
"""
|
||||||
|
Construct the updated pooling parameters to use for a supported task.
|
||||||
|
"""
|
||||||
|
return PoolingParamsUpdate()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class PoolerHead(nn.Module):
|
class PoolerHead(nn.Module):
|
||||||
def __init__(self, activation: PoolerActivation) -> None:
|
def __init__(self, activation: PoolerActivation) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead):
|
|||||||
super().__init__(activation=PoolerNormalize())
|
super().__init__(activation=PoolerNormalize())
|
||||||
|
|
||||||
# Load ST projector if available
|
# Load ST projector if available
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.projector: nn.Module | None = (
|
self.projector: nn.Module | None = (
|
||||||
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
||||||
@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead):
|
|||||||
return pooled_data
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
class RewardPoolerHead(PoolerHead):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__(activation=PoolerClassify(static_num_labels=False))
|
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
self.head_dtype = vllm_config.model_config.head_dtype
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
pooled_data: list[torch.Tensor] | torch.Tensor,
|
|
||||||
pooling_metadata: PoolingMetadata,
|
|
||||||
):
|
|
||||||
if isinstance(pooled_data, list):
|
|
||||||
pooled_data = [p.to(self.head_dtype) for p in pooled_data]
|
|
||||||
else:
|
|
||||||
pooled_data = pooled_data.to(self.head_dtype)
|
|
||||||
|
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
|
||||||
|
|
||||||
# for softmax
|
|
||||||
flags = [p.softmax for p in pooling_params]
|
|
||||||
if len(set(flags)) == 1:
|
|
||||||
if flags[0]:
|
|
||||||
pooled_data = self.activation(pooled_data)
|
|
||||||
else:
|
|
||||||
pooled_data = [
|
|
||||||
self.activation(vecs) if f else vecs
|
|
||||||
for vecs, f in zip(pooled_data, flags)
|
|
||||||
]
|
|
||||||
|
|
||||||
return pooled_data
|
|
||||||
|
|
||||||
|
|
||||||
class SimplePooler(Pooler):
|
class SimplePooler(Pooler):
|
||||||
"""A layer that pools specific information from hidden states.
|
"""A layer that pools specific information from hidden states.
|
||||||
|
|
||||||
@ -513,20 +495,6 @@ class SimplePooler(Pooler):
|
|||||||
3. Returns structured results as `PoolerOutput`.
|
3. Returns structured results as `PoolerOutput`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
pooler_config: ResolvedPoolingConfig,
|
|
||||||
) -> "SimplePooler":
|
|
||||||
pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type)
|
|
||||||
if pooler_config.task == "embed":
|
|
||||||
head = EmbeddingPoolerHead()
|
|
||||||
elif pooler_config.task == "encode":
|
|
||||||
head = RewardPoolerHead()
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown task: {pooler_config.task}")
|
|
||||||
return cls(pooling, head)
|
|
||||||
|
|
||||||
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
|
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -549,58 +517,6 @@ class SimplePooler(Pooler):
|
|||||||
return pooled_data
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
class StepPooler(Pooler):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.pooling = AllPool()
|
|
||||||
self.head = RewardPoolerHead()
|
|
||||||
|
|
||||||
def extract_states(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
|
||||||
pooling_metadata: PoolingMetadata,
|
|
||||||
) -> list[torch.Tensor] | torch.Tensor:
|
|
||||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
|
||||||
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
|
||||||
|
|
||||||
pooled_data = list[torch.Tensor]()
|
|
||||||
|
|
||||||
pooling_params = get_pooling_params(pooling_metadata)
|
|
||||||
|
|
||||||
for data, token_id, pooling_param in zip(
|
|
||||||
pooled_data_lst, prompt_token_ids, pooling_params
|
|
||||||
):
|
|
||||||
step_tag_id = pooling_param.step_tag_id
|
|
||||||
returned_token_ids = pooling_param.returned_token_ids
|
|
||||||
|
|
||||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
|
||||||
data = data[:, returned_token_ids]
|
|
||||||
|
|
||||||
if step_tag_id is not None:
|
|
||||||
data = data[token_id == step_tag_id]
|
|
||||||
pooled_data.append(data)
|
|
||||||
|
|
||||||
return pooled_data
|
|
||||||
|
|
||||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
|
||||||
return {"encode"}
|
|
||||||
|
|
||||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
|
||||||
return PoolingParamsUpdate(requires_token_ids=True)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
|
||||||
pooling_metadata: PoolingMetadata,
|
|
||||||
) -> PoolerOutput:
|
|
||||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
|
||||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
|
||||||
return pooled_data
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifierPooler(Pooler):
|
class ClassifierPooler(Pooler):
|
||||||
"""A pooling layer for classification tasks.
|
"""A pooling layer for classification tasks.
|
||||||
|
|
||||||
@ -611,26 +527,46 @@ class ClassifierPooler(Pooler):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def act_fn_for_seq_cls(config: ModelConfig):
|
def act_fn_for_seq_cls(model_config: ModelConfig):
|
||||||
return get_classification_activation_function(config.hf_config)
|
return get_classification_activation_function(model_config.hf_config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def act_fn_for_cross_encoder(config: ModelConfig):
|
def act_fn_for_cross_encoder(model_config: ModelConfig):
|
||||||
return get_cross_encoder_activation_function(config.hf_config)
|
return get_cross_encoder_activation_function(model_config.hf_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resolve_act_fn(
|
||||||
|
model_config: ModelConfig,
|
||||||
|
static_num_labels: bool = True,
|
||||||
|
act_fn: PoolerActivation | str | None = None,
|
||||||
|
):
|
||||||
|
if isinstance(act_fn, str):
|
||||||
|
if act_fn == "classify":
|
||||||
|
return ClassifierPooler.act_fn_for_seq_cls(model_config)
|
||||||
|
elif act_fn == "score":
|
||||||
|
return ClassifierPooler.act_fn_for_cross_encoder(model_config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"act_fn [{act_fn=}] not supported.")
|
||||||
|
elif act_fn is None:
|
||||||
|
return PoolerClassify(static_num_labels=static_num_labels)
|
||||||
|
else:
|
||||||
|
assert callable(act_fn)
|
||||||
|
return act_fn
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pooling: PoolingFn,
|
pooling: PoolingFn,
|
||||||
classifier: ClassifierFn | None,
|
classifier: ClassifierFn | None,
|
||||||
act_fn: PoolerActivation | None = None,
|
act_fn: PoolerActivation | str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
|
|
||||||
self.pooling = pooling
|
self.pooling = pooling
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
self.act_fn = act_fn or PoolerClassify()
|
self.act_fn = self.resolve_act_fn(
|
||||||
|
vllm_config.model_config, static_num_labels=True, act_fn=act_fn
|
||||||
|
)
|
||||||
self.logit_bias: float | None = (
|
self.logit_bias: float | None = (
|
||||||
vllm_config.model_config.pooler_config.logit_bias
|
vllm_config.model_config.pooler_config.logit_bias
|
||||||
)
|
)
|
||||||
@ -672,6 +608,150 @@ class ClassifierPooler(Pooler):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
|
||||||
|
def forward(
|
||||||
|
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pooled_data = pooled_data.to(self.head_dtype)
|
||||||
|
# pooled_data shape: [n_tokens, hidden_dimension]
|
||||||
|
|
||||||
|
# Apply ST projector
|
||||||
|
if self.projector is not None:
|
||||||
|
pooled_data = self.projector(pooled_data)
|
||||||
|
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||||
|
|
||||||
|
# for matryoshka representation
|
||||||
|
pooled_data = pooled_data[..., : pooling_param.dimensions]
|
||||||
|
|
||||||
|
# for normalize
|
||||||
|
if pooling_param.normalize:
|
||||||
|
pooled_data = self.activation(pooled_data)
|
||||||
|
|
||||||
|
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
|
class TokenClassifierPoolerHead(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: ClassifierFn | None,
|
||||||
|
act_fn: PoolerActivation | str | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
|
||||||
|
self.classifier = classifier
|
||||||
|
self.act_fn = ClassifierPooler.resolve_act_fn(
|
||||||
|
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
|
||||||
|
)
|
||||||
|
self.logit_bias: float | None = (
|
||||||
|
vllm_config.model_config.pooler_config.logit_bias
|
||||||
|
)
|
||||||
|
self.head_dtype = vllm_config.model_config.head_dtype
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return {"token_classify"}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_param: PoolingParams,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = hidden_states.to(self.head_dtype)
|
||||||
|
# hidden_states shape: [n_token, hidden_size]
|
||||||
|
|
||||||
|
if self.classifier is not None:
|
||||||
|
scores = self.classifier(hidden_states)
|
||||||
|
else:
|
||||||
|
scores = hidden_states
|
||||||
|
# scores shape: [n_token, num_labels]
|
||||||
|
|
||||||
|
if self.logit_bias is not None:
|
||||||
|
scores -= self.logit_bias
|
||||||
|
|
||||||
|
if pooling_param.activation:
|
||||||
|
scores = self.act_fn(scores)
|
||||||
|
|
||||||
|
# scores shape: [n_token, num_labels]
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class AllPooler(Pooler):
|
||||||
|
def __init__(self, head: nn.Module | PoolerHead) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pooling = AllPool()
|
||||||
|
self.head = head
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return {"token_embed", "token_classify"}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||||
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
|
assert len(pooled_data) == len(pooling_params)
|
||||||
|
|
||||||
|
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
|
class StepPooler(Pooler):
|
||||||
|
def __init__(self, head: nn.Module | PoolerHead) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pooling = AllPool()
|
||||||
|
self.head = head
|
||||||
|
|
||||||
|
def extract_states(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> torch.Tensor | list[torch.Tensor]:
|
||||||
|
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||||
|
prompt_token_ids = get_prompt_token_ids(pooling_metadata)
|
||||||
|
|
||||||
|
pooled_data = list[torch.Tensor]()
|
||||||
|
|
||||||
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
|
|
||||||
|
for data, token_id, pooling_param in zip(
|
||||||
|
pooled_data_lst, prompt_token_ids, pooling_params
|
||||||
|
):
|
||||||
|
step_tag_id = pooling_param.step_tag_id
|
||||||
|
returned_token_ids = pooling_param.returned_token_ids
|
||||||
|
|
||||||
|
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||||
|
data = data[:, returned_token_ids]
|
||||||
|
|
||||||
|
if step_tag_id is not None:
|
||||||
|
data = data[token_id == step_tag_id]
|
||||||
|
pooled_data.append(data)
|
||||||
|
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return {"token_embed", "token_classify"}
|
||||||
|
|
||||||
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
|
return PoolingParamsUpdate(requires_token_ids=True)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||||
|
pooling_params = get_pooling_params(pooling_metadata)
|
||||||
|
assert len(pooled_data) == len(pooling_params)
|
||||||
|
|
||||||
|
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||||
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
class DispatchPooler(Pooler):
|
class DispatchPooler(Pooler):
|
||||||
"""Dispatches calls to a sub-pooler based on the pooling task."""
|
"""Dispatches calls to a sub-pooler based on the pooling task."""
|
||||||
|
|
||||||
|
|||||||
@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T:
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||||
"embed": Pooler.for_embed(pooler_config),
|
"embed": Pooler.for_embed(pooler_config),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
# Lazy import
|
# Lazy import
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
from vllm.model_executor.layers.pooler import (
|
from vllm.model_executor.layers.pooler import (
|
||||||
ClassifierPooler,
|
|
||||||
DispatchPooler,
|
DispatchPooler,
|
||||||
Pooler,
|
Pooler,
|
||||||
PoolingMethod,
|
|
||||||
PoolingType,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T:
|
|||||||
model_config.hidden_size,
|
model_config.hidden_size,
|
||||||
config.num_labels,
|
config.num_labels,
|
||||||
bias=False,
|
bias=False,
|
||||||
params_dtype=torch.float32,
|
params_dtype=vllm_config.model_config.head_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
return_bias=False,
|
||||||
prefix=maybe_prefix(prefix, "score"),
|
prefix=maybe_prefix(prefix, "score"),
|
||||||
)
|
)
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
pooling_type_str = pooler_config.pooling_type
|
|
||||||
assert pooling_type_str is not None
|
|
||||||
pooling_type = PoolingType[pooling_type_str]
|
|
||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
"classify": ClassifierPooler(
|
pooler_config, classifier=self.score
|
||||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
|
||||||
classifier=self._classifier,
|
|
||||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
"score": ClassifierPooler(
|
"classify": Pooler.for_classify(
|
||||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
pooler_config, classifier=self.score, act_fn="classify"
|
||||||
classifier=self._classifier,
|
),
|
||||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
"score": Pooler.for_classify(
|
||||||
vllm_config.model_config
|
pooler_config, classifier=self.score, act_fn="score"
|
||||||
),
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _classifier(self, x: torch.Tensor):
|
|
||||||
x, _ = self.score(x.float())
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T:
|
|||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{"encode": Pooler.for_encode(pooler_config)},
|
{
|
||||||
|
"token_classify": Pooler.for_token_classify(
|
||||||
|
pooler_config=pooler_config
|
||||||
|
)
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
|
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
|
||||||
|
|||||||
@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
|||||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||||
return DispatchPooler(
|
return DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||||
"embed": Pooler.for_embed(pooler_config),
|
"embed": Pooler.for_embed(pooler_config),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
|||||||
|
|
||||||
return DispatchPooler(
|
return DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||||
"embed": SPLADESparsePooler(
|
"embed": SPLADESparsePooler(
|
||||||
mlm_head=self.mlm_head,
|
mlm_head=self.mlm_head,
|
||||||
cls_token_id=cls_id,
|
cls_token_id=cls_id,
|
||||||
@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
|
pooler_config, classifier=self.classifier
|
||||||
|
),
|
||||||
"classify": ClassifierPooler(
|
"classify": ClassifierPooler(
|
||||||
pooling=self.bert.pooler,
|
pooling=self.bert.pooler,
|
||||||
classifier=self.classifier,
|
classifier=self.classifier,
|
||||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
act_fn="classify",
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
"score": ClassifierPooler(
|
"score": ClassifierPooler(
|
||||||
pooling=self.bert.pooler,
|
pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
|
||||||
classifier=self.classifier,
|
|
||||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
|
pooler_config=pooler_config
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -695,20 +695,16 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
|
pooler_config, classifier=self.classifier
|
||||||
|
),
|
||||||
"classify": ClassifierPooler(
|
"classify": ClassifierPooler(
|
||||||
pooling=self.new.pooler,
|
pooling=self.new.pooler,
|
||||||
classifier=self.classifier,
|
classifier=self.classifier,
|
||||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
act_fn="classify",
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
"score": ClassifierPooler(
|
"score": ClassifierPooler(
|
||||||
pooling=self.new.pooler,
|
pooling=self.new.pooler, classifier=self.classifier, act_fn="score"
|
||||||
classifier=self.classifier,
|
|
||||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -837,7 +837,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||||
"embed": Pooler.for_embed(pooler_config),
|
"embed": Pooler.for_embed(pooler_config),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -353,8 +353,15 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
"classify": Pooler.for_classify(pooler_config, classifier=self.score),
|
pooler_config, classifier=self.score
|
||||||
|
),
|
||||||
|
"classify": Pooler.for_classify(
|
||||||
|
pooler_config, classifier=self.score, act_fn="classify"
|
||||||
|
),
|
||||||
|
"score": Pooler.for_classify(
|
||||||
|
pooler_config, classifier=self.score, act_fn="score"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -239,7 +239,7 @@ class GritLM(LlamaForCausalLM):
|
|||||||
if pooler_config is not None:
|
if pooler_config is not None:
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||||
"embed": GritLMPooler(vllm_config.model_config),
|
"embed": GritLMPooler(vllm_config.model_config),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -444,7 +444,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
|||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{"encode": Pooler.for_encode(pooler_config)},
|
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@ -604,10 +604,14 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
|
pooler_config, classifier=self.score
|
||||||
|
),
|
||||||
"classify": Pooler.for_classify(
|
"classify": Pooler.for_classify(
|
||||||
pooler_config,
|
pooler_config, classifier=self.score, act_fn="classify"
|
||||||
classifier=self.score,
|
),
|
||||||
|
"score": Pooler.for_classify(
|
||||||
|
pooler_config, classifier=self.score, act_fn="score"
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -97,9 +97,15 @@ class JinaVLForSequenceClassification(
|
|||||||
self.score = JinaVLScorer(vllm_config.model_config)
|
self.score = JinaVLScorer(vllm_config.model_config)
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
"classify": Pooler.for_classify(pooler_config, classifier=self.score),
|
pooler_config, classifier=self.score
|
||||||
"score": Pooler.for_classify(pooler_config, classifier=self.score),
|
),
|
||||||
|
"classify": Pooler.for_classify(
|
||||||
|
pooler_config, classifier=self.score, act_fn="classify"
|
||||||
|
),
|
||||||
|
"score": Pooler.for_classify(
|
||||||
|
pooler_config, classifier=self.score, act_fn="score"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -322,20 +322,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
|
pooler_config, classifier=self.classifier
|
||||||
|
),
|
||||||
"classify": ClassifierPooler(
|
"classify": ClassifierPooler(
|
||||||
pooling=self.pooling,
|
pooling=self.pooling, classifier=self.classifier, act_fn="classify"
|
||||||
classifier=self.classifier,
|
|
||||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
"score": ClassifierPooler(
|
"score": ClassifierPooler(
|
||||||
pooling=self.pooling,
|
pooling=self.pooling, classifier=self.classifier, act_fn="score"
|
||||||
classifier=self.classifier,
|
|
||||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -421,7 +415,9 @@ class ModernBertForTokenClassification(nn.Module):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
|
pooler_config=pooler_config
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -107,7 +107,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
|||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{"encode": Pooler.for_encode(pooler_config)},
|
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -120,4 +120,6 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
|||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)})
|
self.pooler = DispatchPooler(
|
||||||
|
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||||
|
)
|
||||||
|
|||||||
@ -105,15 +105,7 @@ class RobertaClassificationHead(nn.Module):
|
|||||||
|
|
||||||
@default_pooling_type("CLS")
|
@default_pooling_type("CLS")
|
||||||
class RobertaEmbeddingModel(BertEmbeddingModel):
|
class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||||
"""A model that uses Roberta to provide embedding functionalities.
|
"""A model that uses Roberta to provide embedding functionalities."""
|
||||||
|
|
||||||
This class encapsulates the BertModel and provides an interface for
|
|
||||||
embedding operations and customized pooling functions.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
model: An instance of BertModel used for forward operations.
|
|
||||||
_pooler: An instance of Pooler used for pooling operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
@ -212,20 +204,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
|
pooler_config=pooler_config, classifier=self.classifier
|
||||||
|
),
|
||||||
"classify": ClassifierPooler(
|
"classify": ClassifierPooler(
|
||||||
pooling=CLSPool(),
|
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
|
||||||
classifier=self.classifier,
|
|
||||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
"score": ClassifierPooler(
|
"score": ClassifierPooler(
|
||||||
pooling=CLSPool(),
|
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
|
||||||
classifier=self.classifier,
|
|
||||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -250,7 +250,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
|||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{"encode": Pooler.for_encode(pooler_config)},
|
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
|
|||||||
@ -135,7 +135,7 @@ class TransformersEmbeddingModel(TransformersPoolingBase):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||||
"embed": Pooler.for_embed(pooler_config),
|
"embed": Pooler.for_embed(pooler_config),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -190,20 +190,14 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
|
|||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{
|
{
|
||||||
"encode": Pooler.for_encode(pooler_config),
|
"token_classify": Pooler.for_token_classify(
|
||||||
|
pooler_config, classifier=self.classifier
|
||||||
|
),
|
||||||
"classify": ClassifierPooler(
|
"classify": ClassifierPooler(
|
||||||
pooling=CLSPool(),
|
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
|
||||||
classifier=self.classifier,
|
|
||||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
"score": ClassifierPooler(
|
"score": ClassifierPooler(
|
||||||
pooling=CLSPool(),
|
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
|
||||||
classifier=self.classifier,
|
|
||||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
|
||||||
vllm_config.model_config
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.sampling_params import RequestOutputKind
|
|||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig, PoolerConfig
|
||||||
|
|
||||||
|
|
||||||
class PoolingParams(
|
class PoolingParams(
|
||||||
@ -30,7 +30,6 @@ class PoolingParams(
|
|||||||
if model support matryoshka representation.
|
if model support matryoshka representation.
|
||||||
activation: Whether to apply activation function to
|
activation: Whether to apply activation function to
|
||||||
the classification outputs.
|
the classification outputs.
|
||||||
softmax: Whether to apply softmax to the reward outputs.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# --8<-- [start:common-pooling-params]
|
# --8<-- [start:common-pooling-params]
|
||||||
@ -48,32 +47,19 @@ class PoolingParams(
|
|||||||
activation: bool | None = None
|
activation: bool | None = None
|
||||||
# --8<-- [end:classification-pooling-params]
|
# --8<-- [end:classification-pooling-params]
|
||||||
|
|
||||||
## for reward models
|
## for step pooling models
|
||||||
softmax: bool | None = None
|
|
||||||
step_tag_id: int | None = None
|
step_tag_id: int | None = None
|
||||||
returned_token_ids: list[int] | None = None
|
returned_token_ids: list[int] | None = None
|
||||||
|
|
||||||
|
## Internal use only
|
||||||
task: PoolingTask | None = None
|
task: PoolingTask | None = None
|
||||||
"""Internal use only."""
|
|
||||||
|
|
||||||
requires_token_ids: bool = False
|
requires_token_ids: bool = False
|
||||||
"""Internal use only."""
|
|
||||||
|
|
||||||
extra_kwargs: dict[str, Any] | None = None
|
extra_kwargs: dict[str, Any] | None = None
|
||||||
"""Internal use only."""
|
|
||||||
|
|
||||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_parameters(self) -> list[str]:
|
def all_parameters(self) -> list[str]:
|
||||||
return [
|
return ["dimensions", "normalize", "activation"]
|
||||||
"dimensions",
|
|
||||||
"normalize",
|
|
||||||
"activation",
|
|
||||||
"softmax",
|
|
||||||
"step_tag_id",
|
|
||||||
"returned_token_ids",
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid_parameters(self):
|
def valid_parameters(self):
|
||||||
@ -81,7 +67,8 @@ class PoolingParams(
|
|||||||
"embed": ["dimensions", "normalize"],
|
"embed": ["dimensions", "normalize"],
|
||||||
"classify": ["activation"],
|
"classify": ["activation"],
|
||||||
"score": ["activation"],
|
"score": ["activation"],
|
||||||
"encode": ["softmax", "step_tag_id", "returned_token_ids"],
|
"token_embed": ["dimensions", "normalize"],
|
||||||
|
"token_classify": ["activation"],
|
||||||
}
|
}
|
||||||
|
|
||||||
def clone(self) -> "PoolingParams":
|
def clone(self) -> "PoolingParams":
|
||||||
@ -100,7 +87,6 @@ class PoolingParams(
|
|||||||
# NOTE: Task validation needs to done against the model instance,
|
# NOTE: Task validation needs to done against the model instance,
|
||||||
# which is not available in model config. So, it's not included
|
# which is not available in model config. So, it's not included
|
||||||
# in this method
|
# in this method
|
||||||
|
|
||||||
self._merge_default_parameters(model_config)
|
self._merge_default_parameters(model_config)
|
||||||
self._set_default_parameters(model_config)
|
self._set_default_parameters(model_config)
|
||||||
self._verify_valid_parameters()
|
self._verify_valid_parameters()
|
||||||
@ -125,8 +111,34 @@ class PoolingParams(
|
|||||||
if getattr(self, k, None) is None:
|
if getattr(self, k, None) is None:
|
||||||
setattr(self, k, getattr(pooler_config, k))
|
setattr(self, k, getattr(pooler_config, k))
|
||||||
|
|
||||||
|
self._verify_step_pooling(pooler_config, valid_parameters)
|
||||||
|
|
||||||
|
def _verify_step_pooling(
|
||||||
|
self, pooler_config: "PoolerConfig", valid_parameters: list[str]
|
||||||
|
):
|
||||||
|
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
||||||
|
if pooler_config.pooling_type != "STEP":
|
||||||
|
invalid_parameters = []
|
||||||
|
for k in step_pooling_parameters:
|
||||||
|
if getattr(self, k, None) is not None:
|
||||||
|
invalid_parameters.append(k)
|
||||||
|
|
||||||
|
if invalid_parameters:
|
||||||
|
raise ValueError(
|
||||||
|
f"Task {self.task} only supports {valid_parameters} "
|
||||||
|
f"parameters, does not support "
|
||||||
|
f"{invalid_parameters} parameters"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for k in step_pooling_parameters:
|
||||||
|
if getattr(pooler_config, k, None) is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if getattr(self, k, None) is None:
|
||||||
|
setattr(self, k, getattr(pooler_config, k))
|
||||||
|
|
||||||
def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
|
def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
|
||||||
if self.task == "embed":
|
if self.task in ["embed", "token_embed"]:
|
||||||
if self.normalize is None:
|
if self.normalize is None:
|
||||||
self.normalize = True
|
self.normalize = True
|
||||||
|
|
||||||
@ -150,13 +162,9 @@ class PoolingParams(
|
|||||||
elif self.dimensions < 1:
|
elif self.dimensions < 1:
|
||||||
raise ValueError("Dimensions must be greater than 0")
|
raise ValueError("Dimensions must be greater than 0")
|
||||||
|
|
||||||
elif self.task in ["classify", "score"]:
|
elif self.task in ["classify", "score", "token_classify"]:
|
||||||
if self.activation is None:
|
if self.activation is None:
|
||||||
self.activation = True
|
self.activation = True
|
||||||
|
|
||||||
elif self.task == "encode":
|
|
||||||
if self.softmax is None:
|
|
||||||
self.softmax = True
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown pooling task: {self.task}")
|
raise ValueError(f"Unknown pooling task: {self.task}")
|
||||||
|
|
||||||
@ -185,7 +193,6 @@ class PoolingParams(
|
|||||||
f"normalize={self.normalize}, "
|
f"normalize={self.normalize}, "
|
||||||
f"dimensions={self.dimensions}, "
|
f"dimensions={self.dimensions}, "
|
||||||
f"activation={self.activation}, "
|
f"activation={self.activation}, "
|
||||||
f"softmax={self.softmax}, "
|
|
||||||
f"step_tag_id={self.step_tag_id}, "
|
f"step_tag_id={self.step_tag_id}, "
|
||||||
f"returned_token_ids={self.returned_token_ids}, "
|
f"returned_token_ids={self.returned_token_ids}, "
|
||||||
f"requires_token_ids={self.requires_token_ids}, "
|
f"requires_token_ids={self.requires_token_ids}, "
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from typing import Literal, get_args
|
|||||||
GenerationTask = Literal["generate", "transcription"]
|
GenerationTask = Literal["generate", "transcription"]
|
||||||
GENERATION_TASKS = get_args(GenerationTask)
|
GENERATION_TASKS = get_args(GenerationTask)
|
||||||
|
|
||||||
PoolingTask = Literal["encode", "embed", "classify", "score"]
|
PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"]
|
||||||
POOLING_TASKS = get_args(PoolingTask)
|
POOLING_TASKS = get_args(PoolingTask)
|
||||||
|
|
||||||
SupportedTask = Literal[GenerationTask, PoolingTask]
|
SupportedTask = Literal[GenerationTask, PoolingTask]
|
||||||
|
|||||||
@ -1926,15 +1926,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
supported_tasks = list(model.pooler.get_supported_tasks())
|
supported_tasks = list(model.pooler.get_supported_tasks())
|
||||||
|
|
||||||
if (
|
if self.scheduler_config.chunked_prefill_enabled:
|
||||||
self.scheduler_config.chunked_prefill_enabled
|
if "token_embed" in supported_tasks:
|
||||||
and "encode" in supported_tasks
|
supported_tasks.remove("token_embed")
|
||||||
):
|
if "token_classify" in supported_tasks:
|
||||||
supported_tasks.remove("encode")
|
supported_tasks.remove("token_classify")
|
||||||
|
|
||||||
logger.debug_once(
|
logger.debug_once(
|
||||||
"Chunked prefill is not supported with "
|
"Chunked prefill is not supported with "
|
||||||
"encode task which using ALL pooling. "
|
"token_embed and token_classify tasks "
|
||||||
|
"which using ALL pooling. "
|
||||||
"Please turn off chunked prefill by "
|
"Please turn off chunked prefill by "
|
||||||
"`--no-enable-chunked-prefill` before using it."
|
"`--no-enable-chunked-prefill` before using it."
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user