mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:45:15 +08:00
[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (#6227)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
8a74c68bd1
commit
e2fbaee725
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -50,23 +49,9 @@ TEST_MESSAGES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockTokenizer:
|
|
||||||
chat_template = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockServingChat:
|
|
||||||
tokenizer: MockTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_chat_template():
|
def test_load_chat_template():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
tokenizer = MockTokenizer()
|
template_content = load_chat_template(chat_template=chatml_jinja_path)
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
|
||||||
load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path)
|
|
||||||
|
|
||||||
template_content = tokenizer.chat_template
|
|
||||||
|
|
||||||
# Test assertions
|
# Test assertions
|
||||||
assert template_content is not None
|
assert template_content is not None
|
||||||
@ -78,22 +63,16 @@ def test_load_chat_template():
|
|||||||
def test_no_load_chat_template_filelike():
|
def test_no_load_chat_template_filelike():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
template = "../../examples/does_not_exist"
|
template = "../../examples/does_not_exist"
|
||||||
tokenizer = MockTokenizer()
|
|
||||||
|
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="looks like a file path"):
|
with pytest.raises(ValueError, match="looks like a file path"):
|
||||||
load_chat_template(mock_serving_chat, chat_template=template)
|
load_chat_template(chat_template=template)
|
||||||
|
|
||||||
|
|
||||||
def test_no_load_chat_template_literallike():
|
def test_no_load_chat_template_literallike():
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
template = "{{ messages }}"
|
template = "{{ messages }}"
|
||||||
tokenizer = MockTokenizer()
|
|
||||||
|
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
template_content = load_chat_template(chat_template=template)
|
||||||
load_chat_template(mock_serving_chat, chat_template=template)
|
|
||||||
template_content = tokenizer.chat_template
|
|
||||||
|
|
||||||
assert template_content == template
|
assert template_content == template
|
||||||
|
|
||||||
@ -105,8 +84,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
expected_output):
|
expected_output):
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
template_content = load_chat_template(chat_template=template)
|
||||||
load_chat_template(mock_serving_chat, chat_template=template)
|
|
||||||
|
|
||||||
# Create a mock request object using keyword arguments
|
# Create a mock request object using keyword arguments
|
||||||
mock_request = ChatCompletionRequest(
|
mock_request = ChatCompletionRequest(
|
||||||
@ -118,7 +96,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
result = tokenizer.apply_chat_template(
|
result = tokenizer.apply_chat_template(
|
||||||
conversation=mock_request.messages,
|
conversation=mock_request.messages,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=mock_request.add_generation_prompt)
|
add_generation_prompt=mock_request.add_generation_prompt,
|
||||||
|
chat_template=mock_request.chat_template or template_content)
|
||||||
|
|
||||||
# Test assertion
|
# Test assertion
|
||||||
assert result == expected_output, (
|
assert result == expected_output, (
|
||||||
|
|||||||
@ -7,11 +7,11 @@ import jsonschema
|
|||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
# downloading lora to test lora requests
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
|
||||||
|
from .test_completion import zephyr_lora_files # noqa: F401
|
||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
@ -21,12 +21,7 @@ LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def zephyr_lora_files():
|
def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
|
||||||
return snapshot_download(repo_id=LORA_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def server(zephyr_lora_files):
|
|
||||||
args = [
|
args = [
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
"--dtype",
|
"--dtype",
|
||||||
@ -38,7 +33,7 @@ def server(zephyr_lora_files):
|
|||||||
"--enable-lora",
|
"--enable-lora",
|
||||||
"--lora-modules",
|
"--lora-modules",
|
||||||
f"zephyr-lora={zephyr_lora_files}",
|
f"zephyr-lora={zephyr_lora_files}",
|
||||||
f"zephyr-lora2={zephyr_lora_files}",
|
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||||
"--max-lora-rank",
|
"--max-lora-rank",
|
||||||
"64",
|
"64",
|
||||||
"--max-cpu-loras",
|
"--max-cpu-loras",
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# imports for guided decoding tests
|
# imports for guided decoding tests
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
@ -9,6 +11,7 @@ import pytest
|
|||||||
# downloading lora to test lora requests
|
# downloading lora to test lora requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@ -30,13 +33,29 @@ def zephyr_lora_files():
|
|||||||
return snapshot_download(repo_id=LORA_NAME)
|
return snapshot_download(repo_id=LORA_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def zephyr_lora_added_tokens_files(zephyr_lora_files):
|
||||||
|
tmp_dir = TemporaryDirectory()
|
||||||
|
tmp_model_dir = f"{tmp_dir.name}/zephyr"
|
||||||
|
shutil.copytree(zephyr_lora_files, tmp_model_dir)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||||
|
# Copy tokenizer to adapter and add some unique tokens
|
||||||
|
# 32000, 32001, 32002
|
||||||
|
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
|
||||||
|
special_tokens=True)
|
||||||
|
assert added == 3
|
||||||
|
tokenizer.save_pretrained(tmp_model_dir)
|
||||||
|
yield tmp_model_dir
|
||||||
|
tmp_dir.cleanup()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def zephyr_pa_files():
|
def zephyr_pa_files():
|
||||||
return snapshot_download(repo_id=PA_NAME)
|
return snapshot_download(repo_id=PA_NAME)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server(zephyr_lora_files, zephyr_pa_files):
|
def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
|
||||||
args = [
|
args = [
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
"--dtype",
|
"--dtype",
|
||||||
@ -50,7 +69,7 @@ def server(zephyr_lora_files, zephyr_pa_files):
|
|||||||
"--enable-lora",
|
"--enable-lora",
|
||||||
"--lora-modules",
|
"--lora-modules",
|
||||||
f"zephyr-lora={zephyr_lora_files}",
|
f"zephyr-lora={zephyr_lora_files}",
|
||||||
f"zephyr-lora2={zephyr_lora_files}",
|
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||||
"--max-lora-rank",
|
"--max-lora-rank",
|
||||||
"64",
|
"64",
|
||||||
"--max-cpu-loras",
|
"--max-cpu-loras",
|
||||||
@ -111,6 +130,34 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
|
|||||||
assert len(completion.choices[0].text) >= 1
|
assert len(completion.choices[0].text) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
|
||||||
|
# test using token IDs
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model="zephyr-lora2",
|
||||||
|
prompt=[0, 0, 32000, 32001, 32002],
|
||||||
|
echo=True,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
# Added tokens should appear in tokenized prompt
|
||||||
|
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||||
|
# test using token IDs
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=[0, 0, 32000, 32001, 32002],
|
||||||
|
echo=True,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
# Added tokens should not appear in tokenized prompt
|
||||||
|
assert "vllm" not in completion.choices[0].text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# first test base model, then test loras, then test prompt adapters
|
# first test base model, then test loras, then test prompt adapters
|
||||||
|
|||||||
@ -38,5 +38,4 @@ async def _async_serving_chat_init():
|
|||||||
|
|
||||||
def test_async_serving_chat_init():
|
def test_async_serving_chat_init():
|
||||||
serving_completion = asyncio.run(_async_serving_chat_init())
|
serving_completion = asyncio.run(_async_serving_chat_init())
|
||||||
assert serving_completion.tokenizer is not None
|
assert serving_completion.chat_template == CHAT_TEMPLATE
|
||||||
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE
|
|
||||||
|
|||||||
@ -5,13 +5,15 @@ import requests
|
|||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
|
||||||
|
from .test_completion import zephyr_lora_files # noqa: F401
|
||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server(zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||||
args = [
|
args = [
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
"--dtype",
|
"--dtype",
|
||||||
@ -21,12 +23,25 @@ def server():
|
|||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--max-num-seqs",
|
"--max-num-seqs",
|
||||||
"128",
|
"128",
|
||||||
|
# lora config
|
||||||
|
"--enable-lora",
|
||||||
|
"--lora-modules",
|
||||||
|
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||||
|
"--max-lora-rank",
|
||||||
|
"64",
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def tokenizer_name(model_name: str,
|
||||||
|
zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||||
|
return zephyr_lora_added_tokens_files if (
|
||||||
|
model_name == "zephyr-lora2") else model_name
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def client(server):
|
def client(server):
|
||||||
return server.get_async_client()
|
return server.get_async_client()
|
||||||
@ -34,16 +49,18 @@ def client(server):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name,tokenizer_name",
|
||||||
[MODEL_NAME],
|
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||||
|
indirect=["tokenizer_name"],
|
||||||
)
|
)
|
||||||
async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str, tokenizer_name: str):
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||||
|
tokenizer_mode="fast")
|
||||||
|
|
||||||
for add_special in [False, True]:
|
for add_special in [False, True]:
|
||||||
prompt = "This is a test prompt."
|
prompt = "vllm1 This is a test prompt."
|
||||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||||
|
|
||||||
response = requests.post(base_url + "/tokenize",
|
response = requests.post(base_url + "/tokenize",
|
||||||
@ -63,12 +80,15 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name,tokenizer_name",
|
||||||
[MODEL_NAME],
|
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||||
|
indirect=["tokenizer_name"],
|
||||||
)
|
)
|
||||||
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
|
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
||||||
|
tokenizer_name: str):
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||||
|
tokenizer_mode="fast")
|
||||||
|
|
||||||
for add_generation in [False, True]:
|
for add_generation in [False, True]:
|
||||||
for add_special in [False, True]:
|
for add_special in [False, True]:
|
||||||
@ -80,7 +100,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
"content": "Nice to meet you!"
|
"content": "Nice to meet you!"
|
||||||
}, {
|
}, {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "Can I ask a question?"
|
"content": "Can I ask a question? vllm1"
|
||||||
}]
|
}]
|
||||||
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
@ -108,16 +128,20 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name,tokenizer_name",
|
||||||
[MODEL_NAME],
|
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||||
|
indirect=["tokenizer_name"],
|
||||||
)
|
)
|
||||||
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
|
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
|
||||||
|
tokenizer_name: str):
|
||||||
base_url = str(client.base_url)[:-3].strip("/")
|
base_url = str(client.base_url)[:-3].strip("/")
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
|
||||||
|
tokenizer_mode="fast")
|
||||||
|
|
||||||
prompt = "This is a test prompt."
|
prompt = "This is a test prompt. vllm1"
|
||||||
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
|
||||||
|
print(f"CALLING {base_url} FOR {model_name}")
|
||||||
response = requests.post(base_url + "/detokenize",
|
response = requests.post(base_url + "/detokenize",
|
||||||
json={
|
json={
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
|
|||||||
@ -480,11 +480,16 @@ class AsyncLLMEngine:
|
|||||||
self.set_errored(exc)
|
self.set_errored(exc)
|
||||||
self._request_tracker.propagate_exception(exc)
|
self._request_tracker.propagate_exception(exc)
|
||||||
|
|
||||||
async def get_tokenizer(self) -> "PreTrainedTokenizer":
|
async def get_tokenizer(
|
||||||
|
self,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
) -> "PreTrainedTokenizer":
|
||||||
if self.engine_use_ray:
|
if self.engine_use_ray:
|
||||||
return await self.engine.get_tokenizer.remote() # type: ignore
|
return await self.engine.get_tokenizer.remote( # type: ignore
|
||||||
else:
|
lora_request)
|
||||||
return self.engine.get_tokenizer()
|
|
||||||
|
return await (self.engine.get_tokenizer_group().
|
||||||
|
get_lora_tokenizer_async(lora_request))
|
||||||
|
|
||||||
def start_background_loop(self) -> None:
|
def start_background_loop(self) -> None:
|
||||||
"""Start the background loop."""
|
"""Start the background loop."""
|
||||||
|
|||||||
@ -455,8 +455,11 @@ class LLMEngine:
|
|||||||
|
|
||||||
return self.tokenizer
|
return self.tokenizer
|
||||||
|
|
||||||
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
def get_tokenizer(
|
||||||
return self.get_tokenizer_group().get_lora_tokenizer(None)
|
self,
|
||||||
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
) -> "PreTrainedTokenizer":
|
||||||
|
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
|
||||||
|
|
||||||
def get_tokenizer_for_seq(self,
|
def get_tokenizer_for_seq(self,
|
||||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||||
|
|||||||
@ -257,7 +257,8 @@ def run_server(args, llm_engine=None):
|
|||||||
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
|
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
|
||||||
served_model_names)
|
served_model_names)
|
||||||
openai_serving_tokenization = OpenAIServingTokenization(
|
openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
engine, model_config, served_model_names, args.chat_template)
|
engine, model_config, served_model_names, args.lora_modules,
|
||||||
|
args.chat_template)
|
||||||
app.root_path = args.root_path
|
app.root_path = args.root_path
|
||||||
|
|
||||||
logger.info("Available routes are:")
|
logger.info("Available routes are:")
|
||||||
|
|||||||
@ -5,10 +5,11 @@ from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final
|
|||||||
|
|
||||||
from openai.types.chat import (ChatCompletionContentPartImageParam,
|
from openai.types.chat import (ChatCompletionContentPartImageParam,
|
||||||
ChatCompletionContentPartTextParam)
|
ChatCompletionContentPartTextParam)
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam,
|
||||||
ChatCompletionMessageParam)
|
ChatCompletionMessageParam)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.utils import async_get_and_parse_image
|
from vllm.multimodal.utils import async_get_and_parse_image
|
||||||
@ -29,40 +30,34 @@ class ChatMessageParseResult:
|
|||||||
default_factory=list)
|
default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def load_chat_template(engine: OpenAIServing, chat_template: Optional[str]):
|
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
|
||||||
tokenizer = engine.tokenizer
|
if chat_template is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with open(chat_template, "r") as f:
|
||||||
|
resolved_chat_template = f.read()
|
||||||
|
except OSError as e:
|
||||||
|
JINJA_CHARS = "{}\n"
|
||||||
|
if not any(c in chat_template for c in JINJA_CHARS):
|
||||||
|
msg = (f"The supplied chat template ({chat_template}) "
|
||||||
|
f"looks like a file path, but it failed to be "
|
||||||
|
f"opened. Reason: {e}")
|
||||||
|
raise ValueError(msg) from e
|
||||||
|
|
||||||
if chat_template is not None:
|
# If opening a file fails, set chat template to be args to
|
||||||
try:
|
# ensure we decode so our escape are interpreted correctly
|
||||||
with open(chat_template, "r") as f:
|
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||||
tokenizer.chat_template = f.read()
|
|
||||||
except OSError as e:
|
|
||||||
JINJA_CHARS = "{}\n"
|
|
||||||
if not any(c in chat_template for c in JINJA_CHARS):
|
|
||||||
msg = (f"The supplied chat template ({chat_template}) "
|
|
||||||
f"looks like a file path, but it failed to be "
|
|
||||||
f"opened. Reason: {e}")
|
|
||||||
raise ValueError(msg) from e
|
|
||||||
|
|
||||||
# If opening a file fails, set chat template to be args to
|
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||||
# ensure we decode so our escape are interpreted correctly
|
return resolved_chat_template
|
||||||
tokenizer.chat_template = codecs.decode(chat_template,
|
|
||||||
"unicode_escape")
|
|
||||||
|
|
||||||
logger.info("Using supplied chat template:\n%s",
|
|
||||||
tokenizer.chat_template)
|
|
||||||
elif tokenizer.chat_template is not None:
|
|
||||||
logger.info("Using default chat template:\n%s",
|
|
||||||
tokenizer.chat_template)
|
|
||||||
else:
|
|
||||||
logger.warning("No chat template provided. Chat API will not work.")
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def _image_token_str(engine: OpenAIServing) -> Optional[str]:
|
def _image_token_str(model_config: ModelConfig,
|
||||||
|
tokenizer: PreTrainedTokenizer) -> Optional[str]:
|
||||||
# TODO: Let user specify how to insert image tokens into prompt
|
# TODO: Let user specify how to insert image tokens into prompt
|
||||||
# (similar to chat template)
|
# (similar to chat template)
|
||||||
model_type = engine.model_config.hf_config.model_type
|
model_type = model_config.hf_config.model_type
|
||||||
if model_type == "phi3_v":
|
if model_type == "phi3_v":
|
||||||
# Workaround since this token is not defined in the tokenizer
|
# Workaround since this token is not defined in the tokenizer
|
||||||
return "<|image_1|>"
|
return "<|image_1|>"
|
||||||
@ -70,17 +65,14 @@ def _image_token_str(engine: OpenAIServing) -> Optional[str]:
|
|||||||
# These models do not use image tokens in the prompt
|
# These models do not use image tokens in the prompt
|
||||||
return None
|
return None
|
||||||
if model_type.startswith("llava"):
|
if model_type.startswith("llava"):
|
||||||
return engine.tokenizer.decode(
|
return tokenizer.decode(model_config.hf_config.image_token_index)
|
||||||
engine.model_config.hf_config.image_token_index)
|
|
||||||
|
|
||||||
else:
|
raise TypeError("Unknown model type: {model_type}")
|
||||||
raise TypeError("Unknown model type: {model_type}")
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Let user specify how to insert image tokens into prompt
|
# TODO: Let user specify how to insert image tokens into prompt
|
||||||
# (similar to chat template)
|
# (similar to chat template)
|
||||||
def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
|
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
|
||||||
text_prompt: str) -> str:
|
|
||||||
"""Combine image and text prompts for vision language model"""
|
"""Combine image and text prompts for vision language model"""
|
||||||
|
|
||||||
# NOTE: For now we assume all model architectures use the same
|
# NOTE: For now we assume all model architectures use the same
|
||||||
@ -89,9 +81,10 @@ def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str,
|
|||||||
|
|
||||||
|
|
||||||
def _parse_chat_message_content_parts(
|
def _parse_chat_message_content_parts(
|
||||||
engine: OpenAIServing,
|
|
||||||
role: str,
|
role: str,
|
||||||
parts: Iterable[ChatCompletionContentPartParam],
|
parts: Iterable[ChatCompletionContentPartParam],
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
) -> ChatMessageParseResult:
|
) -> ChatMessageParseResult:
|
||||||
texts: List[str] = []
|
texts: List[str] = []
|
||||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||||
@ -122,7 +115,7 @@ def _parse_chat_message_content_parts(
|
|||||||
text_prompt = "\n".join(texts)
|
text_prompt = "\n".join(texts)
|
||||||
|
|
||||||
if mm_futures:
|
if mm_futures:
|
||||||
image_token_str = _image_token_str(engine)
|
image_token_str = _image_token_str(model_config, tokenizer)
|
||||||
if image_token_str is not None:
|
if image_token_str is not None:
|
||||||
if image_token_str in text_prompt:
|
if image_token_str in text_prompt:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -130,7 +123,6 @@ def _parse_chat_message_content_parts(
|
|||||||
"Skipping prompt formatting.")
|
"Skipping prompt formatting.")
|
||||||
else:
|
else:
|
||||||
text_prompt = _get_full_image_text_prompt(
|
text_prompt = _get_full_image_text_prompt(
|
||||||
engine,
|
|
||||||
image_token_str=image_token_str,
|
image_token_str=image_token_str,
|
||||||
text_prompt=text_prompt,
|
text_prompt=text_prompt,
|
||||||
)
|
)
|
||||||
@ -141,8 +133,9 @@ def _parse_chat_message_content_parts(
|
|||||||
|
|
||||||
|
|
||||||
def parse_chat_message_content(
|
def parse_chat_message_content(
|
||||||
engine: OpenAIServing,
|
|
||||||
message: ChatCompletionMessageParam,
|
message: ChatCompletionMessageParam,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
) -> ChatMessageParseResult:
|
) -> ChatMessageParseResult:
|
||||||
role = message["role"]
|
role = message["role"]
|
||||||
content = message.get("content")
|
content = message.get("content")
|
||||||
@ -153,4 +146,5 @@ def parse_chat_message_content(
|
|||||||
messages = [ConversationMessage(role=role, content=content)]
|
messages = [ConversationMessage(role=role, content=content)]
|
||||||
return ChatMessageParseResult(messages=messages, mm_futures=[])
|
return ChatMessageParseResult(messages=messages, mm_futures=[])
|
||||||
|
|
||||||
return _parse_chat_message_content_parts(engine, role, content)
|
return _parse_chat_message_content_parts(role, content, model_config,
|
||||||
|
tokenizer)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
@ -49,7 +50,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
lora_modules=lora_modules)
|
lora_modules=lora_modules)
|
||||||
|
|
||||||
self.response_role = response_role
|
self.response_role = response_role
|
||||||
load_chat_template(self, chat_template)
|
|
||||||
|
# If this is None we use the tokenizer's default chat template
|
||||||
|
self.chat_template = load_chat_template(chat_template)
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
@ -71,11 +74,15 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
_, lora_request = self._maybe_get_adapter(request)
|
||||||
|
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||||
|
|
||||||
conversation: List[ConversationMessage] = []
|
conversation: List[ConversationMessage] = []
|
||||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||||
|
|
||||||
for msg in request.messages:
|
for msg in request.messages:
|
||||||
chat_parsed_result = parse_chat_message_content(self, msg)
|
chat_parsed_result = parse_chat_message_content(
|
||||||
|
msg, self.model_config, tokenizer)
|
||||||
|
|
||||||
conversation.extend(chat_parsed_result.messages)
|
conversation.extend(chat_parsed_result.messages)
|
||||||
mm_futures.extend(chat_parsed_result.mm_futures)
|
mm_futures.extend(chat_parsed_result.mm_futures)
|
||||||
@ -84,13 +91,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tool.model_dump() for tool in request.tools
|
tool.model_dump() for tool in request.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
tools=tool_dicts,
|
tools=tool_dicts,
|
||||||
documents=request.documents,
|
documents=request.documents,
|
||||||
chat_template=request.chat_template,
|
chat_template=request.chat_template or self.chat_template,
|
||||||
**(request.chat_template_kwargs or {}),
|
**(request.chat_template_kwargs or {}),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -112,19 +119,19 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
try:
|
try:
|
||||||
# Tokenize/detokenize depending on prompt format (string/token list)
|
# Tokenize/detokenize depending on prompt format (string/token list)
|
||||||
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
|
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
|
||||||
request,
|
request,
|
||||||
|
tokenizer,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
add_special_tokens=request.add_special_tokens)
|
add_special_tokens=request.add_special_tokens)
|
||||||
sampling_params = request.to_sampling_params()
|
sampling_params = request.to_sampling_params()
|
||||||
_, lora_request = self._maybe_get_adapter(request)
|
|
||||||
decoding_config = await self.engine.get_decoding_config()
|
decoding_config = await self.engine.get_decoding_config()
|
||||||
guided_decoding_backend = request.guided_decoding_backend \
|
guided_decoding_backend = request.guided_decoding_backend \
|
||||||
or decoding_config.guided_decoding_backend
|
or decoding_config.guided_decoding_backend
|
||||||
guided_decode_logits_processor = (
|
guided_decode_logits_processor = (
|
||||||
await get_guided_decoding_logits_processor(
|
await
|
||||||
guided_decoding_backend, request, await
|
get_guided_decoding_logits_processor(guided_decoding_backend,
|
||||||
self.engine.get_tokenizer()))
|
request, tokenizer))
|
||||||
if guided_decode_logits_processor:
|
if guided_decode_logits_processor:
|
||||||
if sampling_params.logits_processors is None:
|
if sampling_params.logits_processors is None:
|
||||||
sampling_params.logits_processors = []
|
sampling_params.logits_processors = []
|
||||||
@ -158,12 +165,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# Streaming response
|
# Streaming response
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self.chat_completion_stream_generator(
|
return self.chat_completion_stream_generator(
|
||||||
request, result_generator, request_id, conversation)
|
request, result_generator, request_id, conversation, tokenizer)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return await self.chat_completion_full_generator(
|
return await self.chat_completion_full_generator(
|
||||||
request, raw_request, result_generator, request_id,
|
request, raw_request, result_generator, request_id,
|
||||||
conversation)
|
conversation, tokenizer)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -175,9 +182,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return request.messages[-1]["role"]
|
return request.messages[-1]["role"]
|
||||||
|
|
||||||
async def chat_completion_stream_generator(
|
async def chat_completion_stream_generator(
|
||||||
self, request: ChatCompletionRequest,
|
self,
|
||||||
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
request: ChatCompletionRequest,
|
||||||
conversation: List[ConversationMessage]
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
|
request_id: str,
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
model_name = self.served_model_names[0]
|
model_name = self.served_model_names[0]
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
@ -264,6 +274,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logprobs = self._create_chat_logprobs(
|
logprobs = self._create_chat_logprobs(
|
||||||
token_ids=delta_token_ids,
|
token_ids=delta_token_ids,
|
||||||
top_logprobs=out_logprobs,
|
top_logprobs=out_logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
num_output_top_logprobs=request.top_logprobs,
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -352,9 +363,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
async def chat_completion_full_generator(
|
async def chat_completion_full_generator(
|
||||||
self, request: ChatCompletionRequest, raw_request: Optional[Request],
|
self,
|
||||||
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
request: ChatCompletionRequest,
|
||||||
conversation: List[ConversationMessage]
|
raw_request: Optional[Request],
|
||||||
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
|
request_id: str,
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||||
|
|
||||||
model_name = self.served_model_names[0]
|
model_name = self.served_model_names[0]
|
||||||
@ -382,6 +397,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
top_logprobs=out_logprobs,
|
top_logprobs=out_logprobs,
|
||||||
num_output_top_logprobs=request.top_logprobs,
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
@ -436,16 +452,14 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def _get_top_logprobs(
|
def _get_top_logprobs(
|
||||||
self, logprobs: Dict[int, Logprob],
|
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||||
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
|
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
|
||||||
return [
|
return [
|
||||||
ChatCompletionLogProb(
|
ChatCompletionLogProb(
|
||||||
token=self._get_decoded_token(p[1], p[0]),
|
token=(token := self._get_decoded_token(p[1], p[0],
|
||||||
|
tokenizer)),
|
||||||
logprob=max(p[1].logprob, -9999.0),
|
logprob=max(p[1].logprob, -9999.0),
|
||||||
bytes=list(
|
bytes=list(token.encode("utf-8", errors="replace")))
|
||||||
self._get_decoded_token(p[1],
|
|
||||||
p[0]).encode("utf-8",
|
|
||||||
errors="replace")))
|
|
||||||
for i, p in enumerate(logprobs.items())
|
for i, p in enumerate(logprobs.items())
|
||||||
if top_logprobs and i < top_logprobs
|
if top_logprobs and i < top_logprobs
|
||||||
]
|
]
|
||||||
@ -454,6 +468,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
token_ids: GenericSequence[int],
|
token_ids: GenericSequence[int],
|
||||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
num_output_top_logprobs: Optional[int] = None,
|
num_output_top_logprobs: Optional[int] = None,
|
||||||
) -> ChatCompletionLogProbs:
|
) -> ChatCompletionLogProbs:
|
||||||
"""Create OpenAI-style logprobs."""
|
"""Create OpenAI-style logprobs."""
|
||||||
@ -463,12 +478,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
for i, token_id in enumerate(token_ids):
|
for i, token_id in enumerate(token_ids):
|
||||||
step_top_logprobs = top_logprobs[i]
|
step_top_logprobs = top_logprobs[i]
|
||||||
if step_top_logprobs is None:
|
if step_top_logprobs is None:
|
||||||
|
token = tokenizer.decode(token_id)
|
||||||
logprobs_content.append(
|
logprobs_content.append(
|
||||||
ChatCompletionLogProbsContent(
|
ChatCompletionLogProbsContent(
|
||||||
token=self.tokenizer.decode(token_id),
|
token=token,
|
||||||
bytes=list(
|
bytes=list(token.encode("utf-8", errors="replace"))))
|
||||||
self.tokenizer.decode(token_id).encode(
|
|
||||||
"utf-8", errors="replace"))))
|
|
||||||
else:
|
else:
|
||||||
logprobs_content.append(
|
logprobs_content.append(
|
||||||
ChatCompletionLogProbsContent(
|
ChatCompletionLogProbsContent(
|
||||||
@ -479,6 +493,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
step_top_logprobs[token_id].decoded_token.encode(
|
step_top_logprobs[token_id].decoded_token.encode(
|
||||||
"utf-8", errors="replace")),
|
"utf-8", errors="replace")),
|
||||||
top_logprobs=self._get_top_logprobs(
|
top_logprobs=self._get_top_logprobs(
|
||||||
step_top_logprobs, num_output_top_logprobs)))
|
step_top_logprobs, num_output_top_logprobs,
|
||||||
|
tokenizer)))
|
||||||
|
|
||||||
return ChatCompletionLogProbs(content=logprobs_content)
|
return ChatCompletionLogProbs(content=logprobs_content)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
@ -100,20 +101,22 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: List[AsyncIterator[RequestOutput]] = []
|
generators: List[AsyncIterator[RequestOutput]] = []
|
||||||
try:
|
try:
|
||||||
sampling_params = request.to_sampling_params()
|
|
||||||
adapter_type, adapter_request = self._maybe_get_adapter(request)
|
adapter_type, adapter_request = self._maybe_get_adapter(request)
|
||||||
lora_request, prompt_adapter_request = None, None
|
lora_request, prompt_adapter_request = None, None
|
||||||
if adapter_type == 'LoRA':
|
if adapter_type == 'LoRA':
|
||||||
lora_request, prompt_adapter_request = adapter_request, None
|
lora_request, prompt_adapter_request = adapter_request, None
|
||||||
elif adapter_type == 'PromptAdapter':
|
elif adapter_type == 'PromptAdapter':
|
||||||
lora_request, prompt_adapter_request = None, adapter_request
|
lora_request, prompt_adapter_request = None, adapter_request
|
||||||
|
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||||
|
|
||||||
|
sampling_params = request.to_sampling_params()
|
||||||
decoding_config = await self.engine.get_decoding_config()
|
decoding_config = await self.engine.get_decoding_config()
|
||||||
guided_decoding_backend = request.guided_decoding_backend \
|
guided_decoding_backend = request.guided_decoding_backend \
|
||||||
or decoding_config.guided_decoding_backend
|
or decoding_config.guided_decoding_backend
|
||||||
guided_decode_logit_processor = (
|
guided_decode_logit_processor = (
|
||||||
await get_guided_decoding_logits_processor(
|
await
|
||||||
guided_decoding_backend, request, await
|
get_guided_decoding_logits_processor(guided_decoding_backend,
|
||||||
self.engine.get_tokenizer()))
|
request, tokenizer))
|
||||||
if guided_decode_logit_processor is not None:
|
if guided_decode_logit_processor is not None:
|
||||||
if sampling_params.logits_processors is None:
|
if sampling_params.logits_processors is None:
|
||||||
sampling_params.logits_processors = []
|
sampling_params.logits_processors = []
|
||||||
@ -122,18 +125,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
||||||
|
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
if prompt_is_tokens:
|
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
|
||||||
prompt_formats = self._validate_prompt_and_tokenize(
|
prompt_formats = await self._validate_prompt_and_tokenize(
|
||||||
request,
|
request,
|
||||||
prompt_ids=prompt,
|
tokenizer,
|
||||||
truncate_prompt_tokens=sampling_params.
|
truncate_prompt_tokens=sampling_params.
|
||||||
truncate_prompt_tokens)
|
truncate_prompt_tokens,
|
||||||
else:
|
**{prompt_arg: prompt})
|
||||||
prompt_formats = self._validate_prompt_and_tokenize(
|
|
||||||
request,
|
|
||||||
prompt=prompt,
|
|
||||||
truncate_prompt_tokens=sampling_params.
|
|
||||||
truncate_prompt_tokens)
|
|
||||||
prompt_ids, prompt_text = prompt_formats
|
prompt_ids, prompt_text = prompt_formats
|
||||||
|
|
||||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||||
@ -179,7 +177,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
request_id,
|
request_id,
|
||||||
created_time,
|
created_time,
|
||||||
model_name,
|
model_name,
|
||||||
num_prompts=len(prompts))
|
num_prompts=len(prompts),
|
||||||
|
tokenizer=tokenizer)
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||||
@ -191,7 +190,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
final_res_batch[i] = res
|
final_res_batch[i] = res
|
||||||
response = self.request_output_to_completion_response(
|
response = self.request_output_to_completion_response(
|
||||||
final_res_batch, request, request_id, created_time, model_name)
|
final_res_batch, request, request_id, created_time, model_name,
|
||||||
|
tokenizer)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -218,6 +218,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
created_time: int,
|
created_time: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
num_prompts: int,
|
num_prompts: int,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
assert request.n is not None
|
assert request.n is not None
|
||||||
previous_texts = [""] * request.n * num_prompts
|
previous_texts = [""] * request.n * num_prompts
|
||||||
@ -268,6 +269,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
token_ids=delta_token_ids,
|
token_ids=delta_token_ids,
|
||||||
top_logprobs=out_logprobs,
|
top_logprobs=out_logprobs,
|
||||||
num_output_top_logprobs=request.logprobs,
|
num_output_top_logprobs=request.logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
initial_text_offset=len(previous_texts[i]),
|
initial_text_offset=len(previous_texts[i]),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -336,6 +338,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
created_time: int,
|
created_time: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
choices: List[CompletionResponseChoice] = []
|
choices: List[CompletionResponseChoice] = []
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
@ -367,6 +370,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
logprobs = self._create_completion_logprobs(
|
logprobs = self._create_completion_logprobs(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
top_logprobs=out_logprobs,
|
top_logprobs=out_logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
num_output_top_logprobs=request.logprobs,
|
num_output_top_logprobs=request.logprobs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -404,6 +408,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
token_ids: GenericSequence[int],
|
token_ids: GenericSequence[int],
|
||||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||||
num_output_top_logprobs: int,
|
num_output_top_logprobs: int,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
initial_text_offset: int = 0,
|
initial_text_offset: int = 0,
|
||||||
) -> CompletionLogProbs:
|
) -> CompletionLogProbs:
|
||||||
"""Create logprobs for OpenAI Completion API."""
|
"""Create logprobs for OpenAI Completion API."""
|
||||||
@ -417,13 +422,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
for i, token_id in enumerate(token_ids):
|
for i, token_id in enumerate(token_ids):
|
||||||
step_top_logprobs = top_logprobs[i]
|
step_top_logprobs = top_logprobs[i]
|
||||||
if step_top_logprobs is None:
|
if step_top_logprobs is None:
|
||||||
token = self.tokenizer.decode(token_id)
|
token = tokenizer.decode(token_id)
|
||||||
out_tokens.append(token)
|
out_tokens.append(token)
|
||||||
out_token_logprobs.append(None)
|
out_token_logprobs.append(None)
|
||||||
out_top_logprobs.append(None)
|
out_top_logprobs.append(None)
|
||||||
else:
|
else:
|
||||||
token = self._get_decoded_token(step_top_logprobs[token_id],
|
token = self._get_decoded_token(step_top_logprobs[token_id],
|
||||||
token_id)
|
token_id, tokenizer)
|
||||||
token_logprob = max(step_top_logprobs[token_id].logprob,
|
token_logprob = max(step_top_logprobs[token_id].logprob,
|
||||||
-9999.0)
|
-9999.0)
|
||||||
out_tokens.append(token)
|
out_tokens.append(token)
|
||||||
@ -436,7 +441,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
out_top_logprobs.append({
|
out_top_logprobs.append({
|
||||||
# Convert float("-inf") to the
|
# Convert float("-inf") to the
|
||||||
# JSON-serializable float that OpenAI uses
|
# JSON-serializable float that OpenAI uses
|
||||||
self._get_decoded_token(top_lp[1], top_lp[0]):
|
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
|
||||||
max(top_lp[1].logprob, -9999.0)
|
max(top_lp[1].logprob, -9999.0)
|
||||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||||
if num_output_top_logprobs >= i
|
if num_output_top_logprobs >= i
|
||||||
|
|||||||
@ -89,14 +89,11 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
prompt_is_tokens, prompts = parse_prompt_format(request.input)
|
prompt_is_tokens, prompts = parse_prompt_format(request.input)
|
||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
|
tokenizer = await self.engine.get_tokenizer()
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
if prompt_is_tokens:
|
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
|
||||||
prompt_formats = self._validate_prompt_and_tokenize(
|
prompt_formats = await self._validate_prompt_and_tokenize(
|
||||||
request, prompt_ids=prompt)
|
request, tokenizer, **{prompt_arg: prompt})
|
||||||
else:
|
|
||||||
prompt_formats = self._validate_prompt_and_tokenize(
|
|
||||||
request, prompt=prompt)
|
|
||||||
|
|
||||||
prompt_ids, prompt_text = prompt_formats
|
prompt_ids, prompt_text = prompt_formats
|
||||||
|
|
||||||
generator = self.engine.encode(
|
generator = self.engine.encode(
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from http import HTTPStatus
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -19,7 +20,6 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -52,14 +52,6 @@ class OpenAIServing:
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
|
|
||||||
# A separate tokenizer to map token IDs to strings.
|
|
||||||
self.tokenizer = get_tokenizer(
|
|
||||||
model_config.tokenizer,
|
|
||||||
tokenizer_mode=model_config.tokenizer_mode,
|
|
||||||
tokenizer_revision=model_config.tokenizer_revision,
|
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
|
||||||
truncation_side="left")
|
|
||||||
|
|
||||||
self.served_model_names = served_model_names
|
self.served_model_names = served_model_names
|
||||||
|
|
||||||
self.lora_requests = []
|
self.lora_requests = []
|
||||||
@ -154,7 +146,8 @@ class OpenAIServing:
|
|||||||
|
|
||||||
def _maybe_get_adapter(
|
def _maybe_get_adapter(
|
||||||
self, request: Union[CompletionRequest, ChatCompletionRequest,
|
self, request: Union[CompletionRequest, ChatCompletionRequest,
|
||||||
EmbeddingRequest]
|
EmbeddingRequest, TokenizeRequest,
|
||||||
|
DetokenizeRequest]
|
||||||
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
|
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
|
||||||
PromptAdapterRequest]]]:
|
PromptAdapterRequest]]]:
|
||||||
if request.model in self.served_model_names:
|
if request.model in self.served_model_names:
|
||||||
@ -168,11 +161,12 @@ class OpenAIServing:
|
|||||||
# if _check_model has been called earlier, this will be unreachable
|
# if _check_model has been called earlier, this will be unreachable
|
||||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||||
|
|
||||||
def _validate_prompt_and_tokenize(
|
async def _validate_prompt_and_tokenize(
|
||||||
self,
|
self,
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest,
|
request: Union[ChatCompletionRequest, CompletionRequest,
|
||||||
DetokenizeRequest, EmbeddingRequest,
|
DetokenizeRequest, EmbeddingRequest,
|
||||||
TokenizeRequest],
|
TokenizeRequest],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
prompt_ids: Optional[List[int]] = None,
|
prompt_ids: Optional[List[int]] = None,
|
||||||
truncate_prompt_tokens: Optional[Annotated[int,
|
truncate_prompt_tokens: Optional[Annotated[int,
|
||||||
@ -181,7 +175,7 @@ class OpenAIServing:
|
|||||||
) -> Tuple[List[int], str]:
|
) -> Tuple[List[int], str]:
|
||||||
if not (prompt or prompt_ids):
|
if not (prompt or prompt_ids):
|
||||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||||
if (prompt and prompt_ids):
|
if prompt and prompt_ids:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Only one of prompt or prompt_ids should be provided.")
|
"Only one of prompt or prompt_ids should be provided.")
|
||||||
|
|
||||||
@ -200,14 +194,14 @@ class OpenAIServing:
|
|||||||
"truncation": True,
|
"truncation": True,
|
||||||
"max_length": truncate_prompt_tokens,
|
"max_length": truncate_prompt_tokens,
|
||||||
})
|
})
|
||||||
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
|
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
|
||||||
elif truncate_prompt_tokens is not None:
|
elif truncate_prompt_tokens is not None:
|
||||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||||
else:
|
else:
|
||||||
input_ids = prompt_ids
|
input_ids = prompt_ids
|
||||||
|
|
||||||
input_text = prompt if prompt is not None else self.tokenizer.decode(
|
input_text = prompt if prompt is not None else tokenizer.decode(
|
||||||
prompt_ids)
|
input_ids)
|
||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
# Note: EmbeddingRequest doesn't have max_tokens
|
# Note: EmbeddingRequest doesn't have max_tokens
|
||||||
@ -245,7 +239,9 @@ class OpenAIServing:
|
|||||||
else:
|
else:
|
||||||
return input_ids, input_text
|
return input_ids, input_text
|
||||||
|
|
||||||
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
|
@staticmethod
|
||||||
|
def _get_decoded_token(logprob: Logprob, token_id: int,
|
||||||
|
tokenizer: PreTrainedTokenizer) -> str:
|
||||||
if logprob.decoded_token is not None:
|
if logprob.decoded_token is not None:
|
||||||
return logprob.decoded_token
|
return logprob.decoded_token
|
||||||
return self.tokenizer.decode(token_id)
|
return tokenizer.decode(token_id)
|
||||||
|
|||||||
@ -9,7 +9,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
|||||||
DetokenizeResponse,
|
DetokenizeResponse,
|
||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse)
|
TokenizeResponse)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
|
OpenAIServing)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingTokenization(OpenAIServing):
|
class OpenAIServingTokenization(OpenAIServing):
|
||||||
@ -18,13 +19,15 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
engine: AsyncLLMEngine,
|
engine: AsyncLLMEngine,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
|
lora_modules: Optional[List[LoRAModulePath]] = None,
|
||||||
chat_template: Optional[str] = None):
|
chat_template: Optional[str] = None):
|
||||||
super().__init__(engine=engine,
|
super().__init__(engine=engine,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
served_model_names=served_model_names,
|
||||||
lora_modules=None)
|
lora_modules=lora_modules)
|
||||||
|
|
||||||
load_chat_template(self, chat_template)
|
# If this is None we use the tokenizer's default chat template
|
||||||
|
self.chat_template = load_chat_template(chat_template)
|
||||||
|
|
||||||
async def create_tokenize(self,
|
async def create_tokenize(self,
|
||||||
request: TokenizeRequest) -> TokenizeResponse:
|
request: TokenizeRequest) -> TokenizeResponse:
|
||||||
@ -40,20 +43,25 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Only one of `prompt` or `messages` should be provided.")
|
"Only one of `prompt` or `messages` should be provided.")
|
||||||
|
|
||||||
|
_, lora_request = self._maybe_get_adapter(request)
|
||||||
|
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||||
if request.messages:
|
if request.messages:
|
||||||
conversation: List[ConversationMessage] = []
|
conversation: List[ConversationMessage] = []
|
||||||
|
|
||||||
for message in request.messages:
|
for message in request.messages:
|
||||||
conversation.extend(
|
result = parse_chat_message_content(message, self.model_config,
|
||||||
parse_chat_message_content(self, message).messages)
|
tokenizer)
|
||||||
|
conversation.extend(result.messages)
|
||||||
|
|
||||||
request.prompt = self.tokenizer.apply_chat_template(
|
request.prompt = tokenizer.apply_chat_template(
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
tokenize=False)
|
tokenize=False,
|
||||||
|
chat_template=self.chat_template)
|
||||||
|
|
||||||
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
|
||||||
request,
|
request,
|
||||||
|
tokenizer,
|
||||||
prompt=request.prompt,
|
prompt=request.prompt,
|
||||||
add_special_tokens=request.add_special_tokens)
|
add_special_tokens=request.add_special_tokens)
|
||||||
|
|
||||||
@ -67,7 +75,9 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
(input_ids, input_text) = self._validate_prompt_and_tokenize(
|
_, lora_request = self._maybe_get_adapter(request)
|
||||||
request, prompt_ids=request.tokens)
|
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||||
|
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
|
||||||
|
request, tokenizer, prompt_ids=request.tokens)
|
||||||
|
|
||||||
return DetokenizeResponse(prompt=input_text)
|
return DetokenizeResponse(prompt=input_text)
|
||||||
|
|||||||
@ -165,6 +165,12 @@ class Detokenizer:
|
|||||||
return len(new_decoded_token_text)
|
return len(new_decoded_token_text)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_none_with_empty(tokens: List[Optional[str]]):
|
||||||
|
for i, token in enumerate(tokens):
|
||||||
|
if token is None:
|
||||||
|
tokens[i] = ""
|
||||||
|
|
||||||
|
|
||||||
def _convert_tokens_to_string_with_added_encoders(
|
def _convert_tokens_to_string_with_added_encoders(
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
output_tokens: List[str],
|
output_tokens: List[str],
|
||||||
@ -223,6 +229,8 @@ def convert_prompt_ids_to_tokens(
|
|||||||
read_offset = len(new_tokens)
|
read_offset = len(new_tokens)
|
||||||
prefix_offset = max(
|
prefix_offset = max(
|
||||||
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
||||||
|
# This is required to guard against out-of-vocab prompt token ids
|
||||||
|
_replace_none_with_empty(new_tokens)
|
||||||
return new_tokens, prefix_offset, read_offset
|
return new_tokens, prefix_offset, read_offset
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -88,6 +88,9 @@ def get_tokenizer(
|
|||||||
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||||
kwargs["use_fast"] = False
|
kwargs["use_fast"] = False
|
||||||
|
|
||||||
|
if "truncation_side" not in kwargs:
|
||||||
|
kwargs["truncation_side"] = "left"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user