mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:15:01 +08:00
[Frontend] Move async logic outside of constructor (#4674)
This commit is contained in:
parent
16bc0a098f
commit
f12b20decc
@ -60,13 +60,12 @@ class MockServingChat:
|
|||||||
tokenizer: MockTokenizer
|
tokenizer: MockTokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_load_chat_template():
|
||||||
async def test_load_chat_template():
|
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
tokenizer = MockTokenizer()
|
tokenizer = MockTokenizer()
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
await OpenAIServingChat._load_chat_template(
|
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||||
mock_serving_chat, chat_template=chatml_jinja_path)
|
chat_template=chatml_jinja_path)
|
||||||
|
|
||||||
template_content = tokenizer.chat_template
|
template_content = tokenizer.chat_template
|
||||||
|
|
||||||
@ -77,8 +76,7 @@ async def test_load_chat_template():
|
|||||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
|
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_no_load_chat_template_filelike():
|
||||||
async def test_no_load_chat_template_filelike():
|
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
template = "../../examples/does_not_exist"
|
template = "../../examples/does_not_exist"
|
||||||
tokenizer = MockTokenizer()
|
tokenizer = MockTokenizer()
|
||||||
@ -86,35 +84,33 @@ async def test_no_load_chat_template_filelike():
|
|||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="looks like a file path"):
|
with pytest.raises(ValueError, match="looks like a file path"):
|
||||||
await OpenAIServingChat._load_chat_template(mock_serving_chat,
|
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||||
chat_template=template)
|
chat_template=template)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
def test_no_load_chat_template_literallike():
|
||||||
async def test_no_load_chat_template_literallike():
|
|
||||||
# Testing chatml template
|
# Testing chatml template
|
||||||
template = "{{ messages }}"
|
template = "{{ messages }}"
|
||||||
tokenizer = MockTokenizer()
|
tokenizer = MockTokenizer()
|
||||||
|
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
await OpenAIServingChat._load_chat_template(mock_serving_chat,
|
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||||
chat_template=template)
|
chat_template=template)
|
||||||
template_content = tokenizer.chat_template
|
template_content = tokenizer.chat_template
|
||||||
|
|
||||||
assert template_content == template
|
assert template_content == template
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model,template,add_generation_prompt,expected_output",
|
"model,template,add_generation_prompt,expected_output",
|
||||||
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
||||||
async def test_get_gen_prompt(model, template, add_generation_prompt,
|
def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||||
expected_output):
|
expected_output):
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||||
mock_serving_chat = MockServingChat(tokenizer)
|
mock_serving_chat = MockServingChat(tokenizer)
|
||||||
await OpenAIServingChat._load_chat_template(mock_serving_chat,
|
OpenAIServingChat._load_chat_template(mock_serving_chat,
|
||||||
chat_template=template)
|
chat_template=template)
|
||||||
|
|
||||||
# Create a mock request object using keyword arguments
|
# Create a mock request object using keyword arguments
|
||||||
mock_request = ChatCompletionRequest(
|
mock_request = ChatCompletionRequest(
|
||||||
|
|||||||
@ -20,11 +20,15 @@ class MockModelConfig:
|
|||||||
class MockEngine:
|
class MockEngine:
|
||||||
|
|
||||||
async def get_model_config(self):
|
async def get_model_config(self):
|
||||||
return MockModelConfig
|
return MockModelConfig()
|
||||||
|
|
||||||
|
|
||||||
async def _async_serving_chat_init():
|
async def _async_serving_chat_init():
|
||||||
serving_completion = OpenAIServingChat(MockEngine(),
|
engine = MockEngine()
|
||||||
|
model_config = await engine.get_model_config()
|
||||||
|
|
||||||
|
serving_completion = OpenAIServingChat(engine,
|
||||||
|
model_config,
|
||||||
served_model_names=[MODEL_NAME],
|
served_model_names=[MODEL_NAME],
|
||||||
response_role="assistant",
|
response_role="assistant",
|
||||||
chat_template=CHAT_TEMPLATE)
|
chat_template=CHAT_TEMPLATE)
|
||||||
|
|||||||
@ -516,7 +516,7 @@ class EngineArgs:
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
# Get the list of attributes of this dataclass.
|
# Get the list of attributes of this dataclass.
|
||||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
# Set the attributes from the parsed arguments.
|
# Set the attributes from the parsed arguments.
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import inspect
|
|||||||
import re
|
import re
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Set
|
from typing import Optional, Set
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@ -164,15 +164,32 @@ if __name__ == "__main__":
|
|||||||
served_model_names = args.served_model_name
|
served_model_names = args.served_model_name
|
||||||
else:
|
else:
|
||||||
served_model_names = [args.model]
|
served_model_names = [args.model]
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(
|
engine = AsyncLLMEngine.from_engine_args(
|
||||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
||||||
openai_serving_chat = OpenAIServingChat(engine, served_model_names,
|
|
||||||
|
event_loop: Optional[asyncio.AbstractEventLoop]
|
||||||
|
try:
|
||||||
|
event_loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
event_loop = None
|
||||||
|
|
||||||
|
if event_loop is not None and event_loop.is_running():
|
||||||
|
# If the current is instanced by Ray Serve,
|
||||||
|
# there is already a running event loop
|
||||||
|
model_config = event_loop.run_until_complete(engine.get_model_config())
|
||||||
|
else:
|
||||||
|
# When using single vLLM without engine_use_ray
|
||||||
|
model_config = asyncio.run(engine.get_model_config())
|
||||||
|
|
||||||
|
openai_serving_chat = OpenAIServingChat(engine, model_config,
|
||||||
|
served_model_names,
|
||||||
args.response_role,
|
args.response_role,
|
||||||
args.lora_modules,
|
args.lora_modules,
|
||||||
args.chat_template)
|
args.chat_template)
|
||||||
openai_serving_completion = OpenAIServingCompletion(
|
openai_serving_completion = OpenAIServingCompletion(
|
||||||
engine, served_model_names, args.lora_modules)
|
engine, model_config, served_model_names, args.lora_modules)
|
||||||
|
|
||||||
app.root_path = args.root_path
|
app.root_path = args.root_path
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import codecs
|
import codecs
|
||||||
import time
|
import time
|
||||||
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
|
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
|
||||||
@ -8,6 +7,7 @@ from fastapi import Request
|
|||||||
from openai.types.chat import (ChatCompletionContentPartParam,
|
from openai.types.chat import (ChatCompletionContentPartParam,
|
||||||
ChatCompletionRole)
|
ChatCompletionRole)
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest, ChatCompletionResponse,
|
ChatCompletionRequest, ChatCompletionResponse,
|
||||||
@ -35,17 +35,47 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
engine: AsyncLLMEngine,
|
engine: AsyncLLMEngine,
|
||||||
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
response_role: str,
|
response_role: str,
|
||||||
lora_modules: Optional[List[LoRAModulePath]] = None,
|
lora_modules: Optional[List[LoRAModulePath]] = None,
|
||||||
chat_template: Optional[str] = None):
|
chat_template: Optional[str] = None):
|
||||||
super().__init__(engine=engine,
|
super().__init__(engine=engine,
|
||||||
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
served_model_names=served_model_names,
|
||||||
lora_modules=lora_modules,
|
lora_modules=lora_modules)
|
||||||
await_post_init=self._load_chat_template(
|
|
||||||
chat_template=chat_template))
|
|
||||||
|
|
||||||
self.response_role = response_role
|
self.response_role = response_role
|
||||||
|
self._load_chat_template(chat_template)
|
||||||
|
|
||||||
|
def _load_chat_template(self, chat_template: Optional[str]):
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
|
||||||
|
if chat_template is not None:
|
||||||
|
try:
|
||||||
|
with open(chat_template, "r") as f:
|
||||||
|
tokenizer.chat_template = f.read()
|
||||||
|
except OSError as e:
|
||||||
|
JINJA_CHARS = "{}\n"
|
||||||
|
if not any(c in chat_template for c in JINJA_CHARS):
|
||||||
|
msg = (f"The supplied chat template ({chat_template}) "
|
||||||
|
f"looks like a file path, but it failed to be "
|
||||||
|
f"opened. Reason: {e}")
|
||||||
|
raise ValueError(msg) from e
|
||||||
|
|
||||||
|
# If opening a file fails, set chat template to be args to
|
||||||
|
# ensure we decode so our escape are interpreted correctly
|
||||||
|
tokenizer.chat_template = codecs.decode(
|
||||||
|
chat_template, "unicode_escape")
|
||||||
|
|
||||||
|
logger.info("Using supplied chat template:\n%s",
|
||||||
|
tokenizer.chat_template)
|
||||||
|
elif tokenizer.chat_template is not None:
|
||||||
|
logger.info("Using default chat template:\n%s",
|
||||||
|
tokenizer.chat_template)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"No chat template provided. Chat API will not work.")
|
||||||
|
|
||||||
def _parse_chat_message_content(
|
def _parse_chat_message_content(
|
||||||
self,
|
self,
|
||||||
@ -357,36 +387,4 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _load_chat_template(self, chat_template: Optional[str]):
|
|
||||||
while self.tokenizer is None:
|
|
||||||
# Give the parent class time to load the tokenizer
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
|
|
||||||
if chat_template is not None:
|
|
||||||
try:
|
|
||||||
with open(chat_template, "r") as f:
|
|
||||||
tokenizer.chat_template = f.read()
|
|
||||||
except OSError as e:
|
|
||||||
JINJA_CHARS = "{}\n"
|
|
||||||
if not any(c in chat_template for c in JINJA_CHARS):
|
|
||||||
msg = (f"The supplied chat template ({chat_template}) "
|
|
||||||
f"looks like a file path, but it failed to be "
|
|
||||||
f"opened. Reason: {e}")
|
|
||||||
raise ValueError(msg) from e
|
|
||||||
|
|
||||||
# If opening a file fails, set chat template to be args to
|
|
||||||
# ensure we decode so our escape are interpreted correctly
|
|
||||||
tokenizer.chat_template = codecs.decode(
|
|
||||||
chat_template, "unicode_escape")
|
|
||||||
|
|
||||||
logger.info("Using supplied chat template:\n%s",
|
|
||||||
tokenizer.chat_template)
|
|
||||||
elif tokenizer.chat_template is not None:
|
|
||||||
logger.info("Using default chat template:\n%s",
|
|
||||||
tokenizer.chat_template)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"No chat template provided. Chat API will not work.")
|
|
||||||
@ -4,6 +4,7 @@ from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
@ -52,11 +53,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
|
|||||||
|
|
||||||
class OpenAIServingCompletion(OpenAIServing):
|
class OpenAIServingCompletion(OpenAIServing):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
||||||
engine: AsyncLLMEngine,
|
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
lora_modules: Optional[List[LoRAModulePath]] = None):
|
lora_modules: Optional[List[LoRAModulePath]]):
|
||||||
super().__init__(engine=engine,
|
super().__init__(engine=engine,
|
||||||
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
served_model_names=served_model_names,
|
||||||
lora_modules=lora_modules)
|
lora_modules=lora_modules)
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
CompletionRequest, ErrorResponse,
|
CompletionRequest, ErrorResponse,
|
||||||
@ -29,13 +28,24 @@ class LoRAModulePath:
|
|||||||
|
|
||||||
class OpenAIServing:
|
class OpenAIServing:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
||||||
engine: AsyncLLMEngine,
|
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
lora_modules: Optional[List[LoRAModulePath]],
|
lora_modules: Optional[List[LoRAModulePath]]):
|
||||||
await_post_init: Optional[Awaitable[Any]] = None):
|
super().__init__()
|
||||||
|
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
|
self.max_model_len = model_config.max_model_len
|
||||||
|
|
||||||
|
# A separate tokenizer to map token IDs to strings.
|
||||||
|
self.tokenizer = get_tokenizer(
|
||||||
|
model_config.tokenizer,
|
||||||
|
tokenizer_mode=model_config.tokenizer_mode,
|
||||||
|
tokenizer_revision=model_config.tokenizer_revision,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
truncation_side="left")
|
||||||
|
|
||||||
self.served_model_names = served_model_names
|
self.served_model_names = served_model_names
|
||||||
|
|
||||||
if lora_modules is None:
|
if lora_modules is None:
|
||||||
self.lora_requests = []
|
self.lora_requests = []
|
||||||
else:
|
else:
|
||||||
@ -47,38 +57,6 @@ class OpenAIServing:
|
|||||||
) for i, lora in enumerate(lora_modules, start=1)
|
) for i, lora in enumerate(lora_modules, start=1)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.max_model_len = 0
|
|
||||||
# Lazy initialized
|
|
||||||
self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
|
||||||
|
|
||||||
try:
|
|
||||||
event_loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
event_loop = None
|
|
||||||
|
|
||||||
if event_loop is not None and event_loop.is_running():
|
|
||||||
# If the current is instanced by Ray Serve,
|
|
||||||
# there is already a running event loop
|
|
||||||
event_loop.create_task(self._post_init(await_post_init))
|
|
||||||
else:
|
|
||||||
# When using single vLLM without engine_use_ray
|
|
||||||
asyncio.run(self._post_init(await_post_init))
|
|
||||||
|
|
||||||
async def _post_init(self, await_post_init):
|
|
||||||
engine_model_config = await self.engine.get_model_config()
|
|
||||||
self.max_model_len = engine_model_config.max_model_len
|
|
||||||
|
|
||||||
# A separate tokenizer to map token IDs to strings.
|
|
||||||
self.tokenizer = get_tokenizer(
|
|
||||||
engine_model_config.tokenizer,
|
|
||||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
|
||||||
tokenizer_revision=engine_model_config.tokenizer_revision,
|
|
||||||
trust_remote_code=engine_model_config.trust_remote_code,
|
|
||||||
truncation_side="left")
|
|
||||||
|
|
||||||
if await_post_init is not None:
|
|
||||||
await await_post_init
|
|
||||||
|
|
||||||
async def show_available_models(self) -> ModelList:
|
async def show_available_models(self) -> ModelList:
|
||||||
"""Show available models. Right now we only have one model."""
|
"""Show available models. Right now we only have one model."""
|
||||||
model_cards = [
|
model_cards = [
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user