mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 05:24:25 +08:00
Support Cross encoder models (#10400)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
parent
49628fe13e
commit
214efc2c3c
@ -44,6 +44,148 @@ We currently support the following OpenAI APIs:
|
|||||||
- This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst).
|
- This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst).
|
||||||
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
|
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
|
||||||
|
|
||||||
|
## Score API for Cross Encoder Models
|
||||||
|
|
||||||
|
vLLM supports *cross encoders models* at the **/v1/score** endpoint, which is not an OpenAI API standard endpoint. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||||
|
|
||||||
|
A ***Cross Encoder*** takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1.
|
||||||
|
|
||||||
|
### Example of usage for a pair of a string and a list of texts
|
||||||
|
|
||||||
|
In this case, the model will compare the first given text to each of the texts containing the list.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X 'POST' \
|
||||||
|
'http://127.0.0.1:8000/v1/score' \
|
||||||
|
-H 'accept: application/json' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "BAAI/bge-reranker-v2-m3",
|
||||||
|
"text_1": "What is the capital of France?",
|
||||||
|
"text_2": [
|
||||||
|
"The capital of Brazil is Brasilia.",
|
||||||
|
"The capital of France is Paris."
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"id": "score-request-id",
|
||||||
|
"object": "list",
|
||||||
|
"created": 693570,
|
||||||
|
"model": "BAAI/bge-reranker-v2-m3",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"object": "score",
|
||||||
|
"score": [
|
||||||
|
0.001094818115234375
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"object": "score",
|
||||||
|
"score": [
|
||||||
|
1
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example of usage for a pair of two lists of texts
|
||||||
|
|
||||||
|
In this case, the model will compare the one by one, making pairs by same index correspondent in each list.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X 'POST' \
|
||||||
|
'http://127.0.0.1:8000/v1/score' \
|
||||||
|
-H 'accept: application/json' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "BAAI/bge-reranker-v2-m3",
|
||||||
|
"encoding_format": "float",
|
||||||
|
"text_1": [
|
||||||
|
"What is the capital of Brazil?",
|
||||||
|
"What is the capital of France?"
|
||||||
|
],
|
||||||
|
"text_2": [
|
||||||
|
"The capital of Brazil is Brasilia.",
|
||||||
|
"The capital of France is Paris."
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"id": "score-request-id",
|
||||||
|
"object": "list",
|
||||||
|
"created": 693447,
|
||||||
|
"model": "BAAI/bge-reranker-v2-m3",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"object": "score",
|
||||||
|
"score": [
|
||||||
|
1
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"object": "score",
|
||||||
|
"score": [
|
||||||
|
1
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example of usage for a pair of two strings
|
||||||
|
|
||||||
|
In this case, the model will compare the strings of texts.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X 'POST' \
|
||||||
|
'http://127.0.0.1:8000/v1/score' \
|
||||||
|
-H 'accept: application/json' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "BAAI/bge-reranker-v2-m3",
|
||||||
|
"encoding_format": "float",
|
||||||
|
"text_1": "What is the capital of France?",
|
||||||
|
"text_2": "The capital of France is Paris."
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Response:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"id": "score-request-id",
|
||||||
|
"object": "list",
|
||||||
|
"created": 693447,
|
||||||
|
"model": "BAAI/bge-reranker-v2-m3",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"object": "score",
|
||||||
|
"score": [
|
||||||
|
1
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Extra Parameters
|
## Extra Parameters
|
||||||
|
|
||||||
vLLM supports a set of parameters that are not part of the OpenAI API.
|
vLLM supports a set of parameters that are not part of the OpenAI API.
|
||||||
|
|||||||
58
examples/openai_cross_encoder_score.py
Normal file
58
examples/openai_cross_encoder_score.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
"""Examples Python client Score for Cross Encoder Models
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import pprint
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def post_http_request(prompt: json, api_url: str) -> requests.Response:
|
||||||
|
headers = {"User-Agent": "Test Client"}
|
||||||
|
response = requests.post(api_url, headers=headers, json=prompt)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3")
|
||||||
|
args = parser.parse_args()
|
||||||
|
api_url = f"http://{args.host}:{args.port}/v1/score"
|
||||||
|
|
||||||
|
model_name = args.model
|
||||||
|
|
||||||
|
text_1 = "What is the capital of France?"
|
||||||
|
text_2 = [
|
||||||
|
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||||
|
]
|
||||||
|
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||||
|
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||||
|
print("Prompt for text_1 is string and text_2 is a list:")
|
||||||
|
pprint.pprint(prompt)
|
||||||
|
print("Score Response:")
|
||||||
|
pprint.pprint(score_response.data)
|
||||||
|
|
||||||
|
text_1 = [
|
||||||
|
"What is the capital of Brazil?", "What is the capital of France?"
|
||||||
|
]
|
||||||
|
text_2 = [
|
||||||
|
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||||
|
]
|
||||||
|
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||||
|
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||||
|
print("Prompt for text_1 and text_2 are lists:")
|
||||||
|
pprint.pprint(prompt)
|
||||||
|
print("Score Response:")
|
||||||
|
pprint.pprint(score_response.data)
|
||||||
|
|
||||||
|
text_1 = "What is the capital of Brazil?"
|
||||||
|
text_2 = "The capital of Brazil is Brasilia."
|
||||||
|
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||||
|
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||||
|
print("Prompt for text_1 and text_2 are strings:")
|
||||||
|
pprint.pprint(prompt)
|
||||||
|
print("Score Response:")
|
||||||
|
pprint.pprint(score_response.data)
|
||||||
@ -265,6 +265,7 @@ class HfRunner:
|
|||||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
is_embedding_model: bool = False,
|
is_embedding_model: bool = False,
|
||||||
is_sentence_transformer: bool = False,
|
is_sentence_transformer: bool = False,
|
||||||
|
is_cross_encoder: bool = False,
|
||||||
skip_tokenizer_init: bool = False,
|
skip_tokenizer_init: bool = False,
|
||||||
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||||
postprocess_inputs: Callable[..., BatchEncoding] = identity,
|
postprocess_inputs: Callable[..., BatchEncoding] = identity,
|
||||||
@ -282,6 +283,14 @@ class HfRunner:
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
).to(dtype=torch_dtype))
|
).to(dtype=torch_dtype))
|
||||||
|
elif is_cross_encoder:
|
||||||
|
# Lazy init required for AMD CI
|
||||||
|
from sentence_transformers import CrossEncoder
|
||||||
|
self.model = CrossEncoder(model_name,
|
||||||
|
device="cpu",
|
||||||
|
trust_remote_code=True)
|
||||||
|
self.model.model = self.wrap_device(self.model.model)\
|
||||||
|
.to(dtype=torch_dtype)
|
||||||
else:
|
else:
|
||||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||||
self.model = self.wrap_device(
|
self.model = self.wrap_device(
|
||||||
@ -625,6 +634,9 @@ class HfRunner:
|
|||||||
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
|
def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
|
||||||
return self.model.encode(prompts)
|
return self.model.encode(prompts)
|
||||||
|
|
||||||
|
def predict(self, prompts: List[List[str]]) -> torch.Tensor:
|
||||||
|
return self.model.predict(prompts, convert_to_tensor=True)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -898,6 +910,14 @@ class VllmRunner:
|
|||||||
req_outputs = self.model.encode(inputs)
|
req_outputs = self.model.encode(inputs)
|
||||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self,
|
||||||
|
text_1: Union[str, List[str]],
|
||||||
|
text_2: Union[str, List[str]],
|
||||||
|
) -> List[List[float]]:
|
||||||
|
req_outputs = self.model.score(text_1, text_2)
|
||||||
|
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
93
tests/entrypoints/openai/test_score.py
Normal file
93
tests/entrypoints/openai/test_score.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import ScoreResponse
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "BAAI/bge-reranker-v2-m3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def server():
|
||||||
|
args = [
|
||||||
|
"--enforce-eager",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
|
||||||
|
model_name: str):
|
||||||
|
text_1 = "What is the capital of France?"
|
||||||
|
text_2 = [
|
||||||
|
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||||
|
]
|
||||||
|
|
||||||
|
score_response = requests.post(server.url_for("v1/score"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"text_1": text_1,
|
||||||
|
"text_2": text_2,
|
||||||
|
})
|
||||||
|
score_response.raise_for_status()
|
||||||
|
score = ScoreResponse.model_validate(score_response.json())
|
||||||
|
|
||||||
|
assert score.id is not None
|
||||||
|
assert score.data is not None
|
||||||
|
assert len(score.data) == 2
|
||||||
|
assert score.data[0].score[0] <= 0.01
|
||||||
|
assert score.data[1].score[0] >= 0.9
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
|
||||||
|
model_name: str):
|
||||||
|
text_1 = [
|
||||||
|
"What is the capital of the United States?",
|
||||||
|
"What is the capital of France?"
|
||||||
|
]
|
||||||
|
text_2 = [
|
||||||
|
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||||
|
]
|
||||||
|
|
||||||
|
score_response = requests.post(server.url_for("v1/score"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"text_1": text_1,
|
||||||
|
"text_2": text_2,
|
||||||
|
})
|
||||||
|
score_response.raise_for_status()
|
||||||
|
score = ScoreResponse.model_validate(score_response.json())
|
||||||
|
|
||||||
|
assert score.id is not None
|
||||||
|
assert score.data is not None
|
||||||
|
assert len(score.data) == 2
|
||||||
|
assert score.data[0].score[0] <= 0.01
|
||||||
|
assert score.data[1].score[0] >= 0.9
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
|
||||||
|
model_name: str):
|
||||||
|
text_1 = "What is the capital of France?"
|
||||||
|
text_2 = "The capital of France is Paris."
|
||||||
|
|
||||||
|
score_response = requests.post(server.url_for("v1/score"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"text_1": text_1,
|
||||||
|
"text_2": text_2,
|
||||||
|
})
|
||||||
|
score_response.raise_for_status()
|
||||||
|
score = ScoreResponse.model_validate(score_response.json())
|
||||||
|
|
||||||
|
assert score.id is not None
|
||||||
|
assert score.data is not None
|
||||||
|
assert len(score.data) == 1
|
||||||
|
assert score.data[0].score[0] >= 0.9
|
||||||
95
tests/models/embedding/language/test_scoring.py
Normal file
95
tests/models/embedding/language/test_scoring.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
"""Compare the embedding outputs of HF and vLLM models.
|
||||||
|
|
||||||
|
Run `pytest tests/models/embedding/language/test_embedding.py`.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"cross-encoder/ms-marco-MiniLM-L-6-v2", # Bert
|
||||||
|
"BAAI/bge-reranker-v2-m3", # Roberta
|
||||||
|
]
|
||||||
|
|
||||||
|
TEXTS_1 = [
|
||||||
|
"What is the capital of France?",
|
||||||
|
"What is the capital of Germany?",
|
||||||
|
]
|
||||||
|
|
||||||
|
TEXTS_2 = [
|
||||||
|
"The capital of France is Paris.",
|
||||||
|
"The capital of Germany is Berlin.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=MODELS)
|
||||||
|
def model_name(request):
|
||||||
|
yield request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
|
||||||
|
|
||||||
|
text_pair = [TEXTS_1[0], TEXTS_2[0]]
|
||||||
|
|
||||||
|
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||||
|
hf_outputs = hf_model.predict([text_pair]).tolist()
|
||||||
|
|
||||||
|
with vllm_runner(model_name,
|
||||||
|
task="embedding",
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=None) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.score(text_pair[0], text_pair[1])
|
||||||
|
|
||||||
|
assert len(vllm_outputs) == 1
|
||||||
|
assert len(hf_outputs) == 1
|
||||||
|
|
||||||
|
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
||||||
|
|
||||||
|
text_pairs = [
|
||||||
|
[TEXTS_1[0], TEXTS_2[0]],
|
||||||
|
[TEXTS_1[0], TEXTS_2[1]],
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||||
|
hf_outputs = hf_model.predict(text_pairs).tolist()
|
||||||
|
|
||||||
|
with vllm_runner(model_name,
|
||||||
|
task="embedding",
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=None) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2)
|
||||||
|
|
||||||
|
assert len(vllm_outputs) == 2
|
||||||
|
assert len(hf_outputs) == 2
|
||||||
|
|
||||||
|
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
||||||
|
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
||||||
|
|
||||||
|
text_pairs = [
|
||||||
|
[TEXTS_1[0], TEXTS_2[0]],
|
||||||
|
[TEXTS_1[1], TEXTS_2[1]],
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||||
|
hf_outputs = hf_model.predict(text_pairs).tolist()
|
||||||
|
|
||||||
|
with vllm_runner(model_name,
|
||||||
|
task="embedding",
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=None) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2)
|
||||||
|
|
||||||
|
assert len(vllm_outputs) == 2
|
||||||
|
assert len(hf_outputs) == 2
|
||||||
|
|
||||||
|
assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
|
||||||
|
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
|
||||||
@ -135,6 +135,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
|||||||
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
||||||
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
||||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
|
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
|
||||||
|
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
|
||||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-large"),
|
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-large"),
|
||||||
# [Multimodal]
|
# [Multimodal]
|
||||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
||||||
@ -143,6 +144,13 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
|||||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||||
|
# [Text-only]
|
||||||
|
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
|
||||||
|
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
|
||||||
|
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
|
||||||
|
}
|
||||||
|
|
||||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||||
# [Decoder-only]
|
# [Decoder-only]
|
||||||
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
|
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
|
||||||
@ -195,6 +203,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
_EXAMPLE_MODELS = {
|
_EXAMPLE_MODELS = {
|
||||||
**_TEXT_GENERATION_EXAMPLE_MODELS,
|
**_TEXT_GENERATION_EXAMPLE_MODELS,
|
||||||
**_EMBEDDING_EXAMPLE_MODELS,
|
**_EMBEDDING_EXAMPLE_MODELS,
|
||||||
|
**_CROSS_ENCODER_EXAMPLE_MODELS,
|
||||||
**_MULTIMODAL_EXAMPLE_MODELS,
|
**_MULTIMODAL_EXAMPLE_MODELS,
|
||||||
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
|
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,7 +6,10 @@ import torch.cuda
|
|||||||
from vllm.model_executor.models import (is_embedding_model,
|
from vllm.model_executor.models import (is_embedding_model,
|
||||||
is_text_generation_model,
|
is_text_generation_model,
|
||||||
supports_multimodal)
|
supports_multimodal)
|
||||||
from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS,
|
||||||
|
_EMBEDDING_MODELS,
|
||||||
_MULTIMODAL_MODELS,
|
_MULTIMODAL_MODELS,
|
||||||
_SPECULATIVE_DECODING_MODELS,
|
_SPECULATIVE_DECODING_MODELS,
|
||||||
_TEXT_GENERATION_MODELS,
|
_TEXT_GENERATION_MODELS,
|
||||||
@ -29,22 +32,28 @@ def test_registry_imports(model_arch):
|
|||||||
model_arch in _TEXT_GENERATION_MODELS
|
model_arch in _TEXT_GENERATION_MODELS
|
||||||
or model_arch in _MULTIMODAL_MODELS)
|
or model_arch in _MULTIMODAL_MODELS)
|
||||||
|
|
||||||
|
embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS}
|
||||||
assert is_embedding_model(model_cls) is (model_arch
|
assert is_embedding_model(model_cls) is (model_arch
|
||||||
in _EMBEDDING_MODELS)
|
in embedding_models)
|
||||||
|
|
||||||
assert supports_multimodal(model_cls) is (model_arch
|
assert supports_multimodal(model_cls) is (model_arch
|
||||||
in _MULTIMODAL_MODELS)
|
in _MULTIMODAL_MODELS)
|
||||||
|
|
||||||
|
|
||||||
@fork_new_process_for_each_test
|
@fork_new_process_for_each_test
|
||||||
@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [
|
@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [
|
||||||
("LlamaForCausalLM", False, False),
|
("LlamaForCausalLM", False, False, False),
|
||||||
("MllamaForConditionalGeneration", True, False),
|
("MllamaForConditionalGeneration", True, False, False),
|
||||||
("LlavaForConditionalGeneration", True, True),
|
("LlavaForConditionalGeneration", True, True, False),
|
||||||
|
("BertForSequenceClassification", False, False, True),
|
||||||
|
("RobertaForSequenceClassification", False, False, True),
|
||||||
|
("XLMRobertaForSequenceClassification", False, False, True),
|
||||||
])
|
])
|
||||||
def test_registry_is_multimodal(model_arch, is_mm, init_cuda):
|
def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce):
|
||||||
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm
|
assert ModelRegistry.is_multimodal_model(model_arch) is is_mm
|
||||||
|
|
||||||
|
assert ModelRegistry.is_cross_encoder_model(model_arch) is is_ce
|
||||||
|
|
||||||
if init_cuda and current_platform.is_cuda_alike():
|
if init_cuda and current_platform.is_cuda_alike():
|
||||||
assert not torch.cuda.is_initialized()
|
assert not torch.cuda.is_initialized()
|
||||||
|
|
||||||
|
|||||||
@ -712,6 +712,11 @@ class ModelConfig:
|
|||||||
def is_multimodal_model(self) -> bool:
|
def is_multimodal_model(self) -> bool:
|
||||||
return self.multimodal_config is not None
|
return self.multimodal_config is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_cross_encoder(self) -> bool:
|
||||||
|
architectures = getattr(self.hf_config, "architectures", [])
|
||||||
|
return ModelRegistry.is_cross_encoder_model(architectures)
|
||||||
|
|
||||||
|
|
||||||
class CacheConfig:
|
class CacheConfig:
|
||||||
"""Configuration for the KV cache.
|
"""Configuration for the KV cache.
|
||||||
|
|||||||
@ -1357,6 +1357,7 @@ class Scheduler:
|
|||||||
encoder_seq_data=encoder_seq_data,
|
encoder_seq_data=encoder_seq_data,
|
||||||
cross_block_table=cross_block_table,
|
cross_block_table=cross_block_table,
|
||||||
state=seq_group.state,
|
state=seq_group.state,
|
||||||
|
token_type_ids=seq_group.token_type_ids,
|
||||||
# `multi_modal_data` will only be present for the 1st comm
|
# `multi_modal_data` will only be present for the 1st comm
|
||||||
# between engine and worker.
|
# between engine and worker.
|
||||||
# the subsequent comms can still use delta, but
|
# the subsequent comms can still use delta, but
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
|||||||
apply_mistral_chat_template,
|
apply_mistral_chat_template,
|
||||||
parse_chat_messages,
|
parse_chat_messages,
|
||||||
resolve_chat_template_content_format)
|
resolve_chat_template_content_format)
|
||||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -817,6 +817,128 @@ class LLM:
|
|||||||
return self.engine_class.validate_outputs(outputs,
|
return self.engine_class.validate_outputs(outputs,
|
||||||
EmbeddingRequestOutput)
|
EmbeddingRequestOutput)
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self,
|
||||||
|
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||||
|
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||||
|
/,
|
||||||
|
truncate_prompt_tokens: Optional[int] = None,
|
||||||
|
use_tqdm: bool = True,
|
||||||
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> List[EmbeddingRequestOutput]:
|
||||||
|
"""Generates similarity scores for all pairs <text,text_pair>.
|
||||||
|
|
||||||
|
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
|
||||||
|
the text_1 sentence will be replicated N times to pair with the text_2
|
||||||
|
sentences. The input pairs are used to build a list of prompts for the
|
||||||
|
cross encoder model. This class automatically batches the prompts,
|
||||||
|
considering the memory constraint. For the best performance, put all
|
||||||
|
of your texts into a single list and pass it to this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_1: can be a single prompt or a list of prompts, in which
|
||||||
|
case it has to have the same length as the text_2 list
|
||||||
|
text_2: The texts to pair with the query to form the input
|
||||||
|
to the LLM. See :class:`~vllm.inputs.PromptType` for
|
||||||
|
more details about the format of each prompts.
|
||||||
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
|
lora_request: LoRA request to use for generation, if any.
|
||||||
|
prompt_adapter_request: Prompt Adapter request to use for
|
||||||
|
generation, if any.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||||
|
generated scores in the same order as the input prompts.
|
||||||
|
"""
|
||||||
|
task = self.llm_engine.model_config.task
|
||||||
|
if task != "embedding":
|
||||||
|
messages = ["LLM.score() is only supported for embedding models."]
|
||||||
|
|
||||||
|
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||||
|
if "embedding" in supported_tasks:
|
||||||
|
messages.append(
|
||||||
|
"Your model supports the 'embedding' task, but is "
|
||||||
|
f"currently initialized for the '{task}' task. Please "
|
||||||
|
"initialize the model using `--task embedding`.")
|
||||||
|
|
||||||
|
raise ValueError(" ".join(messages))
|
||||||
|
|
||||||
|
if not self.llm_engine.model_config.is_cross_encoder:
|
||||||
|
raise ValueError("Your model does not support the cross encoding")
|
||||||
|
|
||||||
|
tokenizer = self.llm_engine.get_tokenizer()
|
||||||
|
|
||||||
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
|
raise ValueError(
|
||||||
|
"MistralTokenizer not supported for cross-encoding")
|
||||||
|
|
||||||
|
# the tokenizer for models such as
|
||||||
|
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
||||||
|
# lists of tokens to the `text` and `text_pair` kwargs
|
||||||
|
def ensure_str(prompt: SingletonPrompt):
|
||||||
|
if isinstance(prompt, dict):
|
||||||
|
if "multi_modal_data" in prompt:
|
||||||
|
raise ValueError("Multi-modal prompt is not "
|
||||||
|
"supported for cross encoding")
|
||||||
|
elif "prompt_token_ids" in prompt:
|
||||||
|
prompt = tokenizer.decode(
|
||||||
|
cast(TokensPrompt, prompt)["prompt_token_ids"])
|
||||||
|
elif "prompt" in prompt:
|
||||||
|
prompt = cast(TextPrompt, prompt)["prompt"]
|
||||||
|
assert type(prompt) is str
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
if isinstance(text_1, (str, dict)):
|
||||||
|
# Convert a single prompt to a list.
|
||||||
|
text_1 = [text_1]
|
||||||
|
text_1 = [ensure_str(t) for t in text_1]
|
||||||
|
|
||||||
|
if isinstance(text_2, (str, dict)):
|
||||||
|
# Convert a single prompt to a list.
|
||||||
|
text_2 = [text_2]
|
||||||
|
text_2 = [ensure_str(t) for t in text_2]
|
||||||
|
|
||||||
|
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||||
|
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||||
|
if len(text_1) == 0:
|
||||||
|
raise ValueError("At least one text element must be given")
|
||||||
|
if len(text_2) == 0:
|
||||||
|
raise ValueError("At least one text_pair element must be given")
|
||||||
|
|
||||||
|
if len(text_1) == 1:
|
||||||
|
text_1 = text_1 * len(text_2)
|
||||||
|
|
||||||
|
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
|
||||||
|
pooling_params = PoolingParams()
|
||||||
|
|
||||||
|
tokenization_kwargs: Dict[str, Any] = {}
|
||||||
|
if truncate_prompt_tokens is not None:
|
||||||
|
tokenization_kwargs["truncation"] = True
|
||||||
|
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||||
|
|
||||||
|
parsed_prompts = []
|
||||||
|
|
||||||
|
for q, t in input_pairs:
|
||||||
|
prompt_inputs = tokenizer(text=q,
|
||||||
|
text_pair=t,
|
||||||
|
**tokenization_kwargs)
|
||||||
|
engine_prompt = TokensPrompt(
|
||||||
|
prompt_token_ids=prompt_inputs["input_ids"],
|
||||||
|
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||||
|
parsed_prompts.append(engine_prompt)
|
||||||
|
|
||||||
|
self._validate_and_add_requests(
|
||||||
|
prompts=parsed_prompts,
|
||||||
|
params=pooling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||||
|
return self.engine_class.validate_outputs(outputs,
|
||||||
|
EmbeddingRequestOutput)
|
||||||
|
|
||||||
def start_profile(self) -> None:
|
def start_profile(self) -> None:
|
||||||
self.llm_engine.start_profile()
|
self.llm_engine.start_profile()
|
||||||
|
|
||||||
|
|||||||
@ -45,6 +45,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
EmbeddingResponse, ErrorResponse,
|
EmbeddingResponse, ErrorResponse,
|
||||||
LoadLoraAdapterRequest,
|
LoadLoraAdapterRequest,
|
||||||
|
ScoreRequest, ScoreResponse,
|
||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse,
|
TokenizeResponse,
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
|||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||||
|
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||||
from vllm.entrypoints.openai.serving_tokenization import (
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
OpenAIServingTokenization)
|
OpenAIServingTokenization)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
@ -280,6 +282,10 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
|||||||
return request.app.state.openai_serving_embedding
|
return request.app.state.openai_serving_embedding
|
||||||
|
|
||||||
|
|
||||||
|
def score(request: Request) -> Optional[OpenAIServingScores]:
|
||||||
|
return request.app.state.openai_serving_scores
|
||||||
|
|
||||||
|
|
||||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||||
return request.app.state.openai_serving_tokenization
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
@ -391,6 +397,23 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|||||||
assert_never(generator)
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/score")
|
||||||
|
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||||
|
handler = score(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
return base(raw_request).create_error_response(
|
||||||
|
message="The model does not support Score API")
|
||||||
|
|
||||||
|
generator = await handler.create_score(request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, ScoreResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||||
@ -466,8 +489,9 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
async def validation_exception_handler(_, exc):
|
async def validation_exception_handler(_, exc):
|
||||||
chat = app.state.openai_serving_chat
|
err = ErrorResponse(message=str(exc),
|
||||||
err = chat.create_error_response(message=str(exc))
|
type="BadRequestError",
|
||||||
|
code=HTTPStatus.BAD_REQUEST)
|
||||||
return JSONResponse(err.model_dump(),
|
return JSONResponse(err.model_dump(),
|
||||||
status_code=HTTPStatus.BAD_REQUEST)
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
@ -565,6 +589,13 @@ def init_app_state(
|
|||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
) if model_config.task == "embedding" else None
|
) if model_config.task == "embedding" else None
|
||||||
|
state.openai_serving_scores = OpenAIServingScores(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
base_model_paths,
|
||||||
|
request_logger=request_logger
|
||||||
|
) if (model_config.task == "embedding" \
|
||||||
|
and model_config.is_cross_encoder) else None
|
||||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
|
|||||||
@ -806,6 +806,27 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
|||||||
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreRequest(OpenAIBaseModel):
|
||||||
|
model: str
|
||||||
|
text_1: Union[List[str], str]
|
||||||
|
text_2: Union[List[str], str]
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
|
|
||||||
|
# doc: begin-chat-embedding-pooling-params
|
||||||
|
additional_data: Optional[Any] = None
|
||||||
|
# doc: end-chat-embedding-pooling-params
|
||||||
|
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."))
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(additional_data=self.additional_data)
|
||||||
|
|
||||||
|
|
||||||
class CompletionLogProbs(OpenAIBaseModel):
|
class CompletionLogProbs(OpenAIBaseModel):
|
||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
@ -876,6 +897,21 @@ class EmbeddingResponse(OpenAIBaseModel):
|
|||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreResponseData(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
object: str = "score"
|
||||||
|
score: Union[List[float], str]
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||||
|
object: str = "list"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
data: List[ScoreResponseData]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
class FunctionCall(OpenAIBaseModel):
|
class FunctionCall(OpenAIBaseModel):
|
||||||
name: str
|
name: str
|
||||||
arguments: str
|
arguments: str
|
||||||
|
|||||||
215
vllm/entrypoints/openai/serving_score.py
Normal file
215
vllm/entrypoints/openai/serving_score.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
||||||
|
ScoreResponse, ScoreResponseData,
|
||||||
|
UsageInfo)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||||
|
from vllm.inputs.data import TokensPrompt
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.outputs import EmbeddingRequestOutput
|
||||||
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
from vllm.utils import merge_async_iterators, random_uuid
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def request_output_to_score_response(
|
||||||
|
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||||
|
created_time: int, model_name: str) -> ScoreResponse:
|
||||||
|
data: List[ScoreResponseData] = []
|
||||||
|
score = None
|
||||||
|
num_prompt_tokens = 0
|
||||||
|
for idx, final_res in enumerate(final_res_batch):
|
||||||
|
if final_res is not None:
|
||||||
|
score = final_res.outputs.embedding
|
||||||
|
score_data = ScoreResponseData(index=idx, score=score)
|
||||||
|
data.append(score_data)
|
||||||
|
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
total_tokens=num_prompt_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ScoreResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
data=data,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
|
||||||
|
str]) -> List:
|
||||||
|
if isinstance(text_1, (str, dict)):
|
||||||
|
# Convert a single prompt to a list.
|
||||||
|
text_1 = [text_1]
|
||||||
|
text_1 = [t for t in text_1]
|
||||||
|
|
||||||
|
if isinstance(text_2, (str, dict)):
|
||||||
|
# Convert a single prompt to a list.
|
||||||
|
text_2 = [text_2]
|
||||||
|
text_2 = [t for t in text_2]
|
||||||
|
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||||
|
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||||
|
if len(text_1) == 0:
|
||||||
|
raise ValueError("At least one text element must be given")
|
||||||
|
if len(text_2) == 0:
|
||||||
|
raise ValueError("At least one text_pair element must be given")
|
||||||
|
|
||||||
|
if len(text_1) == 1:
|
||||||
|
text_1 = text_1 * len(text_2)
|
||||||
|
|
||||||
|
return [(t1, t2) for t1, t2 in zip(text_1, text_2)]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingScores(OpenAIServing):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine_client: EngineClient,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
base_model_paths: List[BaseModelPath],
|
||||||
|
*,
|
||||||
|
request_logger: Optional[RequestLogger],
|
||||||
|
) -> None:
|
||||||
|
super().__init__(engine_client=engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
base_model_paths=base_model_paths,
|
||||||
|
lora_modules=None,
|
||||||
|
prompt_adapters=None,
|
||||||
|
request_logger=request_logger)
|
||||||
|
|
||||||
|
async def create_score(
|
||||||
|
self,
|
||||||
|
request: ScoreRequest,
|
||||||
|
raw_request: Optional[Request] = None,
|
||||||
|
) -> Union[ScoreResponse, ErrorResponse]:
|
||||||
|
"""
|
||||||
|
Score API similar to Sentence Transformers cross encoder
|
||||||
|
|
||||||
|
See https://sbert.net/docs/package_reference/cross_encoder
|
||||||
|
"""
|
||||||
|
error_check_ret = await self._check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
model_name = request.model
|
||||||
|
request_id = f"score-{random_uuid()}"
|
||||||
|
created_time = int(time.monotonic())
|
||||||
|
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||||
|
|
||||||
|
request_prompts = []
|
||||||
|
engine_prompts = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
lora_request,
|
||||||
|
prompt_adapter_request,
|
||||||
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
|
if prompt_adapter_request is not None:
|
||||||
|
raise NotImplementedError("Prompt adapter is not supported "
|
||||||
|
"for embedding models")
|
||||||
|
|
||||||
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
|
raise ValueError(
|
||||||
|
"MistralTokenizer not supported for cross-encoding")
|
||||||
|
|
||||||
|
if not self.model_config.is_cross_encoder:
|
||||||
|
raise ValueError("Model is not cross encoder.")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# Schedule the request and get the result generator.
|
||||||
|
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||||
|
|
||||||
|
input_pairs = make_pairs(request.text_1, request.text_2)
|
||||||
|
|
||||||
|
for q, t in input_pairs:
|
||||||
|
request_prompt = f"{q}{tokenizer.sep_token}{t}"
|
||||||
|
|
||||||
|
tokenization_kwargs: Dict[str, Any] = {}
|
||||||
|
if truncate_prompt_tokens is not None:
|
||||||
|
tokenization_kwargs["truncation"] = True
|
||||||
|
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||||
|
|
||||||
|
prompt_inputs = tokenizer(text=q,
|
||||||
|
text_pair=t,
|
||||||
|
**tokenization_kwargs)
|
||||||
|
engine_prompt = TokensPrompt(
|
||||||
|
prompt_token_ids=prompt_inputs["input_ids"],
|
||||||
|
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||||
|
|
||||||
|
request_prompts.append(request_prompt)
|
||||||
|
engine_prompts.append(engine_prompt)
|
||||||
|
|
||||||
|
try:
|
||||||
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
|
self._log_inputs(request_id_item,
|
||||||
|
request_prompts[i],
|
||||||
|
params=pooling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
|
trace_headers = (None if raw_request is None else await
|
||||||
|
self._get_trace_headers(raw_request.headers))
|
||||||
|
|
||||||
|
generator = self.engine_client.encode(
|
||||||
|
engine_prompt,
|
||||||
|
pooling_params,
|
||||||
|
request_id_item,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
priority=request.priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
generators.append(generator)
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
result_generator = merge_async_iterators(
|
||||||
|
*generators,
|
||||||
|
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_prompts = len(engine_prompts)
|
||||||
|
|
||||||
|
# Non-streaming response
|
||||||
|
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||||
|
final_res_batch = [None] * num_prompts
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for i, res in result_generator:
|
||||||
|
final_res_batch[i] = res
|
||||||
|
|
||||||
|
assert all(final_res is not None for final_res in final_res_batch)
|
||||||
|
|
||||||
|
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
||||||
|
final_res_batch)
|
||||||
|
|
||||||
|
response = request_output_to_score_response(
|
||||||
|
final_res_batch_checked, request_id, created_time, model_name)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
return response
|
||||||
@ -38,6 +38,9 @@ class TokensPrompt(TypedDict):
|
|||||||
prompt_token_ids: List[int]
|
prompt_token_ids: List[int]
|
||||||
"""A list of token IDs to pass to the model."""
|
"""A list of token IDs to pass to the model."""
|
||||||
|
|
||||||
|
token_type_ids: NotRequired[List[int]]
|
||||||
|
"""A list of token type IDs to pass to the cross encoder model."""
|
||||||
|
|
||||||
multi_modal_data: NotRequired["MultiModalDataDict"]
|
multi_modal_data: NotRequired["MultiModalDataDict"]
|
||||||
"""
|
"""
|
||||||
DEPRECATED: Optional multi-modal data to pass to the model,
|
DEPRECATED: Optional multi-modal data to pass to the model,
|
||||||
@ -133,6 +136,9 @@ class TokenInputs(TypedDict):
|
|||||||
prompt_token_ids: List[int]
|
prompt_token_ids: List[int]
|
||||||
"""The token IDs of the prompt."""
|
"""The token IDs of the prompt."""
|
||||||
|
|
||||||
|
token_type_ids: NotRequired[List[int]]
|
||||||
|
"""The token type IDs of the prompt."""
|
||||||
|
|
||||||
prompt: NotRequired[str]
|
prompt: NotRequired[str]
|
||||||
"""
|
"""
|
||||||
The original prompt text corresponding to the token IDs, if available.
|
The original prompt text corresponding to the token IDs, if available.
|
||||||
@ -160,6 +166,7 @@ class TokenInputs(TypedDict):
|
|||||||
|
|
||||||
def token_inputs(
|
def token_inputs(
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
|
token_type_ids: Optional[List[int]] = None,
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||||
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
|
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
|
||||||
@ -170,6 +177,8 @@ def token_inputs(
|
|||||||
|
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
inputs["prompt"] = prompt
|
inputs["prompt"] = prompt
|
||||||
|
if token_type_ids is not None:
|
||||||
|
inputs["token_type_ids"] = token_type_ids
|
||||||
if multi_modal_data is not None:
|
if multi_modal_data is not None:
|
||||||
inputs["multi_modal_data"] = multi_modal_data
|
inputs["multi_modal_data"] = multi_modal_data
|
||||||
if multi_modal_placeholders is not None:
|
if multi_modal_placeholders is not None:
|
||||||
@ -234,6 +243,15 @@ class SingletonInputsAdapter:
|
|||||||
|
|
||||||
assert_never(inputs)
|
assert_never(inputs)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def token_type_ids(self) -> List[int]:
|
||||||
|
inputs = self.inputs
|
||||||
|
|
||||||
|
if inputs["type"] == "token" or inputs["type"] == "multimodal":
|
||||||
|
return inputs.get("token_type_ids", [])
|
||||||
|
|
||||||
|
assert_never(inputs)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||||
inputs = self.inputs
|
inputs = self.inputs
|
||||||
|
|||||||
@ -305,6 +305,7 @@ class InputPreprocessor:
|
|||||||
tokens_content = parsed["content"]
|
tokens_content = parsed["content"]
|
||||||
|
|
||||||
prompt_token_ids = tokens_content["prompt_token_ids"]
|
prompt_token_ids = tokens_content["prompt_token_ids"]
|
||||||
|
token_type_ids = tokens_content.get("token_type_ids")
|
||||||
multi_modal_data = tokens_content.get("multi_modal_data")
|
multi_modal_data = tokens_content.get("multi_modal_data")
|
||||||
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
||||||
|
|
||||||
@ -318,6 +319,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
return token_inputs(
|
return token_inputs(
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
multi_modal_data=multi_modal_data,
|
multi_modal_data=multi_modal_data,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,11 +3,14 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import PoolerConfig
|
from vllm.config import PoolerConfig
|
||||||
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
|
||||||
PoolingTensors)
|
PoolingTensors)
|
||||||
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
|
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
|
||||||
|
from vllm.transformers_utils.config import (
|
||||||
|
get_cross_encoder_activation_function)
|
||||||
|
|
||||||
|
|
||||||
class PoolingType(IntEnum):
|
class PoolingType(IntEnum):
|
||||||
@ -152,3 +155,64 @@ class Pooler(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
return PoolerOutput(outputs=pooled_outputs)
|
return PoolerOutput(outputs=pooled_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEncodingPooler(nn.Module):
|
||||||
|
"""A layer that pools specific information from hidden states.
|
||||||
|
|
||||||
|
This layer does the following:
|
||||||
|
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||||
|
2. Normalizes output if specified.
|
||||||
|
3. Returns structured results as `PoolerOutput`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
pooling_type: The type of pooling to use.
|
||||||
|
normalize: Whether to normalize the pooled data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
classifier: nn.Module,
|
||||||
|
pooler: Optional[nn.Module] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.classifier = classifier
|
||||||
|
self.pooler = pooler
|
||||||
|
self.default_activation_function = \
|
||||||
|
get_cross_encoder_activation_function(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
"""Pools sentence pair scores from the hidden_states."""
|
||||||
|
|
||||||
|
prompt_lens = PoolingTensors.from_pooling_metadata(
|
||||||
|
pooling_metadata, hidden_states.device).prompt_lens
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
pooled_data_lst = []
|
||||||
|
for prompt_len in prompt_lens:
|
||||||
|
pooled_data_i = hidden_states[offset:offset + prompt_len]
|
||||||
|
|
||||||
|
if self.pooler is not None:
|
||||||
|
final_shape_tensor = self.pooler(pooled_data_i)
|
||||||
|
else:
|
||||||
|
final_shape_tensor = self.classifier(pooled_data_i)
|
||||||
|
|
||||||
|
pooled_data_lst.append(final_shape_tensor)
|
||||||
|
offset += prompt_len
|
||||||
|
|
||||||
|
pooled_output = torch.stack(pooled_data_lst)
|
||||||
|
|
||||||
|
if self.pooler is not None:
|
||||||
|
# apply classifier once on the full batch if possible
|
||||||
|
pooled_output = self.classifier(pooled_output)
|
||||||
|
logits = self.default_activation_function(pooled_output)
|
||||||
|
|
||||||
|
pooled_outputs = [
|
||||||
|
EmbeddingSequenceGroupOutput(data.tolist()) for data in logits
|
||||||
|
]
|
||||||
|
return PoolerOutput(outputs=pooled_outputs)
|
||||||
|
|||||||
@ -11,14 +11,18 @@ from vllm.model_executor.layers.activation import get_act_fn
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler,
|
||||||
|
PoolingType)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
from vllm.transformers_utils.config import (
|
||||||
|
get_cross_encoder_activation_function)
|
||||||
|
|
||||||
from .utils import maybe_prefix
|
from .utils import maybe_prefix
|
||||||
|
|
||||||
@ -48,7 +52,9 @@ class BertEmbedding(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
seq_lens: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
|
|
||||||
@ -58,17 +64,34 @@ class BertEmbedding(nn.Module):
|
|||||||
# Position embeddings.
|
# Position embeddings.
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
|
||||||
# Token type embeddings. (TODO: move off hotpath?)
|
if token_type_ids is None:
|
||||||
token_type_embeddings = self.token_type_embeddings(
|
token_type_ids = torch.zeros(input_shape,
|
||||||
torch.zeros(input_shape,
|
dtype=torch.long,
|
||||||
dtype=torch.long,
|
device=inputs_embeds.device)
|
||||||
device=inputs_embeds.device))
|
|
||||||
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||||
embeddings = self.LayerNorm(embeddings)
|
embeddings = self.LayerNorm(embeddings)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class BertPooler(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: BertConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
# We "pool" the model by simply taking the hidden state corresponding
|
||||||
|
# to the first token.
|
||||||
|
first_token_tensor = hidden_states[0, :]
|
||||||
|
pooled_output = self.dense(first_token_tensor)
|
||||||
|
pooled_output = self.activation(pooled_output)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
class BertEncoder(nn.Module):
|
class BertEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -309,7 +332,8 @@ class BertModel(nn.Module):
|
|||||||
*,
|
*,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
embedding_class: type = BertEmbedding):
|
embedding_class: type = BertEmbedding,
|
||||||
|
add_pooling_layer: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
@ -319,6 +343,7 @@ class BertModel(nn.Module):
|
|||||||
cache_config,
|
cache_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix=f"{prefix}.encoder")
|
prefix=f"{prefix}.encoder")
|
||||||
|
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -328,13 +353,17 @@ class BertModel(nn.Module):
|
|||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
hidden_states = self.embeddings(input_ids=input_ids,
|
assert hasattr(attn_metadata, "seq_lens_tensor")
|
||||||
position_ids=position_ids)
|
hidden_states = self.embeddings(
|
||||||
|
input_ids=input_ids,
|
||||||
|
seq_lens=attn_metadata.seq_lens_tensor,
|
||||||
|
position_ids=position_ids,
|
||||||
|
token_type_ids=token_type_ids)
|
||||||
return self.encoder(hidden_states, kv_caches, attn_metadata)
|
return self.encoder(hidden_states, kv_caches, attn_metadata)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
@ -349,7 +378,7 @@ class BertModel(nn.Module):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: Set[str] = set()
|
loaded_params: Set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "pooler" in name:
|
if self.pooler is None and "pooler" in name:
|
||||||
continue
|
continue
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
@ -430,3 +459,78 @@ class BertEmbeddingModel(nn.Module):
|
|||||||
pooling_type=PoolingType.CLS,
|
pooling_type=PoolingType.CLS,
|
||||||
normalize=True,
|
normalize=True,
|
||||||
softmax=False)
|
softmax=False)
|
||||||
|
|
||||||
|
|
||||||
|
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||||
|
"""A model that uses Bert to provide embedding functionalities.
|
||||||
|
|
||||||
|
This class encapsulates the BertModel and provides an interface for
|
||||||
|
embedding operations and customized pooling functions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: An instance of BertModel used for forward operations.
|
||||||
|
_pooler: An instance of Pooler used for pooling operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
|
self.default_activation_function = \
|
||||||
|
get_cross_encoder_activation_function(config)
|
||||||
|
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.bert = BertModel(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "bert"),
|
||||||
|
embedding_class=BertEmbedding,
|
||||||
|
add_pooling_layer=True)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
self._pooler = CrossEncodingPooler(config, self.classifier,
|
||||||
|
self.bert.pooler)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|
||||||
|
self_weights = []
|
||||||
|
|
||||||
|
def weight_filter():
|
||||||
|
for name, weight in weights:
|
||||||
|
if name.startswith("bert."):
|
||||||
|
yield (name[len("bert."):], weight)
|
||||||
|
else:
|
||||||
|
self_weights.append((name, weight))
|
||||||
|
|
||||||
|
self.bert.load_weights(weight_filter())
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
for name, loaded_weight in self_weights:
|
||||||
|
if name.startswith("classifier"):
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return self._pooler(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.bert(input_ids=input_ids,
|
||||||
|
position_ids=positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
token_type_ids=token_type_ids)
|
||||||
|
|||||||
@ -7,6 +7,8 @@ from typing_extensions import TypeIs
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import supports_kw
|
from vllm.utils import supports_kw
|
||||||
|
|
||||||
|
from .interfaces_base import is_embedding_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
@ -350,3 +352,37 @@ def is_attention_free(
|
|||||||
return isinstance(model, _IsAttentionFreeType)
|
return isinstance(model, _IsAttentionFreeType)
|
||||||
|
|
||||||
return isinstance(model, IsAttentionFree)
|
return isinstance(model, IsAttentionFree)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SupportsCrossEncoding(Protocol):
|
||||||
|
"""The interface required for all models that support cross encoding."""
|
||||||
|
|
||||||
|
supports_cross_encoding: ClassVar[Literal[True]] = True
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_cross_encoding(
|
||||||
|
model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def _supports_cross_encoding(
|
||||||
|
model: Union[Type[object], object],
|
||||||
|
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
|
||||||
|
|
||||||
|
if isinstance(model, type):
|
||||||
|
return isinstance(model, SupportsCrossEncoding)
|
||||||
|
|
||||||
|
return isinstance(model, SupportsCrossEncoding)
|
||||||
|
|
||||||
|
|
||||||
|
def supports_cross_encoding(
|
||||||
|
model: Union[Type[object], object],
|
||||||
|
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
|
||||||
|
return is_embedding_model(model) and _supports_cross_encoding(model)
|
||||||
|
|||||||
@ -21,7 +21,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .interfaces import (has_inner_state, is_attention_free,
|
from .interfaces import (has_inner_state, is_attention_free,
|
||||||
supports_multimodal, supports_pp)
|
supports_cross_encoding, supports_multimodal,
|
||||||
|
supports_pp)
|
||||||
from .interfaces_base import is_embedding_model, is_text_generation_model
|
from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -100,6 +101,7 @@ _EMBEDDING_MODELS = {
|
|||||||
# [Text-only]
|
# [Text-only]
|
||||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||||
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||||
|
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
|
||||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||||
@ -121,6 +123,14 @@ _EMBEDDING_MODELS = {
|
|||||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_CROSS_ENCODER_MODELS = {
|
||||||
|
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
||||||
|
"RobertaForSequenceClassification": ("roberta",
|
||||||
|
"RobertaForSequenceClassification"),
|
||||||
|
"XLMRobertaForSequenceClassification": ("roberta",
|
||||||
|
"RobertaForSequenceClassification"),
|
||||||
|
}
|
||||||
|
|
||||||
_MULTIMODAL_MODELS = {
|
_MULTIMODAL_MODELS = {
|
||||||
# [Decoder-only]
|
# [Decoder-only]
|
||||||
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
||||||
@ -159,6 +169,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
|||||||
_VLLM_MODELS = {
|
_VLLM_MODELS = {
|
||||||
**_TEXT_GENERATION_MODELS,
|
**_TEXT_GENERATION_MODELS,
|
||||||
**_EMBEDDING_MODELS,
|
**_EMBEDDING_MODELS,
|
||||||
|
**_CROSS_ENCODER_MODELS,
|
||||||
**_MULTIMODAL_MODELS,
|
**_MULTIMODAL_MODELS,
|
||||||
**_SPECULATIVE_DECODING_MODELS,
|
**_SPECULATIVE_DECODING_MODELS,
|
||||||
}
|
}
|
||||||
@ -193,6 +204,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
|||||||
class _ModelInfo:
|
class _ModelInfo:
|
||||||
is_text_generation_model: bool
|
is_text_generation_model: bool
|
||||||
is_embedding_model: bool
|
is_embedding_model: bool
|
||||||
|
supports_cross_encoding: bool
|
||||||
supports_multimodal: bool
|
supports_multimodal: bool
|
||||||
supports_pp: bool
|
supports_pp: bool
|
||||||
has_inner_state: bool
|
has_inner_state: bool
|
||||||
@ -203,6 +215,7 @@ class _ModelInfo:
|
|||||||
return _ModelInfo(
|
return _ModelInfo(
|
||||||
is_text_generation_model=is_text_generation_model(model),
|
is_text_generation_model=is_text_generation_model(model),
|
||||||
is_embedding_model=is_embedding_model(model),
|
is_embedding_model=is_embedding_model(model),
|
||||||
|
supports_cross_encoding=supports_cross_encoding(model),
|
||||||
supports_multimodal=supports_multimodal(model),
|
supports_multimodal=supports_multimodal(model),
|
||||||
supports_pp=supports_pp(model),
|
supports_pp=supports_pp(model),
|
||||||
has_inner_state=has_inner_state(model),
|
has_inner_state=has_inner_state(model),
|
||||||
@ -415,6 +428,12 @@ class _ModelRegistry:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
return self.inspect_model_cls(architectures).is_embedding_model
|
return self.inspect_model_cls(architectures).is_embedding_model
|
||||||
|
|
||||||
|
def is_cross_encoder_model(
|
||||||
|
self,
|
||||||
|
architectures: Union[str, List[str]],
|
||||||
|
) -> bool:
|
||||||
|
return self.inspect_model_cls(architectures).supports_cross_encoding
|
||||||
|
|
||||||
def is_multimodal_model(
|
def is_multimodal_model(
|
||||||
self,
|
self,
|
||||||
architectures: Union[str, List[str]],
|
architectures: Union[str, List[str]],
|
||||||
@ -489,4 +508,4 @@ def _run() -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_run()
|
_run()
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -6,10 +6,17 @@ from transformers import RobertaConfig
|
|||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.layers.pooler import CrossEncodingPooler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||||
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
|
from vllm.transformers_utils.config import (
|
||||||
|
get_cross_encoder_activation_function)
|
||||||
|
|
||||||
|
|
||||||
class RobertaEmbedding(nn.Module):
|
class RobertaEmbedding(nn.Module):
|
||||||
@ -39,34 +46,93 @@ class RobertaEmbedding(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
seq_lens: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
|
|
||||||
# Input embeddings.
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
# TODO: figure out if there is a better way
|
# Replace position ids because in RoBERTa models
|
||||||
# to make to make position ids start at padding_idx + 1
|
# they have to start at padding_idx + 1 and ignore
|
||||||
|
# existing padding tokens
|
||||||
# References:
|
# References:
|
||||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||||
position_ids += self.padding_idx + 1
|
pos_list = []
|
||||||
|
token_list = []
|
||||||
|
offset = 0
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
pos_list.append(position_ids[offset:offset + seq_len])
|
||||||
|
token_list.append(input_ids[offset:offset + seq_len])
|
||||||
|
offset += seq_len
|
||||||
|
|
||||||
|
new_pos_list = []
|
||||||
|
for positions, tokens in zip(pos_list, token_list):
|
||||||
|
# Verify assumption that incoming position are
|
||||||
|
# always a sequence from 0 to N.
|
||||||
|
expected_pos = torch.arange(positions.size()[0],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=inputs_embeds.device)
|
||||||
|
assert torch.equal(positions, expected_pos)
|
||||||
|
new_pos_list.append(
|
||||||
|
create_position_ids_from_input_ids(tokens, self.padding_idx))
|
||||||
|
position_ids = torch.cat(new_pos_list)
|
||||||
|
|
||||||
# Position embeddings.
|
# Position embeddings.
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = torch.zeros(input_shape,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=inputs_embeds.device)
|
||||||
|
|
||||||
# Token type embeddings. (TODO: move off hotpath?)
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
token_type_embeddings = self.token_type_embeddings(
|
|
||||||
torch.zeros(input_shape,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=inputs_embeds.device))
|
|
||||||
|
|
||||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||||
embeddings = self.LayerNorm(embeddings)
|
embeddings = self.LayerNorm(embeddings)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers
|
||||||
|
def create_position_ids_from_input_ids(input_ids,
|
||||||
|
padding_idx,
|
||||||
|
past_key_values_length=0):
|
||||||
|
"""
|
||||||
|
Replace non-padding symbols with their position numbers.
|
||||||
|
Position numbers begin at padding_idx+1. Padding symbols
|
||||||
|
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: torch.Tensor x:
|
||||||
|
|
||||||
|
Returns: torch.Tensor
|
||||||
|
"""
|
||||||
|
# The series of casts and type-conversions here are carefully
|
||||||
|
# balanced to both work with ONNX export and XLA.
|
||||||
|
mask = input_ids.ne(padding_idx).int()
|
||||||
|
|
||||||
|
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
|
||||||
|
past_key_values_length) * mask
|
||||||
|
|
||||||
|
return incremental_indices.long() + padding_idx
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers
|
||||||
|
class RobertaClassificationHead(nn.Module):
|
||||||
|
"""Head for sentence-level classification tasks."""
|
||||||
|
|
||||||
|
def __init__(self, config: RobertaConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
def forward(self, features, **kwargs):
|
||||||
|
x = features[0, :] # take <s> token (equiv. to [CLS])
|
||||||
|
x = self.dense(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
x = self.out_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RobertaEmbeddingModel(BertEmbeddingModel):
|
class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||||
"""A model that uses Roberta to provide embedding functionalities.
|
"""A model that uses Roberta to provide embedding functionalities.
|
||||||
|
|
||||||
@ -85,6 +151,62 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
|||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
embedding_class=RobertaEmbedding)
|
embedding_class=RobertaEmbedding)
|
||||||
|
|
||||||
|
|
||||||
|
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||||
|
"""A model that uses Roberta to provide embedding functionalities.
|
||||||
|
|
||||||
|
This class encapsulates the BertModel and provides an interface for
|
||||||
|
embedding operations and customized pooling functions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
roberta: An instance of BertModel used for forward operations.
|
||||||
|
_pooler: An instance of Pooler used for pooling operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
|
self.default_activation_function = \
|
||||||
|
get_cross_encoder_activation_function(config)
|
||||||
|
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.roberta = BertModel(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "bert"),
|
||||||
|
embedding_class=RobertaEmbedding,
|
||||||
|
add_pooling_layer=False)
|
||||||
|
self.classifier = RobertaClassificationHead(config)
|
||||||
|
self._pooler = CrossEncodingPooler(config, self.classifier)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|
||||||
|
self_weights = []
|
||||||
|
|
||||||
|
def weight_filter():
|
||||||
|
for name, weight in weights:
|
||||||
|
if name.startswith("roberta."):
|
||||||
|
yield (name[len("roberta."):], weight)
|
||||||
|
else:
|
||||||
|
self_weights.append((name, weight))
|
||||||
|
|
||||||
|
self.roberta.load_weights(weight_filter())
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
for name, loaded_weight in self_weights:
|
||||||
|
if name.startswith("classifier"):
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return self._pooler(hidden_states, pooling_metadata)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
@ -93,25 +215,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
|||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
return self.roberta(input_ids=input_ids,
|
||||||
# Verify assumption that position are always a sequence from
|
position_ids=positions,
|
||||||
# 0 to N. (Actually here we just check 0 and N to simplify).
|
kv_caches=kv_caches,
|
||||||
# This is important to fix the position which are assumed to
|
inputs_embeds=inputs_embeds,
|
||||||
# start from padding_idx + 1 instead of 0 in the Roberta models.
|
intermediate_tensors=intermediate_tensors,
|
||||||
assert hasattr(attn_metadata, "seq_lens_tensor")
|
attn_metadata=attn_metadata,
|
||||||
cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
|
token_type_ids=token_type_ids)
|
||||||
start_pos = torch.cat(
|
|
||||||
(torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
|
|
||||||
cumulative[:-1]))
|
|
||||||
assert len(torch.nonzero(positions[start_pos])) == 0
|
|
||||||
end_pos = cumulative - 1
|
|
||||||
last_tokens = attn_metadata.seq_lens_tensor - 1
|
|
||||||
assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0
|
|
||||||
|
|
||||||
return super().forward(input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=inputs_embeds)
|
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.types
|
import torch.types
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import NotRequired, TypeAlias
|
||||||
|
|
||||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
||||||
|
|
||||||
@ -208,6 +208,9 @@ class MultiModalInputsV2(TypedDict):
|
|||||||
prompt_token_ids: List[int]
|
prompt_token_ids: List[int]
|
||||||
"""The processed token IDs which includes placeholder tokens."""
|
"""The processed token IDs which includes placeholder tokens."""
|
||||||
|
|
||||||
|
token_type_ids: NotRequired[List[int]]
|
||||||
|
"""The token type IDs of the prompt."""
|
||||||
|
|
||||||
mm_kwargs: MultiModalKwargs
|
mm_kwargs: MultiModalKwargs
|
||||||
"""Keyword arguments to be directly passed to the model after batching."""
|
"""Keyword arguments to be directly passed to the model after batching."""
|
||||||
|
|
||||||
|
|||||||
@ -60,7 +60,6 @@ class EmbeddingOutput:
|
|||||||
embedding: The embedding vector, which is a list of floats. The
|
embedding: The embedding vector, which is a list of floats. The
|
||||||
length of vector depends on the model as listed in the embedding guide.
|
length of vector depends on the model as listed in the embedding guide.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embedding: List[float]
|
embedding: List[float]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -363,6 +362,50 @@ class EmbeddingRequestOutput:
|
|||||||
f"finished={self.finished})")
|
f"finished={self.finished})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoreOutput:
|
||||||
|
"""The output data of one completion output of a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score: The score, which is a list of floats.
|
||||||
|
index: The correspondent text index of the score.
|
||||||
|
"""
|
||||||
|
index: int
|
||||||
|
score: List[float]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"ScoreOutput("
|
||||||
|
f"score={self.score}), "
|
||||||
|
f"index={self.index})")
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreRequestOutput:
|
||||||
|
"""
|
||||||
|
The output data of an score request to the LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id (str): A unique identifier for the score request.
|
||||||
|
outputs (score): The embedding results for the given input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str, outputs: "ScoreOutput"):
|
||||||
|
self.request_id = request_id
|
||||||
|
self.outputs = outputs
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
"""
|
||||||
|
Returns a string representation of an ScoreRequestOutput instance.
|
||||||
|
|
||||||
|
The representation includes the request_id and the number of outputs,
|
||||||
|
providing a quick overview of the embedding request's results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A string representation of the ScoreRequestOutput instance.
|
||||||
|
"""
|
||||||
|
return (f"ScoreRequestOutput(request_id='{self.request_id}', "
|
||||||
|
f"outputs={repr(self.outputs)}")
|
||||||
|
|
||||||
|
|
||||||
class RequestOutputFactory:
|
class RequestOutputFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -449,6 +449,10 @@ class Sequence:
|
|||||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||||
return self.inputs.prompt_embeds
|
return self.inputs.prompt_embeds
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_type_ids(self) -> List[int]:
|
||||||
|
return self.inputs.token_type_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||||
return self.inputs.multi_modal_data
|
return self.inputs.multi_modal_data
|
||||||
@ -687,6 +691,10 @@ class SequenceGroup:
|
|||||||
return (self.encoder_seq.prompt_token_ids
|
return (self.encoder_seq.prompt_token_ids
|
||||||
if self.encoder_seq is not None else None)
|
if self.encoder_seq is not None else None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_type_ids(self) -> Optional[List[int]]:
|
||||||
|
return self.first_seq.token_type_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multi_modal_data(self) -> MultiModalDataDict:
|
def multi_modal_data(self) -> MultiModalDataDict:
|
||||||
return self.first_seq.multi_modal_data
|
return self.first_seq.multi_modal_data
|
||||||
@ -909,6 +917,7 @@ class SequenceGroupMetadata(
|
|||||||
default_factory=lambda: SequenceGroupState())
|
default_factory=lambda: SequenceGroupState())
|
||||||
# "MultiModalDataDict" types. We have to use Any due to msgspec
|
# "MultiModalDataDict" types. We have to use Any due to msgspec
|
||||||
# doesn't allow to have union of 2 different dicts.
|
# doesn't allow to have union of 2 different dicts.
|
||||||
|
token_type_ids: Optional[List[int]] = None
|
||||||
multi_modal_data: Optional[Any] = None
|
multi_modal_data: Optional[Any] = None
|
||||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from huggingface_hub import (file_exists, hf_hub_download,
|
|||||||
from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError,
|
from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError,
|
||||||
RepositoryNotFoundError,
|
RepositoryNotFoundError,
|
||||||
RevisionNotFoundError)
|
RevisionNotFoundError)
|
||||||
|
from torch import nn
|
||||||
from transformers import GenerationConfig, PretrainedConfig
|
from transformers import GenerationConfig, PretrainedConfig
|
||||||
from transformers.models.auto.image_processing_auto import (
|
from transformers.models.auto.image_processing_auto import (
|
||||||
get_image_processor_config)
|
get_image_processor_config)
|
||||||
@ -31,6 +32,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
|||||||
UltravoxConfig)
|
UltravoxConfig)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
|
from vllm.utils import resolve_obj_by_qualname
|
||||||
|
|
||||||
if VLLM_USE_MODELSCOPE:
|
if VLLM_USE_MODELSCOPE:
|
||||||
from modelscope import AutoConfig
|
from modelscope import AutoConfig
|
||||||
@ -577,3 +579,16 @@ def try_get_generation_config(
|
|||||||
return GenerationConfig.from_model_config(config)
|
return GenerationConfig.from_model_config(config)
|
||||||
except OSError: # Not found
|
except OSError: # Not found
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||||
|
if (hasattr(config, "sbert_ce_default_activation_function")
|
||||||
|
and config.sbert_ce_default_activation_function is not None):
|
||||||
|
|
||||||
|
function_name = config.sbert_ce_default_activation_function
|
||||||
|
assert function_name.startswith("torch.nn.modules."), \
|
||||||
|
"Loading of activation functions is restricted to " \
|
||||||
|
"torch.nn.modules for security reasons"
|
||||||
|
return resolve_obj_by_qualname(function_name)()
|
||||||
|
else:
|
||||||
|
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
|
||||||
|
|||||||
@ -50,6 +50,9 @@ class CPUEmbeddingModelRunner(
|
|||||||
]
|
]
|
||||||
|
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
|
cross_enc_kwargs = {}
|
||||||
|
if model_input.token_type_ids is not None:
|
||||||
|
cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
|
||||||
execute_model_kwargs = {
|
execute_model_kwargs = {
|
||||||
"input_ids":
|
"input_ids":
|
||||||
model_input.input_tokens,
|
model_input.input_tokens,
|
||||||
@ -61,6 +64,7 @@ class CPUEmbeddingModelRunner(
|
|||||||
model_input.attn_metadata,
|
model_input.attn_metadata,
|
||||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||||
device=self.device),
|
device=self.device),
|
||||||
|
**cross_enc_kwargs,
|
||||||
"intermediate_tensors":
|
"intermediate_tensors":
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,6 +43,7 @@ class ModelInputForCPU(ModelRunnerInputBase):
|
|||||||
"""
|
"""
|
||||||
input_tokens: Optional[torch.Tensor] = None
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
input_positions: Optional[torch.Tensor] = None
|
input_positions: Optional[torch.Tensor] = None
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None
|
||||||
attn_metadata: Optional["AttentionMetadata"] = None
|
attn_metadata: Optional["AttentionMetadata"] = None
|
||||||
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||||
virtual_engine: Optional[int] = None
|
virtual_engine: Optional[int] = None
|
||||||
@ -54,6 +55,7 @@ class ModelInputForCPU(ModelRunnerInputBase):
|
|||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
"input_tokens": self.input_tokens,
|
"input_tokens": self.input_tokens,
|
||||||
"input_positions": self.input_positions,
|
"input_positions": self.input_positions,
|
||||||
|
"token_type_ids": self.token_type_ids,
|
||||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
}
|
}
|
||||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
@ -83,6 +85,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
|
|||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
"input_tokens": self.input_tokens,
|
"input_tokens": self.input_tokens,
|
||||||
"input_positions": self.input_positions,
|
"input_positions": self.input_positions,
|
||||||
|
"token_type_ids": self.token_type_ids,
|
||||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
}
|
}
|
||||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
@ -112,6 +115,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
|||||||
self.input_tokens: List[int] = []
|
self.input_tokens: List[int] = []
|
||||||
self.input_positions: Optional[
|
self.input_positions: Optional[
|
||||||
List[int]] = [] if not self.use_mrope else None
|
List[int]] = [] if not self.use_mrope else None
|
||||||
|
self.token_type_ids: Optional[List[int]] = []
|
||||||
self.seq_lens: List[int] = []
|
self.seq_lens: List[int] = []
|
||||||
self.query_lens: List[int] = []
|
self.query_lens: List[int] = []
|
||||||
self.prefill_block_tables: List[List[int]] = []
|
self.prefill_block_tables: List[List[int]] = []
|
||||||
@ -165,6 +169,10 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
|||||||
if not input_data.use_mrope else input_data.input_mrope_positions,
|
if not input_data.use_mrope else input_data.input_mrope_positions,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
|
token_type_ids = torch.tensor(input_data.token_type_ids,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cpu") \
|
||||||
|
if input_data.token_type_ids else None
|
||||||
|
|
||||||
# For multi-modal models
|
# For multi-modal models
|
||||||
multi_modal_kwargs = None
|
multi_modal_kwargs = None
|
||||||
@ -178,6 +186,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
|||||||
return self.model_input_cls(
|
return self.model_input_cls(
|
||||||
input_tokens=input_tokens,
|
input_tokens=input_tokens,
|
||||||
input_positions=input_positions,
|
input_positions=input_positions,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
seq_lens=input_data.seq_lens,
|
seq_lens=input_data.seq_lens,
|
||||||
query_lens=input_data.query_lens,
|
query_lens=input_data.query_lens,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
@ -285,6 +294,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
|||||||
tokens = seq_data.get_token_ids()
|
tokens = seq_data.get_token_ids()
|
||||||
tokens = tokens[context_len:seq_len]
|
tokens = tokens[context_len:seq_len]
|
||||||
token_positions = range(context_len, seq_len)
|
token_positions = range(context_len, seq_len)
|
||||||
|
token_types = seq_group_metadata.token_type_ids
|
||||||
|
|
||||||
# For encoder-only models, the block_table is None,
|
# For encoder-only models, the block_table is None,
|
||||||
# and there is no need to initialize the slot_mapping.
|
# and there is no need to initialize the slot_mapping.
|
||||||
@ -301,6 +311,9 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
|||||||
if data.input_positions is not None:
|
if data.input_positions is not None:
|
||||||
data.input_positions.extend(token_positions)
|
data.input_positions.extend(token_positions)
|
||||||
|
|
||||||
|
if data.token_type_ids is not None:
|
||||||
|
data.token_type_ids.extend(token_types if token_types else [])
|
||||||
|
|
||||||
# Update fields
|
# Update fields
|
||||||
data.input_tokens.extend(tokens)
|
data.input_tokens.extend(tokens)
|
||||||
data.num_prefills += 1
|
data.num_prefills += 1
|
||||||
|
|||||||
@ -97,6 +97,10 @@ class EmbeddingModelRunner(
|
|||||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||||
model_forward_start.record()
|
model_forward_start.record()
|
||||||
|
|
||||||
|
cross_enc_kwargs = {}
|
||||||
|
if model_input.token_types is not None:
|
||||||
|
cross_enc_kwargs["token_type_ids"] = model_input.token_types
|
||||||
|
|
||||||
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
with set_forward_context(model_input.attn_metadata, self.vllm_config):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
@ -105,7 +109,8 @@ class EmbeddingModelRunner(
|
|||||||
attn_metadata=model_input.attn_metadata,
|
attn_metadata=model_input.attn_metadata,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
device=self.device))
|
device=self.device),
|
||||||
|
**cross_enc_kwargs)
|
||||||
|
|
||||||
if (self.observability_config is not None
|
if (self.observability_config is not None
|
||||||
and self.observability_config.collect_model_forward_time):
|
and self.observability_config.collect_model_forward_time):
|
||||||
|
|||||||
@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|||||||
"""
|
"""
|
||||||
input_tokens: Optional[torch.Tensor] = None
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
input_positions: Optional[torch.Tensor] = None
|
input_positions: Optional[torch.Tensor] = None
|
||||||
|
token_types: Optional[torch.Tensor] = None
|
||||||
seq_lens: Optional[List[int]] = None
|
seq_lens: Optional[List[int]] = None
|
||||||
query_lens: Optional[List[int]] = None
|
query_lens: Optional[List[int]] = None
|
||||||
lora_mapping: Optional["LoRAMapping"] = None
|
lora_mapping: Optional["LoRAMapping"] = None
|
||||||
@ -200,6 +201,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
def simple_reinit(self):
|
def simple_reinit(self):
|
||||||
self.input_tokens[0].clear() # type: ignore
|
self.input_tokens[0].clear() # type: ignore
|
||||||
self.input_positions[0].clear() # type: ignore
|
self.input_positions[0].clear() # type: ignore
|
||||||
|
self.token_types[0].clear() # type: ignore
|
||||||
self.mrope_input_positions = None # type: ignore
|
self.mrope_input_positions = None # type: ignore
|
||||||
self.seq_lens[0] = 0 # type: ignore
|
self.seq_lens[0] = 0 # type: ignore
|
||||||
self.orig_seq_lens[0] = 0 # type: ignore
|
self.orig_seq_lens[0] = 0 # type: ignore
|
||||||
@ -226,6 +228,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
# Input tokens and positions.
|
# Input tokens and positions.
|
||||||
input_tokens: Optional[List[List[int]]] = None,
|
input_tokens: Optional[List[List[int]]] = None,
|
||||||
input_positions: Optional[List[List[int]]] = None,
|
input_positions: Optional[List[List[int]]] = None,
|
||||||
|
token_types: Optional[List[List[int]]] = None,
|
||||||
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||||||
|
|
||||||
# The sequence length (may be capped to the sliding window).
|
# The sequence length (may be capped to the sliding window).
|
||||||
@ -291,6 +294,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
for seq_id in range(len(self.seq_ids)):
|
for seq_id in range(len(self.seq_ids)):
|
||||||
self.input_positions[seq_id].clear()
|
self.input_positions[seq_id].clear()
|
||||||
|
|
||||||
|
if token_types:
|
||||||
|
self.token_types = token_types
|
||||||
|
else:
|
||||||
|
for seq_id in range(len(self.seq_ids)):
|
||||||
|
self.token_types[seq_id].clear()
|
||||||
|
|
||||||
self.mrope_input_positions = None
|
self.mrope_input_positions = None
|
||||||
|
|
||||||
if seq_lens:
|
if seq_lens:
|
||||||
@ -354,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
else:
|
else:
|
||||||
self.input_tokens = input_tokens or []
|
self.input_tokens = input_tokens or []
|
||||||
self.input_positions = input_positions or []
|
self.input_positions = input_positions or []
|
||||||
|
self.token_types = token_types or []
|
||||||
self.mrope_input_positions = mrope_input_positions or None
|
self.mrope_input_positions = mrope_input_positions or None
|
||||||
self.seq_lens = seq_lens or []
|
self.seq_lens = seq_lens or []
|
||||||
self.orig_seq_lens = orig_seq_lens or []
|
self.orig_seq_lens = orig_seq_lens or []
|
||||||
@ -386,6 +396,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
|
|
||||||
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
||||||
self.input_positions = [[] for _ in range(self.n_seqs)]
|
self.input_positions = [[] for _ in range(self.n_seqs)]
|
||||||
|
self.token_types = [[] for _ in range(self.n_seqs)]
|
||||||
self.mrope_input_positions = None
|
self.mrope_input_positions = None
|
||||||
self.seq_lens = [0] * self.n_seqs
|
self.seq_lens = [0] * self.n_seqs
|
||||||
self.orig_seq_lens = [0] * self.n_seqs
|
self.orig_seq_lens = [0] * self.n_seqs
|
||||||
@ -498,12 +509,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
|
|
||||||
# Compute tokens.
|
# Compute tokens.
|
||||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||||
|
token_types = seq_group_metadata.token_type_ids
|
||||||
|
|
||||||
inter_data.seq_lens[seq_idx] = seq_len
|
inter_data.seq_lens[seq_idx] = seq_len
|
||||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||||
inter_data.context_lens[seq_idx] = context_len
|
inter_data.context_lens[seq_idx] = context_len
|
||||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||||
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
||||||
|
inter_data.token_types[seq_idx].extend(
|
||||||
|
token_types if token_types else [])
|
||||||
inter_data.query_lens[seq_idx] = seq_len - context_len
|
inter_data.query_lens[seq_idx] = seq_len - context_len
|
||||||
|
|
||||||
if seq_data.mrope_position_delta is not None:
|
if seq_data.mrope_position_delta is not None:
|
||||||
@ -561,6 +575,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
seq_idx][uncomputed_start:]
|
seq_idx][uncomputed_start:]
|
||||||
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
||||||
seq_idx][uncomputed_start:]
|
seq_idx][uncomputed_start:]
|
||||||
|
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
|
||||||
|
uncomputed_start:]
|
||||||
context_len = prefix_cache_len
|
context_len = prefix_cache_len
|
||||||
|
|
||||||
inter_data.context_lens[seq_idx] = context_len
|
inter_data.context_lens[seq_idx] = context_len
|
||||||
@ -575,6 +591,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
seq_idx][-1:]
|
seq_idx][-1:]
|
||||||
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
inter_data.input_positions[seq_idx] = inter_data.input_positions[
|
||||||
seq_idx][-1:]
|
seq_idx][-1:]
|
||||||
|
inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
|
||||||
|
-1:]
|
||||||
inter_data.query_lens[seq_idx] = 1
|
inter_data.query_lens[seq_idx] = 1
|
||||||
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
|
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
|
||||||
|
|
||||||
@ -803,9 +821,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
"""
|
"""
|
||||||
# Combine and flatten intermediate data.
|
# Combine and flatten intermediate data.
|
||||||
input_tokens = []
|
input_tokens = []
|
||||||
|
token_types = []
|
||||||
for inter_data in self.inter_data_list:
|
for inter_data in self.inter_data_list:
|
||||||
for cur_input_tokens in inter_data.input_tokens:
|
for cur_input_tokens in inter_data.input_tokens:
|
||||||
input_tokens.extend(cur_input_tokens)
|
input_tokens.extend(cur_input_tokens)
|
||||||
|
for cur_token_types in inter_data.token_types:
|
||||||
|
token_types.extend(cur_token_types)
|
||||||
|
|
||||||
if not input_tokens:
|
if not input_tokens:
|
||||||
# This may happen when all prefill requests hit
|
# This may happen when all prefill requests hit
|
||||||
@ -874,6 +895,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
|
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
|
||||||
self.runner.device,
|
self.runner.device,
|
||||||
self.runner.pin_memory)
|
self.runner.pin_memory)
|
||||||
|
|
||||||
|
token_types_tensor = async_tensor_h2d(token_types, torch.long,
|
||||||
|
self.runner.device,
|
||||||
|
self.runner.pin_memory) \
|
||||||
|
if token_types else None
|
||||||
|
|
||||||
if mrope_input_positions is not None:
|
if mrope_input_positions is not None:
|
||||||
for idx in range(3):
|
for idx in range(3):
|
||||||
mrope_input_positions[idx].extend(
|
mrope_input_positions[idx].extend(
|
||||||
@ -952,6 +979,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
return self.model_input_cls(
|
return self.model_input_cls(
|
||||||
input_tokens=input_tokens_tensor,
|
input_tokens=input_tokens_tensor,
|
||||||
input_positions=input_positions_tensor,
|
input_positions=input_positions_tensor,
|
||||||
|
token_types=token_types_tensor,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
query_lens=query_lens,
|
query_lens=query_lens,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user