mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (#6227)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
8a74c68bd1
commit
e2fbaee725
@ -1,6 +1,5 @@
|
||||
import os
|
||||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
@ -50,23 +49,9 @@ TEST_MESSAGES = [
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockTokenizer:
|
||||
chat_template = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockServingChat:
|
||||
tokenizer: MockTokenizer
|
||||
|
||||
|
||||
def test_load_chat_template():
|
||||
# Testing chatml template
|
||||
tokenizer = MockTokenizer()
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path)
|
||||
|
||||
template_content = tokenizer.chat_template
|
||||
template_content = load_chat_template(chat_template=chatml_jinja_path)
|
||||
|
||||
# Test assertions
|
||||
assert template_content is not None
|
||||
@ -78,22 +63,16 @@ def test_load_chat_template():
|
||||
def test_no_load_chat_template_filelike():
|
||||
# Testing chatml template
|
||||
template = "../../examples/does_not_exist"
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
|
||||
with pytest.raises(ValueError, match="looks like a file path"):
|
||||
load_chat_template(mock_serving_chat, chat_template=template)
|
||||
load_chat_template(chat_template=template)
|
||||
|
||||
|
||||
def test_no_load_chat_template_literallike():
|
||||
# Testing chatml template
|
||||
template = "{{ messages }}"
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
load_chat_template(mock_serving_chat, chat_template=template)
|
||||
template_content = tokenizer.chat_template
|
||||
template_content = load_chat_template(chat_template=template)
|
||||
|
||||
assert template_content == template
|
||||
|
||||
@ -105,8 +84,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
expected_output):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||
mock_serving_chat = MockServingChat(tokenizer)
|
||||
load_chat_template(mock_serving_chat, chat_template=template)
|
||||
template_content = load_chat_template(chat_template=template)
|
||||
|
||||
# Create a mock request object using keyword arguments
|
||||
mock_request = ChatCompletionRequest(
|
||||
@ -118,7 +96,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
result = tokenizer.apply_chat_template(
|
||||
conversation=mock_request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=mock_request.add_generation_prompt)
|
||||
add_generation_prompt=mock_request.add_generation_prompt,
|
||||
chat_template=mock_request.chat_template or template_content)
|
||||
|
||||
# Test assertion
|
||||
assert result == expected_output, (
|
||||
|
||||
@ -7,11 +7,11 @@ import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import torch
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
|
||||
from .test_completion import zephyr_lora_files # noqa: F401
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
@ -21,12 +21,7 @@ LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_files():
|
||||
return snapshot_download(repo_id=LORA_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(zephyr_lora_files):
|
||||
def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@ -38,7 +33,7 @@ def server(zephyr_lora_files):
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
f"zephyr-lora2={zephyr_lora_files}",
|
||||
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# imports for guided decoding tests
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List
|
||||
|
||||
import jsonschema
|
||||
@ -9,6 +11,7 @@ import pytest
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -30,13 +33,29 @@ def zephyr_lora_files():
|
||||
return snapshot_download(repo_id=LORA_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_added_tokens_files(zephyr_lora_files):
|
||||
tmp_dir = TemporaryDirectory()
|
||||
tmp_model_dir = f"{tmp_dir.name}/zephyr"
|
||||
shutil.copytree(zephyr_lora_files, tmp_model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
# Copy tokenizer to adapter and add some unique tokens
|
||||
# 32000, 32001, 32002
|
||||
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
|
||||
special_tokens=True)
|
||||
assert added == 3
|
||||
tokenizer.save_pretrained(tmp_model_dir)
|
||||
yield tmp_model_dir
|
||||
tmp_dir.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_pa_files():
|
||||
return snapshot_download(repo_id=PA_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(zephyr_lora_files, zephyr_pa_files):
|
||||
def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@ -50,7 +69,7 @@ def server(zephyr_lora_files, zephyr_pa_files):
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
f"zephyr-lora2={zephyr_lora_files}",
|
||||
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
@ -111,6 +130,34 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model="zephyr-lora2",
|
||||
prompt=[0, 0, 32000, 32001, 32002],
|
||||
echo=True,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
# Added tokens should appear in tokenized prompt
|
||||
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 32000, 32001, 32002],
|
||||
echo=True,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
# Added tokens should not appear in tokenized prompt
|
||||
assert "vllm" not in completion.choices[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras, then test prompt adapters
|
||||
|
||||
@ -38,5 +38,4 @@ async def _async_serving_chat_init():
|
||||
|
||||
def test_async_serving_chat_init():
|
||||
serving_completion = asyncio.run(_async_serving_chat_init())
|
||||
assert serving_completion.tokenizer is not None
|
||||
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE
|
||||
assert serving_completion.chat_template == CHAT_TEMPLATE
|
||||
|
||||
@ -5,13 +5,15 @@ import requests
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
|
||||
from .test_completion import zephyr_lora_files # noqa: F401
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
def server(zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@ -21,12 +23,25 @@ def server():
|
||||
"--enforce-eager",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
# lora config
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer_name(model_name: str,
|
||||
zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||
return zephyr_lora_added_tokens_files if (
|
||||
model_name == "zephyr-lora2") else model_name
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
@ -34,16 +49,18 @@ def client(server):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
"model_name,tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
model_name: str, tokenizer_name: str):
|
||||
base_url = str(client.base_url)[:-3].strip("/")
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
|
||||
for add_special in [False, True]:
|
||||
prompt = "This is a test prompt."
|
||||
prompt = "vllm1 This is a test prompt."
|
||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||
|
||||
response = requests.post(base_url + "/tokenize",
|
||||
@ -63,12 +80,15 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
"model_name,tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
|
||||
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
||||
tokenizer_name: str):
|
||||
base_url = str(client.base_url)[:-3].strip("/")
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
|
||||
for add_generation in [False, True]:
|
||||
for add_special in [False, True]:
|
||||
@ -80,7 +100,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
|
||||
"content": "Nice to meet you!"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Can I ask a question?"
|
||||
"content": "Can I ask a question? vllm1"
|
||||
}]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
@ -108,16 +128,20 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
"model_name,tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
|
||||
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
|
||||
tokenizer_name: str):
|
||||
base_url = str(client.base_url)[:-3].strip("/")
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||
tokenizer_mode="fast")
|
||||
|
||||
prompt = "This is a test prompt."
|
||||
prompt = "This is a test prompt. vllm1"
|
||||
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
|
||||
print(f"CALLING {base_url} FOR {model_name}")
|
||||
response = requests.post(base_url + "/detokenize",
|
||||
json={
|
||||
"model": model_name,
|
||||
|
||||
@ -480,11 +480,16 @@ class AsyncLLMEngine:
|
||||
self.set_errored(exc)
|
||||
self._request_tracker.propagate_exception(exc)
|
||||
|
||||
async def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> "PreTrainedTokenizer":
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_tokenizer.remote() # type: ignore
|
||||
else:
|
||||
return self.engine.get_tokenizer()
|
||||
return await self.engine.get_tokenizer.remote( # type: ignore
|
||||
lora_request)
|
||||
|
||||
return await (self.engine.get_tokenizer_group().
|
||||
get_lora_tokenizer_async(lora_request))
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
|
||||
@ -455,8 +455,11 @@ class LLMEngine:
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(None)
|
||||
def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
|
||||
|
||||
def get_tokenizer_for_seq(self,
|
||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||
|
||||
@ -257,7 +257,8 @@ def run_server(args, llm_engine=None):
|
||||
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
|
||||
served_model_names)
|
||||
openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine, model_config, served_model_names, args.chat_template)
|
||||
engine, model_config, served_model_names, args.lora_modules,
|
||||
args.chat_template)
|
||||
app.root_path = args.root_path
|
||||
|
||||
logger.info("Available routes are:")
|
||||
|
||||
@ -5,10 +5,11 @@ from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final
|
||||
|
||||
from openai.types.chat import (ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam)
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam,
|
||||
ChatCompletionMessageParam)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import async_get_and_parse_image
|
||||
@ -29,40 +30,34 @@ class ChatMessageParseResult:
|
||||
default_factory=list)
|
||||
|
||||
|
||||
def load_chat_template(engine: OpenAIServing, chat_template: Optional[str]):
|
||||
tokenizer = engine.tokenizer
|
||||
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
|
||||
if chat_template is None:
|
||||
return None
|
||||
try:
|
||||
with open(chat_template, "r") as f:
|
||||
resolved_chat_template = f.read()
|
||||
except OSError as e:
|
||||
JINJA_CHARS = "{}\n"
|
||||
if not any(c in chat_template for c in JINJA_CHARS):
|
||||
msg = (f"The supplied chat template ({chat_template}) "
|
||||
f"looks like a file path, but it failed to be "
|
||||
f"opened. Reason: {e}")
|
||||
raise ValueError(msg) from e
|
||||
|
||||
if chat_template is not None:
|
||||
try:
|
||||
with open(chat_template, "r") as f:
|
||||
tokenizer.chat_template = f.read()
|
||||
except OSError as e:
|
||||
JINJA_CHARS = "{}\n"
|
||||
if not any(c in chat_template for c in JINJA_CHARS):
|
||||
msg = (f"The supplied chat template ({chat_template}) "
|
||||
f"looks like a file path, but it failed to be "
|
||||
f"opened. Reason: {e}")
|
||||
raise ValueError(msg) from e
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
tokenizer.chat_template = codecs.decode(chat_template,
|
||||
"unicode_escape")
|
||||
|
||||
logger.info("Using supplied chat template:\n%s",
|
||||
tokenizer.chat_template)
|
||||
elif tokenizer.chat_template is not None:
|
||||
logger.info("Using default chat template:\n%s",
|
||||
tokenizer.chat_template)
|
||||
else:
|
||||
logger.warning("No chat template provided. Chat API will not work.")
|
||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
return resolved_chat_template
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _image_token_str(engine: OpenAIServing) -> Optional[str]:
|
||||
def _image_token_str(model_config: ModelConfig,
|
||||
tokenizer: PreTrainedTokenizer) -> Optional[str]:
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
model_type = engine.model_config.hf_config.model_type
|
||||
model_type = model_config.hf_config.model_type
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
return "<|image_1|>"
|
||||
@ -70,17 +65,14 @@ def _image_token_str(engine: OpenAIServing) -> Optional[str]:
|
||||
# These models do not use image tokens in the prompt
|
||||
return None
|
||||
if model_type.startswith("llava"):
|
||||
return engine.tokenizer.decode(
|
||||
engine.model_config.hf_config.image_token_index)
|
||||
return tokenizer.decode(model_config.hf_config.image_token_index)
|
||||
|
||||
else:
|
||||
raise TypeError("Unknown model type: {model_type}")
|
||||
raise TypeError("Unknown model type: {model_type}")
|
||||
|
||||
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
|
||||
text_prompt: str) -> str:
|
||||
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
|
||||
"""Combine image and text prompts for vision language model"""
|
||||
|
||||
# NOTE: For now we assume all model architectures use the same
|
||||
@ -89,9 +81,10 @@ def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
|
||||
|
||||
|
||||
def _parse_chat_message_content_parts(
|
||||
engine: OpenAIServing,
|
||||
role: str,
|
||||
parts: Iterable[ChatCompletionContentPartParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> ChatMessageParseResult:
|
||||
texts: List[str] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
@ -122,7 +115,7 @@ def _parse_chat_message_content_parts(
|
||||
text_prompt = "\n".join(texts)
|
||||
|
||||
if mm_futures:
|
||||
image_token_str = _image_token_str(engine)
|
||||
image_token_str = _image_token_str(model_config, tokenizer)
|
||||
if image_token_str is not None:
|
||||
if image_token_str in text_prompt:
|
||||
logger.warning(
|
||||
@ -130,7 +123,6 @@ def _parse_chat_message_content_parts(
|
||||
"Skipping prompt formatting.")
|
||||
else:
|
||||
text_prompt = _get_full_image_text_prompt(
|
||||
engine,
|
||||
image_token_str=image_token_str,
|
||||
text_prompt=text_prompt,
|
||||
)
|
||||
@ -141,8 +133,9 @@ def _parse_chat_message_content_parts(
|
||||
|
||||
|
||||
def parse_chat_message_content(
|
||||
engine: OpenAIServing,
|
||||
message: ChatCompletionMessageParam,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> ChatMessageParseResult:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
@ -153,4 +146,5 @@ def parse_chat_message_content(
|
||||
messages = [ConversationMessage(role=role, content=content)]
|
||||
return ChatMessageParseResult(messages=messages, mm_futures=[])
|
||||
|
||||
return _parse_chat_message_content_parts(engine, role, content)
|
||||
return _parse_chat_message_content_parts(role, content, model_config,
|
||||
tokenizer)
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
@ -49,7 +50,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
lora_modules=lora_modules)
|
||||
|
||||
self.response_role = response_role
|
||||
load_chat_template(self, chat_template)
|
||||
|
||||
# If this is None we use the tokenizer's default chat template
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
@ -71,11 +74,15 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return error_check_ret
|
||||
|
||||
try:
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
|
||||
for msg in request.messages:
|
||||
chat_parsed_result = parse_chat_message_content(self, msg)
|
||||
chat_parsed_result = parse_chat_message_content(
|
||||
msg, self.model_config, tokenizer)
|
||||
|
||||
conversation.extend(chat_parsed_result.messages)
|
||||
mm_futures.extend(chat_parsed_result.mm_futures)
|
||||
@ -84,13 +91,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool.model_dump() for tool in request.tools
|
||||
]
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
chat_template=request.chat_template,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
@ -112,19 +119,19 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
try:
|
||||
# Tokenize/detokenize depending on prompt format (string/token list)
|
||||
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
|
||||
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
sampling_params = request.to_sampling_params()
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
guided_decode_logits_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
guided_decoding_backend, request, await
|
||||
self.engine.get_tokenizer()))
|
||||
await
|
||||
get_guided_decoding_logits_processor(guided_decoding_backend,
|
||||
request, tokenizer))
|
||||
if guided_decode_logits_processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
@ -158,12 +165,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id, conversation)
|
||||
request, result_generator, request_id, conversation, tokenizer)
|
||||
else:
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, raw_request, result_generator, request_id,
|
||||
conversation)
|
||||
conversation, tokenizer)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@ -175,9 +182,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self, request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
||||
conversation: List[ConversationMessage]
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.served_model_names[0]
|
||||
created_time = int(time.time())
|
||||
@ -264,6 +274,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
else:
|
||||
@ -352,9 +363,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self, request: ChatCompletionRequest, raw_request: Optional[Request],
|
||||
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
||||
conversation: List[ConversationMessage]
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Optional[Request],
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.served_model_names[0]
|
||||
@ -382,6 +397,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
@ -436,16 +452,14 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return response
|
||||
|
||||
def _get_top_logprobs(
|
||||
self, logprobs: Dict[int, Logprob],
|
||||
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
|
||||
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
|
||||
return [
|
||||
ChatCompletionLogProb(
|
||||
token=self._get_decoded_token(p[1], p[0]),
|
||||
token=(token := self._get_decoded_token(p[1], p[0],
|
||||
tokenizer)),
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
bytes=list(
|
||||
self._get_decoded_token(p[1],
|
||||
p[0]).encode("utf-8",
|
||||
errors="replace")))
|
||||
bytes=list(token.encode("utf-8", errors="replace")))
|
||||
for i, p in enumerate(logprobs.items())
|
||||
if top_logprobs and i < top_logprobs
|
||||
]
|
||||
@ -454,6 +468,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
@ -463,12 +478,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=self.tokenizer.decode(token_id),
|
||||
bytes=list(
|
||||
self.tokenizer.decode(token_id).encode(
|
||||
"utf-8", errors="replace"))))
|
||||
token=token,
|
||||
bytes=list(token.encode("utf-8", errors="replace"))))
|
||||
else:
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
@ -479,6 +493,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
step_top_logprobs[token_id].decoded_token.encode(
|
||||
"utf-8", errors="replace")),
|
||||
top_logprobs=self._get_top_logprobs(
|
||||
step_top_logprobs, num_output_top_logprobs)))
|
||||
step_top_logprobs, num_output_top_logprobs,
|
||||
tokenizer)))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
@ -100,20 +101,22 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncIterator[RequestOutput]] = []
|
||||
try:
|
||||
sampling_params = request.to_sampling_params()
|
||||
adapter_type, adapter_request = self._maybe_get_adapter(request)
|
||||
lora_request, prompt_adapter_request = None, None
|
||||
if adapter_type == 'LoRA':
|
||||
lora_request, prompt_adapter_request = adapter_request, None
|
||||
elif adapter_type == 'PromptAdapter':
|
||||
lora_request, prompt_adapter_request = None, adapter_request
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
|
||||
sampling_params = request.to_sampling_params()
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
guided_decode_logit_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
guided_decoding_backend, request, await
|
||||
self.engine.get_tokenizer()))
|
||||
await
|
||||
get_guided_decoding_logits_processor(guided_decoding_backend,
|
||||
request, tokenizer))
|
||||
if guided_decode_logit_processor is not None:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
@ -122,18 +125,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
if prompt_is_tokens:
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
prompt_ids=prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens)
|
||||
else:
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
prompt=prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens)
|
||||
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
|
||||
prompt_formats = await self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
tokenizer,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens,
|
||||
**{prompt_arg: prompt})
|
||||
prompt_ids, prompt_text = prompt_formats
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
@ -179,7 +177,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=len(prompts))
|
||||
num_prompts=len(prompts),
|
||||
tokenizer=tokenizer)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||
@ -191,7 +190,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch, request, request_id, created_time, model_name)
|
||||
final_res_batch, request, request_id, created_time, model_name,
|
||||
tokenizer)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@ -218,6 +218,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
assert request.n is not None
|
||||
previous_texts = [""] * request.n * num_prompts
|
||||
@ -268,6 +269,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=len(previous_texts[i]),
|
||||
)
|
||||
else:
|
||||
@ -336,6 +338,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> CompletionResponse:
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
@ -367,6 +370,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
)
|
||||
else:
|
||||
@ -404,6 +408,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
initial_text_offset: int = 0,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
@ -417,13 +422,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = self.tokenizer.decode(token_id)
|
||||
token = tokenizer.decode(token_id)
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
token = self._get_decoded_token(step_top_logprobs[token_id],
|
||||
token_id)
|
||||
token_id, tokenizer)
|
||||
token_logprob = max(step_top_logprobs[token_id].logprob,
|
||||
-9999.0)
|
||||
out_tokens.append(token)
|
||||
@ -436,7 +441,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
out_top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(top_lp[1], top_lp[0]):
|
||||
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
|
||||
max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
|
||||
@ -89,14 +89,11 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
prompt_is_tokens, prompts = parse_prompt_format(request.input)
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer()
|
||||
for i, prompt in enumerate(prompts):
|
||||
if prompt_is_tokens:
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request, prompt_ids=prompt)
|
||||
else:
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request, prompt=prompt)
|
||||
|
||||
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
|
||||
prompt_formats = await self._validate_prompt_and_tokenize(
|
||||
request, tokenizer, **{prompt_arg: prompt})
|
||||
prompt_ids, prompt_text = prompt_formats
|
||||
|
||||
generator = self.engine.encode(
|
||||
|
||||
@ -5,6 +5,7 @@ from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
from transformers import PreTrainedTokenizer
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@ -19,7 +20,6 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -52,14 +52,6 @@ class OpenAIServing:
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
tokenizer_revision=model_config.tokenizer_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
truncation_side="left")
|
||||
|
||||
self.served_model_names = served_model_names
|
||||
|
||||
self.lora_requests = []
|
||||
@ -154,7 +146,8 @@ class OpenAIServing:
|
||||
|
||||
def _maybe_get_adapter(
|
||||
self, request: Union[CompletionRequest, ChatCompletionRequest,
|
||||
EmbeddingRequest]
|
||||
EmbeddingRequest, TokenizeRequest,
|
||||
DetokenizeRequest]
|
||||
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
|
||||
PromptAdapterRequest]]]:
|
||||
if request.model in self.served_model_names:
|
||||
@ -168,11 +161,12 @@ class OpenAIServing:
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _validate_prompt_and_tokenize(
|
||||
async def _validate_prompt_and_tokenize(
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest,
|
||||
DetokenizeRequest, EmbeddingRequest,
|
||||
TokenizeRequest],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
@ -181,7 +175,7 @@ class OpenAIServing:
|
||||
) -> Tuple[List[int], str]:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||
if (prompt and prompt_ids):
|
||||
if prompt and prompt_ids:
|
||||
raise ValueError(
|
||||
"Only one of prompt or prompt_ids should be provided.")
|
||||
|
||||
@ -200,14 +194,14 @@ class OpenAIServing:
|
||||
"truncation": True,
|
||||
"max_length": truncate_prompt_tokens,
|
||||
})
|
||||
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
|
||||
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
|
||||
elif truncate_prompt_tokens is not None:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
else:
|
||||
input_ids = prompt_ids
|
||||
|
||||
input_text = prompt if prompt is not None else self.tokenizer.decode(
|
||||
prompt_ids)
|
||||
input_text = prompt if prompt is not None else tokenizer.decode(
|
||||
input_ids)
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
@ -245,7 +239,9 @@ class OpenAIServing:
|
||||
else:
|
||||
return input_ids, input_text
|
||||
|
||||
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
|
||||
@staticmethod
|
||||
def _get_decoded_token(logprob: Logprob, token_id: int,
|
||||
tokenizer: PreTrainedTokenizer) -> str:
|
||||
if logprob.decoded_token is not None:
|
||||
return logprob.decoded_token
|
||||
return self.tokenizer.decode(token_id)
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
@ -9,7 +9,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
|
||||
|
||||
class OpenAIServingTokenization(OpenAIServing):
|
||||
@ -18,13 +19,15 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
lora_modules: Optional[List[LoRAModulePath]] = None,
|
||||
chat_template: Optional[str] = None):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=None)
|
||||
lora_modules=lora_modules)
|
||||
|
||||
load_chat_template(self, chat_template)
|
||||
# If this is None we use the tokenizer's default chat template
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
|
||||
async def create_tokenize(self,
|
||||
request: TokenizeRequest) -> TokenizeResponse:
|
||||
@ -40,20 +43,25 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"Only one of `prompt` or `messages` should be provided.")
|
||||
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
if request.messages:
|
||||
conversation: List[ConversationMessage] = []
|
||||
|
||||
for message in request.messages:
|
||||
conversation.extend(
|
||||
parse_chat_message_content(self, message).messages)
|
||||
result = parse_chat_message_content(message, self.model_config,
|
||||
tokenizer)
|
||||
conversation.extend(result.messages)
|
||||
|
||||
request.prompt = self.tokenizer.apply_chat_template(
|
||||
request.prompt = tokenizer.apply_chat_template(
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
conversation=conversation,
|
||||
tokenize=False)
|
||||
tokenize=False,
|
||||
chat_template=self.chat_template)
|
||||
|
||||
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
||||
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=request.prompt,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
|
||||
@ -67,7 +75,9 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
||||
request, prompt_ids=request.tokens)
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
|
||||
request, tokenizer, prompt_ids=request.tokens)
|
||||
|
||||
return DetokenizeResponse(prompt=input_text)
|
||||
|
||||
@ -165,6 +165,12 @@ class Detokenizer:
|
||||
return len(new_decoded_token_text)
|
||||
|
||||
|
||||
def _replace_none_with_empty(tokens: List[Optional[str]]):
|
||||
for i, token in enumerate(tokens):
|
||||
if token is None:
|
||||
tokens[i] = ""
|
||||
|
||||
|
||||
def _convert_tokens_to_string_with_added_encoders(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
output_tokens: List[str],
|
||||
@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens(
|
||||
read_offset = len(new_tokens)
|
||||
prefix_offset = max(
|
||||
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
||||
# This is required to guard against out-of-vocab prompt token ids
|
||||
_replace_none_with_empty(new_tokens)
|
||||
return new_tokens, prefix_offset, read_offset
|
||||
|
||||
|
||||
|
||||
@ -88,6 +88,9 @@ def get_tokenizer(
|
||||
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
if "truncation_side" not in kwargs:
|
||||
kwargs["truncation_side"] = "left"
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user