mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 13:25:30 +08:00
Extend renderer with embedding support and integrate completion endpoint (#24405)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
parent
9ad0688e43
commit
15cb047e25
@ -10,7 +10,7 @@ import pytest
|
|||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.renderer import BaseRenderer
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
@ -27,12 +27,16 @@ async def test_empty_prompt():
|
|||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
client = remote_server.get_async_client()
|
client = remote_server.get_async_client()
|
||||||
|
|
||||||
with pytest.raises(openai.BadRequestError,
|
with pytest.raises(
|
||||||
match="decoder prompt cannot be empty"):
|
openai.BadRequestError,
|
||||||
|
match=
|
||||||
|
"Either prompt or prompt_embeds must be provided and non-empty."
|
||||||
|
):
|
||||||
await client.completions.create(model=model_name,
|
await client.completions.create(model=model_name,
|
||||||
prompt="",
|
prompt="",
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0)
|
temperature=0.0,
|
||||||
|
extra_body={"prompt_embeds": []})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -83,7 +87,7 @@ def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout,
|
|||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
encoded_tensor = pybase64.b64encode(buffer.getvalue())
|
encoded_tensor = pybase64.b64encode(buffer.getvalue())
|
||||||
|
|
||||||
loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor)
|
loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor)
|
||||||
assert len(loaded_prompt_embeds) == 1
|
assert len(loaded_prompt_embeds) == 1
|
||||||
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
|
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
|
||||||
assert loaded_tensor.device.type == "cpu"
|
assert loaded_tensor.device.type == "cpu"
|
||||||
|
|||||||
@ -1,13 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import io
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pybase64
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.entrypoints.renderer import CompletionRenderer
|
from vllm.entrypoints.renderer import CompletionRenderer
|
||||||
|
from vllm.inputs.data import is_embeds_prompt
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -178,3 +182,132 @@ class TestRenderPrompt:
|
|||||||
with pytest.raises(ValueError, match="No tokenizer available"):
|
with pytest.raises(ValueError, match="No tokenizer available"):
|
||||||
await renderer_no_tokenizer.render_prompt(
|
await renderer_no_tokenizer.render_prompt(
|
||||||
prompt_or_prompts="Hello world", max_length=100)
|
prompt_or_prompts="Hello world", max_length=100)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_input_with_needs_detokenization(
|
||||||
|
self, renderer, mock_async_tokenizer):
|
||||||
|
# When needs_detokenization=True for token inputs, renderer should
|
||||||
|
# use the async tokenizer to decode and include the original text
|
||||||
|
# in the returned prompt object.
|
||||||
|
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
|
||||||
|
renderer.async_tokenizer_pool[
|
||||||
|
renderer.tokenizer] = mock_async_tokenizer
|
||||||
|
|
||||||
|
tokens = [1, 2, 3, 4]
|
||||||
|
results = await renderer.render_prompt(
|
||||||
|
prompt_or_prompts=tokens,
|
||||||
|
needs_detokenization=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["prompt_token_ids"] == tokens
|
||||||
|
assert results[0]["prompt"] == "decoded text"
|
||||||
|
mock_async_tokenizer.decode.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRenderEmbedPrompt:
|
||||||
|
|
||||||
|
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
|
||||||
|
"""Helper to create base64-encoded tensor bytes"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(tensor, buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
return pybase64.b64encode(buffer.read())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_prompt_embed(self, renderer):
|
||||||
|
# Create a test tensor
|
||||||
|
test_tensor = torch.randn(10, 768, dtype=torch.float32)
|
||||||
|
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||||
|
|
||||||
|
results = await renderer.render_prompt_and_embeds(
|
||||||
|
prompt_embeds=embed_bytes, cache_salt="test_salt")
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert is_embeds_prompt(results[0])
|
||||||
|
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
|
||||||
|
assert results[0]["cache_salt"] == "test_salt"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_prompt_embeds(self, renderer):
|
||||||
|
# Create multiple test tensors
|
||||||
|
test_tensors = [
|
||||||
|
torch.randn(8, 512, dtype=torch.float32),
|
||||||
|
torch.randn(12, 512, dtype=torch.float32),
|
||||||
|
]
|
||||||
|
embed_bytes_list = [
|
||||||
|
self._create_test_embed_bytes(t) for t in test_tensors
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await renderer.render_prompt_and_embeds(
|
||||||
|
prompt_embeds=embed_bytes_list)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
assert is_embeds_prompt(result)
|
||||||
|
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_embed_truncation(self, renderer):
|
||||||
|
# Create tensor with more tokens than truncation limit
|
||||||
|
test_tensor = torch.randn(20, 768, dtype=torch.float32)
|
||||||
|
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||||
|
|
||||||
|
results = await renderer.render_prompt_and_embeds(
|
||||||
|
prompt_embeds=embed_bytes, truncate_prompt_tokens=10)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
# Should keep last 10 tokens
|
||||||
|
expected = test_tensor[-10:]
|
||||||
|
assert torch.allclose(results[0]["prompt_embeds"], expected)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_embed_different_dtypes(self, renderer):
|
||||||
|
# Test different supported dtypes
|
||||||
|
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
test_tensor = torch.randn(5, 256, dtype=dtype)
|
||||||
|
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||||
|
|
||||||
|
results = await renderer.render_prompt_and_embeds(
|
||||||
|
prompt_embeds=embed_bytes)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["prompt_embeds"].dtype == dtype
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
|
||||||
|
# Test tensor with batch dimension gets squeezed
|
||||||
|
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
|
||||||
|
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||||
|
|
||||||
|
results = await renderer.render_prompt_and_embeds(
|
||||||
|
prompt_embeds=embed_bytes)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
# Should be squeezed to 2D
|
||||||
|
assert results[0]["prompt_embeds"].shape == (10, 768)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_both_prompts_and_embeds(self, renderer,
|
||||||
|
mock_async_tokenizer):
|
||||||
|
# Set up text tokenization
|
||||||
|
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||||
|
[101, 102, 103])
|
||||||
|
renderer.async_tokenizer_pool[
|
||||||
|
renderer.tokenizer] = mock_async_tokenizer
|
||||||
|
|
||||||
|
# Create embed
|
||||||
|
test_tensor = torch.randn(5, 256, dtype=torch.float32)
|
||||||
|
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||||
|
|
||||||
|
results = await renderer.render_prompt_and_embeds(
|
||||||
|
prompt_or_prompts="Hello world", prompt_embeds=embed_bytes)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
# First should be embed prompt
|
||||||
|
assert is_embeds_prompt(results[0])
|
||||||
|
# Second should be tokens prompt
|
||||||
|
assert "prompt_token_ids" in results[1]
|
||||||
|
assert results[1]["prompt_token_ids"] == [101, 102, 103]
|
||||||
|
|||||||
@ -686,7 +686,7 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
async def test_completion_with_empty_prompt_embeds(
|
async def test_completion_with_empty_prompt_embeds(
|
||||||
client: openai.AsyncOpenAI) -> None:
|
client: openai.AsyncOpenAI) -> None:
|
||||||
"""Test completion with empty prompt embeds."""
|
"""Test completion with empty prompt embeds."""
|
||||||
payload: dict[str, list] = {"prompt_embeds": []}
|
payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []}
|
||||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||||
# base_url = http://localhost:8000/v1/completions
|
# base_url = http://localhost:8000/v1/completions
|
||||||
response = requests.post(f"{client.base_url}completions",
|
response = requests.post(f"{client.base_url}completions",
|
||||||
|
|||||||
@ -1270,9 +1270,20 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_prompt_and_prompt_embeds(cls, data):
|
def validate_prompt_and_prompt_embeds(cls, data):
|
||||||
if data.get("prompt") is None and data.get("prompt_embeds") is None:
|
prompt = data.get("prompt")
|
||||||
|
prompt_embeds = data.get("prompt_embeds")
|
||||||
|
|
||||||
|
prompt_is_empty = (prompt is None
|
||||||
|
or (isinstance(prompt, str) and prompt == ""))
|
||||||
|
embeds_is_empty = (prompt_embeds is None
|
||||||
|
or (isinstance(prompt_embeds, list)
|
||||||
|
and len(prompt_embeds) == 0))
|
||||||
|
|
||||||
|
if prompt_is_empty and embeds_is_empty:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"At least one of `prompt` or `prompt_embeds` must be set.")
|
"Either prompt or prompt_embeds must be provided and non-empty."
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@ -26,12 +26,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||||||
PromptTokenUsageInfo,
|
PromptTokenUsageInfo,
|
||||||
RequestResponseMetadata,
|
RequestResponseMetadata,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (
|
|
||||||
EmbedsPrompt as ServingEngineEmbedsPrompt)
|
|
||||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||||
TextTokensPrompt,
|
clamp_prompt_logprobs)
|
||||||
clamp_prompt_logprobs,
|
|
||||||
is_text_tokens_prompt)
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.utils import get_max_tokens
|
from vllm.entrypoints.utils import get_max_tokens
|
||||||
@ -132,12 +128,19 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
else:
|
else:
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
||||||
)
|
)
|
||||||
|
renderer = self._get_renderer(tokenizer)
|
||||||
|
max_input_tokens_len = self.max_model_len - (request.max_tokens
|
||||||
|
or 0)
|
||||||
|
|
||||||
request_prompts, engine_prompts = await self._preprocess_completion(
|
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||||
request,
|
prompt_or_prompts=request.prompt,
|
||||||
tokenizer,
|
prompt_embeds=request.prompt_embeds,
|
||||||
request.prompt,
|
max_length=max_input_tokens_len,
|
||||||
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
cache_salt=request.cache_salt,
|
||||||
|
needs_detokenization=bool(request.echo
|
||||||
|
and not request.return_token_ids),
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.exception("Error in preprocessing prompt inputs")
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
@ -198,7 +201,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
self._log_inputs(
|
self._log_inputs(
|
||||||
request_id_item,
|
request_id_item,
|
||||||
request_prompts[i],
|
engine_prompt,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
@ -249,7 +252,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
if stream:
|
if stream:
|
||||||
return self.completion_stream_generator(
|
return self.completion_stream_generator(
|
||||||
request,
|
request,
|
||||||
request_prompts,
|
engine_prompts,
|
||||||
result_generator,
|
result_generator,
|
||||||
request_id,
|
request_id,
|
||||||
created_time,
|
created_time,
|
||||||
@ -273,11 +276,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
# We did not pass it into vLLM engine to avoid being redundant
|
# We did not pass it into vLLM engine to avoid being redundant
|
||||||
# with the inputs token IDs
|
# with the inputs token IDs
|
||||||
if final_res.prompt is None:
|
if final_res.prompt is None:
|
||||||
request_prompt = request_prompts[i]
|
engine_prompt = engine_prompts[i]
|
||||||
if is_text_tokens_prompt(request_prompt):
|
final_res.prompt = None if is_embeds_prompt(
|
||||||
final_res.prompt = request_prompt["prompt"]
|
engine_prompt) else engine_prompt.get("prompt")
|
||||||
else:
|
|
||||||
final_res.prompt = None
|
|
||||||
|
|
||||||
final_res_batch_checked = cast(list[RequestOutput],
|
final_res_batch_checked = cast(list[RequestOutput],
|
||||||
final_res_batch)
|
final_res_batch)
|
||||||
@ -313,8 +314,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
async def completion_stream_generator(
|
async def completion_stream_generator(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequest,
|
request: CompletionRequest,
|
||||||
request_prompts: list[Union[TextTokensPrompt,
|
engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]],
|
||||||
ServingEngineEmbedsPrompt]],
|
|
||||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||||
request_id: str,
|
request_id: str,
|
||||||
created_time: int,
|
created_time: int,
|
||||||
@ -350,14 +350,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
num_cached_tokens = res.num_cached_tokens
|
num_cached_tokens = res.num_cached_tokens
|
||||||
first_iteration = False
|
first_iteration = False
|
||||||
|
|
||||||
if res.prompt is not None:
|
prompt_text = res.prompt
|
||||||
prompt_text = res.prompt
|
if prompt_text is None:
|
||||||
else:
|
engine_prompt = engine_prompts[prompt_idx]
|
||||||
request_prompt = request_prompts[prompt_idx]
|
prompt_text = None if is_embeds_prompt(
|
||||||
if is_text_tokens_prompt(request_prompt):
|
engine_prompt) else engine_prompt.get("prompt")
|
||||||
prompt_text = request_prompt["prompt"]
|
|
||||||
else:
|
|
||||||
prompt_text = None
|
|
||||||
|
|
||||||
# Prompt details are excluded from later streamed outputs
|
# Prompt details are excluded from later streamed outputs
|
||||||
if prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
@ -378,6 +375,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
assert request.max_tokens is not None
|
assert request.max_tokens is not None
|
||||||
if request.echo and not has_echoed[i]:
|
if request.echo and not has_echoed[i]:
|
||||||
assert prompt_token_ids is not None
|
assert prompt_token_ids is not None
|
||||||
|
if request.return_token_ids:
|
||||||
|
prompt_text = ""
|
||||||
assert prompt_text is not None
|
assert prompt_text is not None
|
||||||
if request.max_tokens == 0:
|
if request.max_tokens == 0:
|
||||||
# only return the prompt
|
# only return the prompt
|
||||||
@ -525,6 +524,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
for output in final_res.outputs:
|
for output in final_res.outputs:
|
||||||
assert request.max_tokens is not None
|
assert request.max_tokens is not None
|
||||||
if request.echo:
|
if request.echo:
|
||||||
|
if request.return_token_ids:
|
||||||
|
prompt_text = ""
|
||||||
assert prompt_text is not None
|
assert prompt_text is not None
|
||||||
if request.max_tokens == 0:
|
if request.max_tokens == 0:
|
||||||
token_ids = prompt_token_ids
|
token_ids = prompt_token_ids
|
||||||
|
|||||||
@ -28,7 +28,6 @@ from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
|||||||
TextTokensPrompt)
|
TextTokensPrompt)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||||
@ -290,7 +289,7 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
async def _create_single_prompt_generator(
|
async def _create_single_prompt_generator(
|
||||||
self,
|
self,
|
||||||
ctx: EmbeddingServeContext,
|
ctx: EmbeddingServeContext,
|
||||||
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
engine_prompt: EngineTokensPrompt,
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
trace_headers: Optional[Mapping[str, str]],
|
trace_headers: Optional[Mapping[str, str]],
|
||||||
prompt_index: int,
|
prompt_index: int,
|
||||||
@ -303,12 +302,6 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
params=pooling_params,
|
params=pooling_params,
|
||||||
lora_request=ctx.lora_request)
|
lora_request=ctx.lora_request)
|
||||||
|
|
||||||
# Mypy has an existing bug related to inferring the variance
|
|
||||||
# of TypedDicts with `builtins.enumerate`:
|
|
||||||
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
|
||||||
engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
|
||||||
engine_prompt)
|
|
||||||
|
|
||||||
# Return the original generator without wrapping
|
# Return the original generator without wrapping
|
||||||
return self.engine_client.encode(
|
return self.engine_client.encode(
|
||||||
engine_prompt,
|
engine_prompt,
|
||||||
@ -375,12 +368,8 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Normal processing for short prompts or non-token prompts
|
# Normal processing for short prompts or non-token prompts
|
||||||
# Cast engine_prompt to the expected type for mypy
|
|
||||||
engine_prompt_typed = cast(
|
|
||||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
|
||||||
engine_prompt)
|
|
||||||
generator = await self._create_single_prompt_generator(
|
generator = await self._create_single_prompt_generator(
|
||||||
ctx, engine_prompt_typed, pooling_params, trace_headers, i)
|
ctx, engine_prompt, pooling_params, trace_headers, i)
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
|
|
||||||
from vllm.utils import merge_async_iterators
|
from vllm.utils import merge_async_iterators
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import asyncio
|
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@ -9,10 +7,8 @@ import traceback
|
|||||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
|
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
|
||||||
TypeVar, Union, cast, overload)
|
|
||||||
|
|
||||||
import pybase64
|
|
||||||
import torch
|
import torch
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
@ -64,10 +60,8 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
|||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer
|
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logprobs import Logprob, PromptLogprobs
|
from vllm.logprobs import Logprob, PromptLogprobs
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -149,8 +143,7 @@ class RequestProcessingMixin(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
request_prompts: Optional[Sequence[RequestPrompt]] = []
|
request_prompts: Optional[Sequence[RequestPrompt]] = []
|
||||||
engine_prompts: Optional[Union[list[EngineTokensPrompt],
|
engine_prompts: Optional[list[EngineTokensPrompt]] = []
|
||||||
list[EngineEmbedsPrompt]]] = []
|
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@ -368,13 +361,6 @@ class OpenAIServing:
|
|||||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||||
request_id_item = f"{ctx.request_id}-{i}"
|
request_id_item = f"{ctx.request_id}-{i}"
|
||||||
|
|
||||||
# Mypy has an existing bug related to inferring the variance of
|
|
||||||
# TypedDicts with `builtins.enumerate`:
|
|
||||||
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
|
||||||
engine_prompt = cast(
|
|
||||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
|
||||||
engine_prompt)
|
|
||||||
|
|
||||||
self._log_inputs(
|
self._log_inputs(
|
||||||
request_id_item,
|
request_id_item,
|
||||||
engine_prompt,
|
engine_prompt,
|
||||||
@ -737,170 +723,6 @@ class OpenAIServing:
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _tokenize_prompt_input_or_inputs_async(
|
|
||||||
self,
|
|
||||||
request: AnyRequest,
|
|
||||||
tokenizer: Optional[AnyTokenizer],
|
|
||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
|
||||||
list[list[int]]]],
|
|
||||||
add_special_tokens: bool = True,
|
|
||||||
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
|
|
||||||
"""
|
|
||||||
Tokenize/detokenize depending on the input format.
|
|
||||||
|
|
||||||
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
|
|
||||||
, each input can be a string or array of tokens. Note that each request
|
|
||||||
can pass one or more inputs.
|
|
||||||
"""
|
|
||||||
inputs_embeds = list[EmbedsPrompt]()
|
|
||||||
inputs_text = list[TextTokensPrompt]()
|
|
||||||
|
|
||||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
|
||||||
None)
|
|
||||||
|
|
||||||
if (truncate_prompt_tokens or 0) < 0:
|
|
||||||
truncate_prompt_tokens = self.max_model_len
|
|
||||||
|
|
||||||
if (isinstance(request, CompletionRequest)
|
|
||||||
and request.prompt_embeds is not None):
|
|
||||||
inputs_embeds.extend(
|
|
||||||
self._load_prompt_embeds(request.prompt_embeds,
|
|
||||||
truncate_prompt_tokens))
|
|
||||||
|
|
||||||
# Empty prompts are okay as long as there are prompt embeddings
|
|
||||||
if input_or_inputs is None or (inputs_embeds
|
|
||||||
and input_or_inputs == ""):
|
|
||||||
return [], inputs_embeds
|
|
||||||
|
|
||||||
# Although our type checking is based on mypy,
|
|
||||||
# VSCode Pyright extension should still work properly
|
|
||||||
# "is False" is required for Pyright to perform type narrowing
|
|
||||||
# See: https://github.com/microsoft/pyright/issues/7672
|
|
||||||
|
|
||||||
# Parse and batch the input prompts
|
|
||||||
batch_inputs = parse_and_batch_prompt(input_or_inputs)
|
|
||||||
|
|
||||||
# Process each input in the batch concurrently
|
|
||||||
tasks = []
|
|
||||||
for prompt_input in batch_inputs:
|
|
||||||
if prompt_input["is_tokens"] is False:
|
|
||||||
assert tokenizer is not None, (
|
|
||||||
"Tokenizer is required for text prompts")
|
|
||||||
task = self._normalize_prompt_text_to_input(
|
|
||||||
request,
|
|
||||||
prompt_input["content"],
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
add_special_tokens=add_special_tokens,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
task = self._normalize_prompt_tokens_to_input(
|
|
||||||
request, prompt_input["content"], tokenizer=tokenizer)
|
|
||||||
tasks.append(task)
|
|
||||||
|
|
||||||
# Wait for all tokenization tasks to complete
|
|
||||||
results = await asyncio.gather(*tasks)
|
|
||||||
inputs_text.extend(results)
|
|
||||||
|
|
||||||
return inputs_text, inputs_embeds
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def _preprocess_completion(
|
|
||||||
self,
|
|
||||||
request: Union[
|
|
||||||
DetokenizeRequest,
|
|
||||||
EmbeddingCompletionRequest,
|
|
||||||
RerankRequest,
|
|
||||||
ClassificationRequest,
|
|
||||||
ScoreRequest,
|
|
||||||
TokenizeCompletionRequest,
|
|
||||||
],
|
|
||||||
tokenizer: Optional[AnyTokenizer],
|
|
||||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
|
||||||
add_special_tokens: bool = ...,
|
|
||||||
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
async def _preprocess_completion(
|
|
||||||
self,
|
|
||||||
request: CompletionRequest,
|
|
||||||
tokenizer: Optional[AnyTokenizer],
|
|
||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
|
||||||
list[list[int]]]],
|
|
||||||
add_special_tokens: bool = ...,
|
|
||||||
) -> tuple[
|
|
||||||
list[Union[TextTokensPrompt, EmbedsPrompt]],
|
|
||||||
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
|
|
||||||
]:
|
|
||||||
...
|
|
||||||
|
|
||||||
async def _preprocess_completion(
|
|
||||||
self,
|
|
||||||
request: CompletionLikeRequest,
|
|
||||||
tokenizer: Optional[AnyTokenizer],
|
|
||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
|
||||||
list[list[int]]]],
|
|
||||||
add_special_tokens: bool = True,
|
|
||||||
) -> tuple[
|
|
||||||
Union[list[TextTokensPrompt], list[Union[TextTokensPrompt,
|
|
||||||
EmbedsPrompt]]],
|
|
||||||
Union[
|
|
||||||
list[EngineTokensPrompt],
|
|
||||||
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
|
|
||||||
],
|
|
||||||
]:
|
|
||||||
if (not isinstance(request, CompletionRequest)
|
|
||||||
and input_or_inputs is None):
|
|
||||||
raise ValueError(
|
|
||||||
"Prompt embeds with non-completion requests is not"
|
|
||||||
" currently supported.")
|
|
||||||
|
|
||||||
(
|
|
||||||
request_prompts_text,
|
|
||||||
request_prompts_embeds,
|
|
||||||
) = await self._tokenize_prompt_input_or_inputs_async(
|
|
||||||
request,
|
|
||||||
tokenizer,
|
|
||||||
input_or_inputs,
|
|
||||||
add_special_tokens=add_special_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
engine_prompts_text = [
|
|
||||||
EngineTokensPrompt(
|
|
||||||
prompt_token_ids=request_prompt_text["prompt_token_ids"])
|
|
||||||
for request_prompt_text in request_prompts_text
|
|
||||||
]
|
|
||||||
cache_salt = (request.cache_salt if
|
|
||||||
(hasattr(request, "cache_salt")
|
|
||||||
and request.cache_salt is not None) else None)
|
|
||||||
if cache_salt:
|
|
||||||
for prompt_text in engine_prompts_text:
|
|
||||||
prompt_text["cache_salt"] = cache_salt
|
|
||||||
|
|
||||||
# This check is equivalent to simply checking if
|
|
||||||
# `request_prompts_embeds` is empty, but it's difficult to propagate
|
|
||||||
# overloads to the private helper functions to enable this check.
|
|
||||||
# This overload is needed because only TextPrompts are allowed for
|
|
||||||
# non-completion requests and if we don't add the overload here,
|
|
||||||
# everywhere this function is used outside of serving_completion will
|
|
||||||
# need logic asserting that only text prompts are in the request.
|
|
||||||
if (not isinstance(request, CompletionRequest)
|
|
||||||
and input_or_inputs is not None):
|
|
||||||
return request_prompts_text, engine_prompts_text
|
|
||||||
|
|
||||||
engine_prompts_embeds = [
|
|
||||||
EngineEmbedsPrompt(
|
|
||||||
prompt_embeds=request_prompt_embeds["prompt_embeds"])
|
|
||||||
for request_prompt_embeds in request_prompts_embeds
|
|
||||||
]
|
|
||||||
if cache_salt:
|
|
||||||
for prompt_embed in engine_prompts_embeds:
|
|
||||||
prompt_embed["cache_salt"] = cache_salt
|
|
||||||
|
|
||||||
request_prompts = request_prompts_embeds + request_prompts_text
|
|
||||||
engine_prompts = engine_prompts_embeds + engine_prompts_text
|
|
||||||
return request_prompts, engine_prompts
|
|
||||||
|
|
||||||
async def _preprocess_chat(
|
async def _preprocess_chat(
|
||||||
self,
|
self,
|
||||||
request: Union[ChatLikeRequest, ResponsesRequest],
|
request: Union[ChatLikeRequest, ResponsesRequest],
|
||||||
@ -1073,41 +895,6 @@ class OpenAIServing:
|
|||||||
# OPTIMIZATION
|
# OPTIMIZATION
|
||||||
priority = orig_priority - 1
|
priority = orig_priority - 1
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load_prompt_embeds(
|
|
||||||
prompt_embeds: Optional[Union[bytes, list[bytes]]],
|
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
|
||||||
) -> list[EmbedsPrompt]:
|
|
||||||
|
|
||||||
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
|
||||||
tensor = torch.load(
|
|
||||||
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
|
||||||
weights_only=True,
|
|
||||||
map_location=torch.device("cpu"),
|
|
||||||
)
|
|
||||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
|
||||||
torch.float32,
|
|
||||||
torch.bfloat16,
|
|
||||||
torch.float16,
|
|
||||||
)
|
|
||||||
tensor = tensor.to_dense()
|
|
||||||
if tensor.dim() > 2:
|
|
||||||
tensor = tensor.squeeze(0)
|
|
||||||
assert tensor.dim() == 2
|
|
||||||
if truncate_prompt_tokens is not None:
|
|
||||||
tensor = tensor[-truncate_prompt_tokens:]
|
|
||||||
return {"prompt_embeds": tensor}
|
|
||||||
|
|
||||||
if prompt_embeds:
|
|
||||||
if isinstance(prompt_embeds, list):
|
|
||||||
return [
|
|
||||||
_load_and_validate_embed(embed) for embed in prompt_embeds
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return [_load_and_validate_embed(prompt_embeds)]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _log_inputs(
|
def _log_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
|||||||
@ -2,12 +2,16 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import io
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Annotated, Optional, Union
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
|
import pybase64
|
||||||
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
@ -49,37 +53,121 @@ class BaseRenderer(ABC):
|
|||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
add_special_tokens: Optional[bool] = True,
|
add_special_tokens: Optional[bool] = True,
|
||||||
cache_salt: Optional[str] = None,
|
cache_salt: Optional[str] = None,
|
||||||
|
needs_detokenization: Optional[bool] = False,
|
||||||
) -> list[EngineTokensPrompt]:
|
) -> list[EngineTokensPrompt]:
|
||||||
"""
|
"""
|
||||||
Convert input prompts into tokenized format for engine processing.
|
Convert text or token inputs into engine-ready TokensPrompt objects.
|
||||||
|
|
||||||
This is the core method that transforms various input formats into
|
This method accepts text or token inputs and produces a
|
||||||
standardized TokensPrompt objects. Implementations should handle
|
list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
|
||||||
tokenization, special token insertion, truncation, and validation
|
for the engine.
|
||||||
according to model requirements.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt_or_prompts: Input data in various formats:
|
prompt_or_prompts: One of:
|
||||||
- str: Single text prompt
|
- ``str``: Single text prompt.
|
||||||
- list[str]: Batch of text prompts
|
- ``list[str]``: Batch of text prompts.
|
||||||
- list[int]: Pre-tokenized sequence
|
- ``list[int]``: Single pre-tokenized sequence.
|
||||||
- list[list[int]]: Batch of pre-tokenized sequences
|
- ``list[list[int]]``: Batch of pre-tokenized sequences.
|
||||||
max_length: Maximum sequence length (endpoint-specific behavior)
|
max_length: Maximum allowable total input token length. If provided,
|
||||||
truncate_prompt_tokens: Truncate to last N tokens
|
token inputs longer than this raise ``ValueError``.
|
||||||
(None=no truncation, 0=empty)
|
truncate_prompt_tokens: Number of tokens to keep. ``None`` means no
|
||||||
add_special_tokens: Add model-specific tokens (e.g., [CLS], [SEP])
|
truncation. ``0`` yields an empty list (and skips embeds).
|
||||||
to text inputs
|
``-1`` maps to ``model_config.max_model_len``.
|
||||||
cache_salt: Optional string to disambiguate cached prompts
|
add_special_tokens: Whether to add model-specific special tokens
|
||||||
|
during text tokenization.
|
||||||
|
cache_salt: Optional string to disambiguate prefix cache entries.
|
||||||
|
needs_detokenization: If True and ``prompt_or_prompts`` is token
|
||||||
|
input, detokenize IDs back to text for inclusion in outputs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[EngineTokensPrompt]: Tokenized prompts ready for engine
|
list[EngineTokensPrompt]: Engine-ready token prompts.
|
||||||
consumption
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input format is invalid or length limits exceeded
|
ValueError: If input formats are invalid or length limits exceeded.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def render_prompt_and_embeds(
|
||||||
|
self,
|
||||||
|
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||||
|
list[list[int]]]] = None,
|
||||||
|
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
|
add_special_tokens: Optional[bool] = True,
|
||||||
|
cache_salt: Optional[str] = None,
|
||||||
|
needs_detokenization: Optional[bool] = False,
|
||||||
|
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||||
|
"""
|
||||||
|
Convert text/token and/or base64-encoded embeddings inputs into
|
||||||
|
engine-ready prompt objects.
|
||||||
|
|
||||||
|
At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be
|
||||||
|
provided and non-empty. If both are omitted or empty (e.g., empty
|
||||||
|
string and empty list), a ``ValueError`` is raised.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_or_prompts: Text or token inputs to include.
|
||||||
|
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
|
||||||
|
torch-saved tensor to be used as prompt embeddings.
|
||||||
|
max_length: Maximum allowable total input token length. If provided,
|
||||||
|
inputs longer than this raise ``ValueError``.
|
||||||
|
truncate_prompt_tokens: Number of tokens/rows to keep from the end
|
||||||
|
of the sequence. ``-1`` maps to ``model_config.max_model_len``.
|
||||||
|
add_special_tokens: Whether to add model-specific special tokens
|
||||||
|
during text tokenization.
|
||||||
|
cache_salt: Optional string to disambiguate prefix cache entries.
|
||||||
|
needs_detokenization: If True and ``prompt_or_prompts`` is token
|
||||||
|
input, detokenize IDs back to text for inclusion in outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||||
|
Engine-ready prompt objects.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds``
|
||||||
|
are omitted or empty (decoder prompt cannot be empty), or if
|
||||||
|
length limits are exceeded.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_prompt_embeds(
|
||||||
|
cls,
|
||||||
|
prompt_embeds: Union[bytes, list[bytes]],
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
|
||||||
|
cache_salt: Optional[str] = None,
|
||||||
|
) -> list[EngineEmbedsPrompt]:
|
||||||
|
"""Load and validate base64-encoded embeddings into prompt objects."""
|
||||||
|
|
||||||
|
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
|
||||||
|
tensor = torch.load(
|
||||||
|
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||||
|
weights_only=True,
|
||||||
|
map_location=torch.device("cpu"),
|
||||||
|
)
|
||||||
|
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||||
|
torch.float32,
|
||||||
|
torch.bfloat16,
|
||||||
|
torch.float16,
|
||||||
|
)
|
||||||
|
tensor = tensor.to_dense()
|
||||||
|
if tensor.dim() > 2:
|
||||||
|
tensor = tensor.squeeze(0)
|
||||||
|
assert tensor.dim() == 2
|
||||||
|
if truncate_prompt_tokens is not None:
|
||||||
|
tensor = tensor[-truncate_prompt_tokens:]
|
||||||
|
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
|
||||||
|
if cache_salt is not None:
|
||||||
|
embeds_prompt["cache_salt"] = cache_salt
|
||||||
|
return embeds_prompt
|
||||||
|
|
||||||
|
if isinstance(prompt_embeds, list):
|
||||||
|
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
|
||||||
|
else:
|
||||||
|
return [_load_and_validate_embed(prompt_embeds)]
|
||||||
|
|
||||||
|
|
||||||
class CompletionRenderer(BaseRenderer):
|
class CompletionRenderer(BaseRenderer):
|
||||||
|
|
||||||
@ -101,50 +189,110 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
add_special_tokens: Optional[bool] = True,
|
add_special_tokens: Optional[bool] = True,
|
||||||
cache_salt: Optional[str] = None,
|
cache_salt: Optional[str] = None,
|
||||||
|
needs_detokenization: Optional[bool] = False,
|
||||||
) -> list[EngineTokensPrompt]:
|
) -> list[EngineTokensPrompt]:
|
||||||
"""Implementation of prompt rendering for completion-style requests.
|
"""Implementation of prompt rendering for completion-style requests.
|
||||||
|
|
||||||
Uses async tokenizer pooling for improved performance. See base class
|
Uses async tokenizer pooling for improved performance. See base class
|
||||||
for detailed parameter documentation.
|
for detailed parameter documentation.
|
||||||
"""
|
"""
|
||||||
if truncate_prompt_tokens is not None:
|
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||||
if truncate_prompt_tokens == 0:
|
truncate_prompt_tokens, max_length)
|
||||||
return []
|
if truncate_prompt_tokens == 0:
|
||||||
if truncate_prompt_tokens < 0:
|
return []
|
||||||
truncate_prompt_tokens = self.model_config.max_model_len
|
|
||||||
if max_length is not None and truncate_prompt_tokens > max_length:
|
|
||||||
raise ValueError(
|
|
||||||
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
|
||||||
f"cannot be greater than max_length ({max_length}). "
|
|
||||||
f"Please select a smaller truncation size.")
|
|
||||||
|
|
||||||
# Parse and batch the input prompts
|
# Parse and batch the input prompts
|
||||||
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
||||||
|
|
||||||
rendered_prompts: list[EngineTokensPrompt] = []
|
tasks = []
|
||||||
tokenize_tasks = []
|
|
||||||
for prompt_input in batch_inputs:
|
for prompt_input in batch_inputs:
|
||||||
if prompt_input["is_tokens"] is True:
|
if prompt_input["is_tokens"] is True:
|
||||||
# Token input
|
# Token input
|
||||||
token_ids = self._maybe_apply_truncation(
|
detokenize_task = asyncio.create_task(
|
||||||
prompt_input["content"], truncate_prompt_tokens)
|
# Note: detokenization is needed when echo is enabled,
|
||||||
rendered_prompts.append(
|
# where the input token IDs are decoded back to text.
|
||||||
self._create_tokens_prompt(token_ids, max_length,
|
self._maybe_detokenize(prompt_input["content"], max_length,
|
||||||
cache_salt))
|
truncate_prompt_tokens, cache_salt,
|
||||||
|
needs_detokenization))
|
||||||
|
tasks.append(detokenize_task)
|
||||||
else:
|
else:
|
||||||
# Text input
|
# Text input
|
||||||
tokenize_task = asyncio.create_task(
|
tokenize_task = asyncio.create_task(
|
||||||
self._tokenize(prompt_input["content"], max_length,
|
self._tokenize(prompt_input["content"], max_length,
|
||||||
truncate_prompt_tokens, add_special_tokens,
|
truncate_prompt_tokens, add_special_tokens,
|
||||||
cache_salt))
|
cache_salt))
|
||||||
tokenize_tasks.append(tokenize_task)
|
tasks.append(tokenize_task)
|
||||||
|
|
||||||
# Wait for all text tokenization to finish
|
# Wait for all text tokenization to finish
|
||||||
if tokenize_tasks:
|
if tasks:
|
||||||
tokenized_text_prompts = await asyncio.gather(*tokenize_tasks)
|
tokenized_text_prompts = await asyncio.gather(*tasks)
|
||||||
rendered_prompts.extend(tokenized_text_prompts)
|
return tokenized_text_prompts
|
||||||
|
|
||||||
return rendered_prompts
|
return []
|
||||||
|
|
||||||
|
async def render_prompt_and_embeds(
|
||||||
|
self,
|
||||||
|
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||||
|
list[list[int]]]] = None,
|
||||||
|
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||||
|
add_special_tokens: Optional[bool] = True,
|
||||||
|
cache_salt: Optional[str] = None,
|
||||||
|
needs_detokenization: Optional[bool] = False,
|
||||||
|
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||||
|
"""
|
||||||
|
Render text/token prompts and/or precomputed embedding prompts. At
|
||||||
|
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
||||||
|
"""
|
||||||
|
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||||
|
truncate_prompt_tokens, max_length)
|
||||||
|
if truncate_prompt_tokens == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []
|
||||||
|
|
||||||
|
if prompt_embeds is not None:
|
||||||
|
rendered.extend(
|
||||||
|
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
|
||||||
|
cache_salt))
|
||||||
|
if prompt_or_prompts is None or prompt_or_prompts == "":
|
||||||
|
return rendered
|
||||||
|
|
||||||
|
token_prompts = await self.render_prompt(
|
||||||
|
prompt_or_prompts=prompt_or_prompts,
|
||||||
|
max_length=max_length,
|
||||||
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
cache_salt=cache_salt,
|
||||||
|
needs_detokenization=needs_detokenization,
|
||||||
|
)
|
||||||
|
rendered.extend(token_prompts)
|
||||||
|
|
||||||
|
return rendered
|
||||||
|
|
||||||
|
def _validate_and_normalize_truncate_tokens(
|
||||||
|
self,
|
||||||
|
truncate_prompt_tokens: Optional[int],
|
||||||
|
max_length: Optional[int],
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Validate and normalize truncate_prompt_tokens parameter."""
|
||||||
|
if truncate_prompt_tokens is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if truncate_prompt_tokens == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if truncate_prompt_tokens < 0:
|
||||||
|
truncate_prompt_tokens = self.model_config.max_model_len
|
||||||
|
|
||||||
|
if max_length is not None and truncate_prompt_tokens > max_length:
|
||||||
|
raise ValueError(
|
||||||
|
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
||||||
|
f"cannot be greater than max_length ({max_length}). "
|
||||||
|
f"Please select a smaller truncation size.")
|
||||||
|
|
||||||
|
return truncate_prompt_tokens
|
||||||
|
|
||||||
def _maybe_apply_truncation(
|
def _maybe_apply_truncation(
|
||||||
self, token_ids: list[int],
|
self, token_ids: list[int],
|
||||||
@ -186,7 +334,29 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
max_length=truncate_prompt_tokens)
|
max_length=truncate_prompt_tokens)
|
||||||
|
|
||||||
return self._create_tokens_prompt(encoded.input_ids, max_length,
|
return self._create_tokens_prompt(encoded.input_ids, max_length,
|
||||||
cache_salt)
|
cache_salt, text)
|
||||||
|
|
||||||
|
async def _maybe_detokenize(
|
||||||
|
self,
|
||||||
|
token_ids: list[int],
|
||||||
|
max_length: Optional[int],
|
||||||
|
truncate_prompt_tokens: Optional[int],
|
||||||
|
cache_salt: Optional[str],
|
||||||
|
needs_detokenization: Optional[bool] = False,
|
||||||
|
) -> EngineTokensPrompt:
|
||||||
|
"""Optionally detokenize token IDs and build a tokens prompt."""
|
||||||
|
token_ids = self._maybe_apply_truncation(token_ids,
|
||||||
|
truncate_prompt_tokens)
|
||||||
|
|
||||||
|
prompt = None
|
||||||
|
if needs_detokenization is True:
|
||||||
|
async_tokenizer = self._get_async_tokenizer()
|
||||||
|
prompt = await async_tokenizer.decode(token_ids)
|
||||||
|
|
||||||
|
return self._create_tokens_prompt(token_ids=token_ids,
|
||||||
|
max_length=max_length,
|
||||||
|
cache_salt=cache_salt,
|
||||||
|
prompt=prompt)
|
||||||
|
|
||||||
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||||
"""Get or create async tokenizer using shared pool."""
|
"""Get or create async tokenizer using shared pool."""
|
||||||
@ -210,6 +380,7 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
cache_salt: Optional[str] = None,
|
cache_salt: Optional[str] = None,
|
||||||
|
prompt: Optional[str] = None,
|
||||||
) -> EngineTokensPrompt:
|
) -> EngineTokensPrompt:
|
||||||
"""Create validated EngineTokensPrompt."""
|
"""Create validated EngineTokensPrompt."""
|
||||||
if max_length is not None and len(token_ids) > max_length:
|
if max_length is not None and len(token_ids) > max_length:
|
||||||
@ -221,4 +392,6 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
|
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
|
||||||
if cache_salt is not None:
|
if cache_salt is not None:
|
||||||
tokens_prompt["cache_salt"] = cache_salt
|
tokens_prompt["cache_salt"] = cache_salt
|
||||||
return tokens_prompt
|
if prompt is not None:
|
||||||
|
tokens_prompt["prompt"] = prompt
|
||||||
|
return tokens_prompt
|
||||||
@ -52,6 +52,9 @@ class TokensPrompt(TypedDict):
|
|||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
"""A list of token IDs to pass to the model."""
|
"""A list of token IDs to pass to the model."""
|
||||||
|
|
||||||
|
prompt: NotRequired[str]
|
||||||
|
"""The prompt text corresponding to the token IDs, if available."""
|
||||||
|
|
||||||
token_type_ids: NotRequired[list[int]]
|
token_type_ids: NotRequired[list[int]]
|
||||||
"""A list of token type IDs to pass to the cross encoder model."""
|
"""A list of token type IDs to pass to the cross encoder model."""
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user