mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 18:45:35 +08:00
Feature/vllm/input embedding completion api (#17590)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Nan2018 <nan@protopia.ai> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Andrew Sansom <qthequartermasterman@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
9da1095daf
commit
221cfc2fea
@ -119,6 +119,7 @@ serving/offline_inference
|
|||||||
serving/openai_compatible_server
|
serving/openai_compatible_server
|
||||||
serving/serve_args
|
serving/serve_args
|
||||||
serving/multimodal_inputs
|
serving/multimodal_inputs
|
||||||
|
serving/prompt_embeds
|
||||||
serving/distributed_serving
|
serving/distributed_serving
|
||||||
serving/metrics
|
serving/metrics
|
||||||
serving/engine_args
|
serving/engine_args
|
||||||
|
|||||||
142
docs/source/serving/prompt_embeds.md
Normal file
142
docs/source/serving/prompt_embeds.md
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# Prompt Embedding Inputs
|
||||||
|
|
||||||
|
This page teaches you how to pass prompt embedding inputs to vLLM.
|
||||||
|
|
||||||
|
## What are prompt embeddings?
|
||||||
|
|
||||||
|
The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary.
|
||||||
|
|
||||||
|
:::{note}
|
||||||
|
Prompt embeddings are currently only supported in the v0 engine.
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Offline Inference
|
||||||
|
|
||||||
|
To input multi-modal data, follow this schema in {class}`vllm.inputs.EmbedsPrompt`:
|
||||||
|
|
||||||
|
- `prompt_embeds`: A torch tensor representing a sequence of prompt/token embeddings. This has the shape (sequence_length, hidden_size), where sequence length is the number of tokens embeddings and hidden_size is the hidden size (embedding size) of the model.
|
||||||
|
|
||||||
|
### Hugging Face Transformers Inputs
|
||||||
|
|
||||||
|
You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from vllm import LLM
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
|
||||||
|
# Transformers
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
||||||
|
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
|
||||||
|
llm = LLM(model=model_name, enable_prompt_embeds=True)
|
||||||
|
|
||||||
|
# Refer to the HuggingFace repo for the correct format to use
|
||||||
|
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
||||||
|
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
|
||||||
|
|
||||||
|
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||||
|
|
||||||
|
# Single prompt inference
|
||||||
|
outputs = llm.generate({
|
||||||
|
"prompt_embeds": prompt_embeds,
|
||||||
|
})
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
# Batch inference
|
||||||
|
|
||||||
|
chats = [
|
||||||
|
[{"role": "user", "content": "Please tell me about the capital of France."}],
|
||||||
|
[{"role": "user", "content": "When is the day longest during the year?"}],
|
||||||
|
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}]
|
||||||
|
]
|
||||||
|
|
||||||
|
token_ids_list = [
|
||||||
|
tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') for chat in chats
|
||||||
|
]
|
||||||
|
prompt_embeds_list = [embedding_layer(token_ids).squeeze(0) for token_ids in token_ids_list]
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"prompt_embeds": prompt_embeds,
|
||||||
|
} for prompt_embeds in prompt_embeds_list
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Online Serving
|
||||||
|
|
||||||
|
Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package.
|
||||||
|
|
||||||
|
When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first.
|
||||||
|
|
||||||
|
Prompt embeddings are passed in as base64 encoded torch tensors.
|
||||||
|
|
||||||
|
### Transformers Inputs via OpenAI Client
|
||||||
|
|
||||||
|
First, launch the OpenAI-compatible server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \
|
||||||
|
--max-model-len 4096 --enable-prompt-embeds
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you can use the OpenAI client as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
|
||||||
|
# Transformers
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
||||||
|
transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
|
||||||
|
|
||||||
|
# Refer to the HuggingFace repo for the correct format to use
|
||||||
|
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
||||||
|
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt')
|
||||||
|
|
||||||
|
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||||
|
|
||||||
|
# Prompt embeddings
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(prompt_embeds, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
binary_data = buffer.read()
|
||||||
|
encoded_embeds = base64.b64encode(binary_data).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
completion = client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
# NOTE: The OpenAI client does not allow `None` as an input to
|
||||||
|
# `prompt`. Use an empty string if you have no text prompts.
|
||||||
|
prompt="",
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
# NOTE: The OpenAI client allows passing in extra JSON body via the
|
||||||
|
# `extra_body` argument.
|
||||||
|
extra_body={"prompt_embeds": encoded_embeds}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(completion.choices[0].text)
|
||||||
|
```
|
||||||
257
tests/entrypoints/openai/test_completion_with_prompt_embeds.py
Normal file
257
tests/entrypoints/openai/test_completion_with_prompt_embeds.py
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import shutil
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
import torch
|
||||||
|
# downloading lora to test lora requests
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from openai import BadRequestError
|
||||||
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
# any model with a chat template should work here
|
||||||
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||||
|
|
||||||
|
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def zephyr_lora_files():
|
||||||
|
return snapshot_download(repo_id=LORA_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def zephyr_lora_added_tokens_files(zephyr_lora_files):
|
||||||
|
tmp_dir = TemporaryDirectory()
|
||||||
|
tmp_model_dir = f"{tmp_dir.name}/zephyr"
|
||||||
|
shutil.copytree(zephyr_lora_files, tmp_model_dir)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
# Copy tokenizer to adapter and add some unique tokens
|
||||||
|
# 32000, 32001, 32002
|
||||||
|
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
|
||||||
|
special_tokens=True)
|
||||||
|
assert added == 3
|
||||||
|
tokenizer.save_pretrained(tmp_model_dir)
|
||||||
|
yield tmp_model_dir
|
||||||
|
tmp_dir.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def default_server_args(
|
||||||
|
zephyr_lora_files,
|
||||||
|
zephyr_lora_added_tokens_files,
|
||||||
|
) -> list[str]:
|
||||||
|
return [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"128",
|
||||||
|
"--enforce-eager",
|
||||||
|
# Prompt Embeds server args
|
||||||
|
"--enable-prompt-embeds",
|
||||||
|
"--no-enable-chunked-prefill",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module",
|
||||||
|
params=["", "--disable-frontend-multiprocessing"])
|
||||||
|
def server_with_prompt_embeds(default_server_args, request):
|
||||||
|
if request.param:
|
||||||
|
default_server_args.append(request.param)
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client_with_prompt_embeds(server_with_prompt_embeds):
|
||||||
|
async with server_with_prompt_embeds.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_embeds(num_tokens: int = 5) -> str:
|
||||||
|
"""Create dummy embeddings and return them as base64 encoded string."""
|
||||||
|
dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size)
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(dummy_embeds, buffer)
|
||||||
|
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_completions_with_prompt_embeds(
|
||||||
|
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
||||||
|
# Test case: Single prompt embeds input
|
||||||
|
encoded_embeds = create_dummy_embeds()
|
||||||
|
completion = await client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="", # Add empty prompt as required parameter
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"prompt_embeds": encoded_embeds})
|
||||||
|
assert len(completion.choices[0].text) >= 1
|
||||||
|
assert completion.choices[0].prompt_logprobs is None
|
||||||
|
|
||||||
|
# Test case: batch completion with prompt_embeds
|
||||||
|
encoded_embeds2 = create_dummy_embeds()
|
||||||
|
completion = await client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="", # Add empty prompt as required parameter
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||||
|
assert len(completion.choices) == 2
|
||||||
|
assert len(completion.choices[0].text) >= 1
|
||||||
|
assert len(completion.choices[1].text) >= 1
|
||||||
|
|
||||||
|
# Test case: streaming with prompt_embeds
|
||||||
|
encoded_embeds = create_dummy_embeds()
|
||||||
|
single_completion = await client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="", # Add empty prompt as required parameter
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"prompt_embeds": encoded_embeds})
|
||||||
|
single_output = single_completion.choices[0].text
|
||||||
|
|
||||||
|
stream = await client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="", # Add empty prompt as required parameter
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
stream=True,
|
||||||
|
extra_body={"prompt_embeds": encoded_embeds})
|
||||||
|
chunks = []
|
||||||
|
finish_reason_count = 0
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks.append(chunk.choices[0].text)
|
||||||
|
if chunk.choices[0].finish_reason is not None:
|
||||||
|
finish_reason_count += 1
|
||||||
|
assert finish_reason_count == 1
|
||||||
|
assert chunk.choices[0].finish_reason == "length"
|
||||||
|
assert chunk.choices[0].text
|
||||||
|
assert "".join(chunks) == single_output
|
||||||
|
|
||||||
|
# Test case: batch streaming with prompt_embeds
|
||||||
|
encoded_embeds2 = create_dummy_embeds()
|
||||||
|
stream = await client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="", # Add empty prompt as required parameter
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
stream=True,
|
||||||
|
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||||
|
chunks_stream_embeds: list[list[str]] = [[], []]
|
||||||
|
finish_reason_count = 0
|
||||||
|
async for chunk in stream:
|
||||||
|
chunks_stream_embeds[chunk.choices[0].index].append(
|
||||||
|
chunk.choices[0].text)
|
||||||
|
if chunk.choices[0].finish_reason is not None:
|
||||||
|
finish_reason_count += 1
|
||||||
|
assert finish_reason_count == 2
|
||||||
|
assert chunk.choices[0].finish_reason == "length"
|
||||||
|
assert chunk.choices[0].text
|
||||||
|
assert len(chunks_stream_embeds[0]) > 0
|
||||||
|
assert len(chunks_stream_embeds[1]) > 0
|
||||||
|
|
||||||
|
# Test case: mixed text and prompt_embeds
|
||||||
|
encoded_embeds = create_dummy_embeds()
|
||||||
|
completion_mixed = await client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="This is a prompt",
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"prompt_embeds": encoded_embeds})
|
||||||
|
assert len(completion.choices) == 2
|
||||||
|
completion_text_only = await client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="This is a prompt",
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
completion_embeds_only = await client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="",
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"prompt_embeds": encoded_embeds})
|
||||||
|
# Embeddings responses should be handled first
|
||||||
|
assert completion_mixed.choices[0].text == completion_embeds_only.choices[
|
||||||
|
0].text
|
||||||
|
assert completion_mixed.choices[1].text == completion_text_only.choices[
|
||||||
|
0].text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_completions_errors_with_prompt_embeds(
|
||||||
|
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
||||||
|
# Test error case: invalid prompt_embeds
|
||||||
|
with pytest.raises(BadRequestError):
|
||||||
|
await client_with_prompt_embeds.completions.create(
|
||||||
|
prompt="",
|
||||||
|
model=model_name,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
extra_body={"prompt_embeds": "invalid_base64"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_completions_with_logprobs_and_prompt_embeds(
|
||||||
|
client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int,
|
||||||
|
model_name: str):
|
||||||
|
# Test case: Logprobs using prompt_embeds
|
||||||
|
encoded_embeds = create_dummy_embeds()
|
||||||
|
completion = await client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="", # Add empty prompt as required parameter
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
echo=False,
|
||||||
|
logprobs=logprobs_arg,
|
||||||
|
extra_body={"prompt_embeds": encoded_embeds})
|
||||||
|
|
||||||
|
logprobs = completion.choices[0].logprobs
|
||||||
|
assert logprobs is not None
|
||||||
|
assert len(logprobs.text_offset) == 5
|
||||||
|
assert len(logprobs.token_logprobs) == 5
|
||||||
|
assert len(logprobs.top_logprobs) == 5
|
||||||
|
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||||
|
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||||
|
assert len(logprobs.tokens) == 5
|
||||||
|
|
||||||
|
# Test case: Log probs with batch completion and prompt_embeds
|
||||||
|
encoded_embeds2 = create_dummy_embeds()
|
||||||
|
completion = await client_with_prompt_embeds.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt="", # Add empty prompt as required parameter
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
echo=False,
|
||||||
|
logprobs=logprobs_arg,
|
||||||
|
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||||
|
|
||||||
|
assert len(completion.choices) == 2
|
||||||
|
for choice in completion.choices:
|
||||||
|
logprobs = choice.logprobs
|
||||||
|
assert logprobs is not None
|
||||||
|
assert len(logprobs.text_offset) == 5
|
||||||
|
assert len(logprobs.token_logprobs) == 5
|
||||||
|
assert len(logprobs.top_logprobs) == 5
|
||||||
|
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||||
|
assert max(logprobs_arg,
|
||||||
|
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||||
|
assert len(logprobs.tokens) == 5
|
||||||
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
@ -23,6 +25,7 @@ class RequestLogger:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
prompt_token_ids: Optional[list[int]],
|
prompt_token_ids: Optional[list[int]],
|
||||||
|
prompt_embeds: Optional[torch.Tensor],
|
||||||
params: Optional[Union[SamplingParams, PoolingParams,
|
params: Optional[Union[SamplingParams, PoolingParams,
|
||||||
BeamSearchParams]],
|
BeamSearchParams]],
|
||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
@ -39,6 +42,8 @@ class RequestLogger:
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Received request %s: prompt: %r, "
|
"Received request %s: prompt: %r, "
|
||||||
"params: %s, prompt_token_ids: %s, "
|
"params: %s, prompt_token_ids: %s, "
|
||||||
|
"prompt_embeds shape: %s, "
|
||||||
"lora_request: %s, prompt_adapter_request: %s.", request_id,
|
"lora_request: %s, prompt_adapter_request: %s.", request_id,
|
||||||
prompt, params, prompt_token_ids, lora_request,
|
prompt, params, prompt_token_ids,
|
||||||
prompt_adapter_request)
|
prompt_embeds.shape if prompt_embeds is not None else None,
|
||||||
|
lora_request, prompt_adapter_request)
|
||||||
|
|||||||
@ -286,6 +286,9 @@ def validate_parsed_serve_args(args: argparse.Namespace):
|
|||||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||||
"--tool-call-parser")
|
"--tool-call-parser")
|
||||||
|
if args.enable_prompt_embeds and args.enable_prompt_adapter:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot use prompt embeds and prompt adapter at the same time.")
|
||||||
|
|
||||||
|
|
||||||
def log_non_default_args(args: argparse.Namespace):
|
def log_non_default_args(args: argparse.Namespace):
|
||||||
|
|||||||
@ -745,7 +745,8 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
# https://platform.openai.com/docs/api-reference/completions/create
|
# https://platform.openai.com/docs/api-reference/completions/create
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
prompt: Union[list[int], list[list[int]], str, list[str]]
|
prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
|
||||||
|
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
|
||||||
best_of: Optional[int] = None
|
best_of: Optional[int] = None
|
||||||
echo: Optional[bool] = False
|
echo: Optional[bool] = False
|
||||||
frequency_penalty: Optional[float] = 0.0
|
frequency_penalty: Optional[float] = 0.0
|
||||||
@ -1025,6 +1026,14 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_prompt_and_prompt_embeds(cls, data):
|
||||||
|
if data.get("prompt") is None and data.get("prompt_embeds") is None:
|
||||||
|
raise ValueError(
|
||||||
|
"At least one of `prompt` or `prompt_embeds` must be set.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import Optional, Union, cast
|
|||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
@ -25,8 +26,11 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||||||
UsageInfo)
|
UsageInfo)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||||
clamp_prompt_logprobs)
|
clamp_prompt_logprobs,
|
||||||
|
is_text_tokens_prompt)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||||
|
is_tokens_prompt)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
@ -90,6 +94,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"suffix is not currently supported")
|
"suffix is not currently supported")
|
||||||
|
|
||||||
|
if request.echo and request.prompt_embeds is not None:
|
||||||
|
return self.create_error_response(
|
||||||
|
"Echo is unsupported with prompt embeds.")
|
||||||
|
|
||||||
request_id = f"cmpl-{self._base_request_id(raw_request)}"
|
request_id = f"cmpl-{self._base_request_id(raw_request)}"
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
||||||
@ -130,8 +138,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
default_max_tokens = self.max_model_len - len(
|
# Mypy does not infer that engine_prompt will have only one of
|
||||||
engine_prompt["prompt_token_ids"])
|
# "prompt_token_ids" or "prompt_embeds" defined, and both of
|
||||||
|
# these as Union[object, the expected type], where it infers
|
||||||
|
# object if engine_prompt is a subclass of one of the
|
||||||
|
# typeddicts that defines both keys. Worse, because of
|
||||||
|
# https://github.com/python/mypy/issues/8586, mypy does not
|
||||||
|
# infer the type of engine_prompt correctly because of the
|
||||||
|
# enumerate. So we need an unnecessary cast here.
|
||||||
|
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||||
|
engine_prompt)
|
||||||
|
if is_embeds_prompt(engine_prompt):
|
||||||
|
input_length = len(engine_prompt["prompt_embeds"])
|
||||||
|
elif is_tokens_prompt(engine_prompt):
|
||||||
|
input_length = len(engine_prompt["prompt_token_ids"])
|
||||||
|
else:
|
||||||
|
assert_never(engine_prompt)
|
||||||
|
default_max_tokens = self.max_model_len - input_length
|
||||||
|
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
default_max_tokens, self.default_sampling_params)
|
default_max_tokens, self.default_sampling_params)
|
||||||
@ -152,6 +176,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
trace_headers = (None if raw_request is None else await
|
trace_headers = (None if raw_request is None else await
|
||||||
self._get_trace_headers(raw_request.headers))
|
self._get_trace_headers(raw_request.headers))
|
||||||
|
|
||||||
|
# Mypy inconsistently requires this second cast in different
|
||||||
|
# environments. It shouldn't be necessary (redundant from above)
|
||||||
|
# but pre-commit in CI fails without it.
|
||||||
|
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||||
|
engine_prompt)
|
||||||
if isinstance(sampling_params, BeamSearchParams):
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
generator = self.engine_client.beam_search(
|
generator = self.engine_client.beam_search(
|
||||||
prompt=engine_prompt,
|
prompt=engine_prompt,
|
||||||
@ -211,7 +240,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
# We did not pass it into vLLM engine to avoid being redundant
|
# We did not pass it into vLLM engine to avoid being redundant
|
||||||
# with the inputs token IDs
|
# with the inputs token IDs
|
||||||
if final_res.prompt is None:
|
if final_res.prompt is None:
|
||||||
final_res.prompt = request_prompts[i]["prompt"]
|
request_prompt = request_prompts[i]
|
||||||
|
if is_text_tokens_prompt(request_prompt):
|
||||||
|
final_res.prompt = request_prompt["prompt"]
|
||||||
|
else:
|
||||||
|
final_res.prompt = None
|
||||||
|
|
||||||
final_res_batch_checked = cast(list[RequestOutput],
|
final_res_batch_checked = cast(list[RequestOutput],
|
||||||
final_res_batch)
|
final_res_batch)
|
||||||
@ -276,8 +309,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
prompt_text = res.prompt
|
prompt_text = res.prompt
|
||||||
|
|
||||||
# Prompt details are excluded from later streamed outputs
|
# Prompt details are excluded from later streamed outputs
|
||||||
if res.prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
|
num_prompt_tokens[prompt_idx] = len(prompt_token_ids)
|
||||||
|
|
||||||
delta_token_ids: GenericSequence[int]
|
delta_token_ids: GenericSequence[int]
|
||||||
out_logprobs: Optional[GenericSequence[Optional[dict[
|
out_logprobs: Optional[GenericSequence[Optional[dict[
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@ -8,11 +9,18 @@ from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
|
|||||||
from concurrent.futures.thread import ThreadPoolExecutor
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
|
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
|
||||||
TypeVar, Union)
|
TypeVar, Union, cast, overload)
|
||||||
|
|
||||||
|
import torch
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from starlette.datastructures import Headers
|
from starlette.datastructures import Headers
|
||||||
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 12):
|
||||||
|
from typing import TypedDict
|
||||||
|
else:
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
@ -53,7 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.inputs import TokensPrompt
|
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||||
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -100,7 +109,22 @@ class TextTokensPrompt(TypedDict):
|
|||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
|
|
||||||
|
|
||||||
RequestPrompt = Union[list[int], str, TextTokensPrompt]
|
class EmbedsPrompt(TypedDict):
|
||||||
|
prompt_embeds: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
|
||||||
|
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||||
|
and "prompt_embeds" not in prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
|
||||||
|
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
|
||||||
|
and "prompt_embeds" in prompt)
|
||||||
|
|
||||||
|
|
||||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||||
|
|
||||||
@ -112,8 +136,9 @@ class RequestProcessingMixin(BaseModel):
|
|||||||
"""
|
"""
|
||||||
request_prompts: Optional[Sequence[RequestPrompt]] = \
|
request_prompts: Optional[Sequence[RequestPrompt]] = \
|
||||||
Field(default_factory=list)
|
Field(default_factory=list)
|
||||||
engine_prompts: Optional[list[TokensPrompt]] = \
|
engine_prompts: Optional[Union[list[EngineTokensPrompt],
|
||||||
Field(default_factory=list)
|
list[EngineEmbedsPrompt]]] = Field(
|
||||||
|
default_factory=list)
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@ -311,6 +336,12 @@ class OpenAIServing:
|
|||||||
lora_request=ctx.lora_request,
|
lora_request=ctx.lora_request,
|
||||||
prompt_adapter_request=ctx.prompt_adapter_request)
|
prompt_adapter_request=ctx.prompt_adapter_request)
|
||||||
|
|
||||||
|
# Mypy has an existing bug related to inferring the variance of
|
||||||
|
# TypedDicts with `builtins.enumerate`:
|
||||||
|
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
||||||
|
engine_prompt = cast(
|
||||||
|
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||||
|
engine_prompt)
|
||||||
generator = self.engine_client.encode(
|
generator = self.engine_client.encode(
|
||||||
engine_prompt,
|
engine_prompt,
|
||||||
pooling_params,
|
pooling_params,
|
||||||
@ -596,10 +627,11 @@ class OpenAIServing:
|
|||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
|
list[list[int]]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
add_special_tokens: bool = True,
|
add_special_tokens: bool = True,
|
||||||
) -> list[TextTokensPrompt]:
|
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
|
||||||
"""
|
"""
|
||||||
Tokenize/detokenize depending on the input format.
|
Tokenize/detokenize depending on the input format.
|
||||||
|
|
||||||
@ -607,11 +639,25 @@ class OpenAIServing:
|
|||||||
, each input can be a string or array of tokens. Note that each request
|
, each input can be a string or array of tokens. Note that each request
|
||||||
can pass one or more inputs.
|
can pass one or more inputs.
|
||||||
"""
|
"""
|
||||||
|
inputs_embeds = list[EmbedsPrompt]()
|
||||||
|
inputs_text = list[TextTokensPrompt]()
|
||||||
|
|
||||||
|
if (isinstance(request, CompletionRequest)
|
||||||
|
and request.prompt_embeds is not None):
|
||||||
|
inputs_embeds.extend(
|
||||||
|
self._load_prompt_embeds(request.prompt_embeds,
|
||||||
|
truncate_prompt_tokens))
|
||||||
|
|
||||||
|
# Empty prompts are okay as long as there are prompt embeddings
|
||||||
|
if input_or_inputs is None or (inputs_embeds
|
||||||
|
and input_or_inputs == ""):
|
||||||
|
return [], inputs_embeds
|
||||||
|
|
||||||
# Although our type checking is based on mypy,
|
# Although our type checking is based on mypy,
|
||||||
# VSCode Pyright extension should still work properly
|
# VSCode Pyright extension should still work properly
|
||||||
# "is True" is required for Pyright to perform type narrowing
|
# "is False" is required for Pyright to perform type narrowing
|
||||||
# See: https://github.com/microsoft/pyright/issues/7672
|
# See: https://github.com/microsoft/pyright/issues/7672
|
||||||
return [
|
inputs_text.extend([
|
||||||
self._normalize_prompt_text_to_input(
|
self._normalize_prompt_text_to_input(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -625,17 +671,56 @@ class OpenAIServing:
|
|||||||
prompt_ids=prompt_input["content"],
|
prompt_ids=prompt_input["content"],
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens)
|
truncate_prompt_tokens=truncate_prompt_tokens)
|
||||||
for prompt_input in parse_and_batch_prompt(input_or_inputs)
|
for prompt_input in parse_and_batch_prompt(input_or_inputs)
|
||||||
]
|
])
|
||||||
|
|
||||||
|
return inputs_text, inputs_embeds
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def _preprocess_completion(
|
||||||
|
self,
|
||||||
|
request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
|
||||||
|
RerankRequest, ClassificationRequest, ScoreRequest,
|
||||||
|
TokenizeCompletionRequest],
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
||||||
|
add_special_tokens: bool = ...,
|
||||||
|
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def _preprocess_completion(
|
||||||
|
self,
|
||||||
|
request: CompletionRequest,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
|
list[list[int]]]],
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
||||||
|
add_special_tokens: bool = ...,
|
||||||
|
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
|
||||||
|
EngineTokensPrompt, EngineEmbedsPrompt]]]:
|
||||||
|
...
|
||||||
|
|
||||||
async def _preprocess_completion(
|
async def _preprocess_completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionLikeRequest,
|
request: CompletionLikeRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
|
list[list[int]]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
add_special_tokens: bool = True,
|
add_special_tokens: bool = True,
|
||||||
) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
|
) -> tuple[Union[list[TextTokensPrompt], list[Union[
|
||||||
request_prompts = await self._tokenize_prompt_input_or_inputs_async(
|
TextTokensPrompt, EmbedsPrompt]]], Union[
|
||||||
|
list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
|
||||||
|
EngineEmbedsPrompt]]]]:
|
||||||
|
if not isinstance(request,
|
||||||
|
CompletionRequest) and input_or_inputs is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt embeds with non-completion requests is not"
|
||||||
|
" currently supported.")
|
||||||
|
|
||||||
|
(request_prompts_text, request_prompts_embeds
|
||||||
|
) = await self._tokenize_prompt_input_or_inputs_async(
|
||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
input_or_inputs,
|
input_or_inputs,
|
||||||
@ -643,11 +728,31 @@ class OpenAIServing:
|
|||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
engine_prompts = [
|
engine_prompts_text = [
|
||||||
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
|
EngineTokensPrompt(
|
||||||
for request_prompt in request_prompts
|
prompt_token_ids=request_prompt_text["prompt_token_ids"])
|
||||||
|
for request_prompt_text in request_prompts_text
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# This check is equivalent to simply checking if
|
||||||
|
# `request_prompts_embeds` is empty, but it's difficult to propagate
|
||||||
|
# overloads to the private helper functions to enable this check.
|
||||||
|
# This overload is needed because only TextPrompts are allowed for
|
||||||
|
# non-completion requests and if we don't add the overload here,
|
||||||
|
# everywhere this function is used outside of serving_completion will
|
||||||
|
# need logic asserting that only text prompts are in the request.
|
||||||
|
if not isinstance(request,
|
||||||
|
CompletionRequest) and input_or_inputs is not None:
|
||||||
|
return request_prompts_text, engine_prompts_text
|
||||||
|
|
||||||
|
engine_prompts_embeds = [
|
||||||
|
EngineEmbedsPrompt(
|
||||||
|
prompt_embeds=request_prompt_embeds["prompt_embeds"])
|
||||||
|
for request_prompt_embeds in request_prompts_embeds
|
||||||
|
]
|
||||||
|
|
||||||
|
request_prompts = request_prompts_embeds + request_prompts_text
|
||||||
|
engine_prompts = engine_prompts_embeds + engine_prompts_text
|
||||||
return request_prompts, engine_prompts
|
return request_prompts, engine_prompts
|
||||||
|
|
||||||
async def _preprocess_chat(
|
async def _preprocess_chat(
|
||||||
@ -666,7 +771,7 @@ class OpenAIServing:
|
|||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||||
add_special_tokens: bool = False,
|
add_special_tokens: bool = False,
|
||||||
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
||||||
list[TokensPrompt]]:
|
list[EngineTokensPrompt]]:
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
|
|
||||||
resolved_content_format = resolve_chat_template_content_format(
|
resolved_content_format = resolve_chat_template_content_format(
|
||||||
@ -739,7 +844,7 @@ class OpenAIServing:
|
|||||||
prompt=tokenizer.decode(request_prompt),
|
prompt=tokenizer.decode(request_prompt),
|
||||||
prompt_token_ids=request_prompt)
|
prompt_token_ids=request_prompt)
|
||||||
|
|
||||||
engine_prompt = TokensPrompt(
|
engine_prompt = EngineTokensPrompt(
|
||||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||||
if mm_data is not None:
|
if mm_data is not None:
|
||||||
engine_prompt["multi_modal_data"] = mm_data
|
engine_prompt["multi_modal_data"] = mm_data
|
||||||
@ -751,6 +856,35 @@ class OpenAIServing:
|
|||||||
|
|
||||||
return conversation, [request_prompt], [engine_prompt]
|
return conversation, [request_prompt], [engine_prompt]
|
||||||
|
|
||||||
|
def _load_prompt_embeds(
|
||||||
|
self,
|
||||||
|
prompt_embeds: Optional[Union[bytes, list[bytes]]],
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
|
) -> list[EmbedsPrompt]:
|
||||||
|
|
||||||
|
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
||||||
|
tensor = torch.load(io.BytesIO(base64.b64decode(embed)),
|
||||||
|
weights_only=True)
|
||||||
|
assert isinstance(
|
||||||
|
tensor,
|
||||||
|
(torch.FloatTensor, torch.BFloat16Tensor, torch.HalfTensor))
|
||||||
|
if tensor.dim() > 2:
|
||||||
|
tensor = tensor.squeeze(0)
|
||||||
|
assert tensor.dim() == 2
|
||||||
|
if truncate_prompt_tokens is not None:
|
||||||
|
tensor = tensor[-truncate_prompt_tokens:]
|
||||||
|
return {"prompt_embeds": tensor}
|
||||||
|
|
||||||
|
if prompt_embeds:
|
||||||
|
if isinstance(prompt_embeds, list):
|
||||||
|
return [
|
||||||
|
_load_and_validate_embed(embed) for embed in prompt_embeds
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [_load_and_validate_embed(prompt_embeds)]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
def _log_inputs(
|
def _log_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -762,13 +896,13 @@ class OpenAIServing:
|
|||||||
) -> None:
|
) -> None:
|
||||||
if self.request_logger is None:
|
if self.request_logger is None:
|
||||||
return
|
return
|
||||||
|
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
||||||
if isinstance(inputs, str):
|
if isinstance(inputs, str):
|
||||||
prompt = inputs
|
prompt = inputs
|
||||||
prompt_token_ids = None
|
|
||||||
elif isinstance(inputs, list):
|
elif isinstance(inputs, list):
|
||||||
prompt = None
|
|
||||||
prompt_token_ids = inputs
|
prompt_token_ids = inputs
|
||||||
|
elif 'prompt_embeds' in inputs:
|
||||||
|
prompt_embeds = inputs.get("prompt_embeds")
|
||||||
else:
|
else:
|
||||||
prompt = inputs["prompt"]
|
prompt = inputs["prompt"]
|
||||||
prompt_token_ids = inputs["prompt_token_ids"]
|
prompt_token_ids = inputs["prompt_token_ids"]
|
||||||
@ -777,6 +911,7 @@ class OpenAIServing:
|
|||||||
request_id,
|
request_id,
|
||||||
prompt,
|
prompt,
|
||||||
prompt_token_ids,
|
prompt_token_ids,
|
||||||
|
prompt_embeds,
|
||||||
params=params,
|
params=params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
|||||||
@ -106,7 +106,8 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
|
|
||||||
# Silently ignore prompt adapter since it does not affect
|
# Silently ignore prompt adapter since it does not affect
|
||||||
# tokenization (Unlike in Embeddings API where an error is raised)
|
# tokenization (Unlike in Embeddings API where an error is raised)
|
||||||
|
if isinstance(engine_prompt,
|
||||||
|
dict) and "prompt_token_ids" in engine_prompt:
|
||||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||||
|
|
||||||
return TokenizeResponse(tokens=input_ids,
|
return TokenizeResponse(tokens=input_ids,
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from collections.abc import Iterable
|
|||||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
|
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
|
||||||
@ -98,6 +98,17 @@ where the decoder-prompt is not specified explicitly, or
|
|||||||
more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt`
|
more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]:
|
||||||
|
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||||
|
and "prompt_embeds" not in prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]:
|
||||||
|
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
|
||||||
|
and "prompt_embeds" in prompt)
|
||||||
|
|
||||||
|
|
||||||
_T1_co = TypeVar("_T1_co",
|
_T1_co = TypeVar("_T1_co",
|
||||||
bound=SingletonPrompt,
|
bound=SingletonPrompt,
|
||||||
default=SingletonPrompt,
|
default=SingletonPrompt,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user