mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 10:45:01 +08:00
[Frontend][Bugfix] Disallow extra fields in OpenAI API (#4355)
This commit is contained in:
parent
12628d3c78
commit
8947bc3c15
@ -8,6 +8,7 @@ py-cpuinfo
|
|||||||
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
|
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
|
||||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||||
fastapi
|
fastapi
|
||||||
|
openai
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
pydantic >= 2.0 # Required for OpenAI server.
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
prometheus_client >= 0.18.0
|
prometheus_client >= 0.18.0
|
||||||
|
|||||||
@ -21,7 +21,6 @@ pytest-rerunfailures
|
|||||||
pytest-shard
|
pytest-shard
|
||||||
httpx
|
httpx
|
||||||
einops # required for MPT
|
einops # required for MPT
|
||||||
openai
|
|
||||||
requests
|
requests
|
||||||
ray
|
ray
|
||||||
peft
|
peft
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import ray
|
|||||||
import requests
|
import requests
|
||||||
# downloading lora to test lora requests
|
# downloading lora to test lora requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from openai import BadRequestError
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@ -770,6 +771,21 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
|||||||
assert loaded == {"result": 2}, loaded
|
assert loaded == {"result": 2}, loaded
|
||||||
|
|
||||||
|
|
||||||
|
async def test_extra_fields(server, client: openai.AsyncOpenAI):
|
||||||
|
with pytest.raises(BadRequestError) as exc_info:
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant.",
|
||||||
|
"extra_field": "0",
|
||||||
|
}], # type: ignore
|
||||||
|
temperature=0,
|
||||||
|
seed=0)
|
||||||
|
|
||||||
|
assert "extra_forbidden" in exc_info.value.message
|
||||||
|
|
||||||
|
|
||||||
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
||||||
simple_sql_grammar = """
|
simple_sql_grammar = """
|
||||||
start: select_statement
|
start: select_statement
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import json
|
|||||||
import ssl
|
import ssl
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.entrypoints.openai.serving_engine import LoRA
|
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
||||||
|
|
||||||
|
|
||||||
class LoRAParserAction(argparse.Action):
|
class LoRAParserAction(argparse.Action):
|
||||||
@ -18,7 +18,7 @@ class LoRAParserAction(argparse.Action):
|
|||||||
lora_list = []
|
lora_list = []
|
||||||
for item in values:
|
for item in values:
|
||||||
name, path = item.split('=')
|
name, path = item.split('=')
|
||||||
lora_list.append(LoRA(name, path))
|
lora_list.append(LoRAModulePath(name, path))
|
||||||
setattr(namespace, self.dest, lora_list)
|
setattr(namespace, self.dest, lora_list)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,14 +4,20 @@ import time
|
|||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
class OpenAIBaseModel(BaseModel):
|
||||||
|
# OpenAI API does not allow extra fields
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(OpenAIBaseModel):
|
||||||
object: str = "error"
|
object: str = "error"
|
||||||
message: str
|
message: str
|
||||||
type: str
|
type: str
|
||||||
@ -19,7 +25,7 @@ class ErrorResponse(BaseModel):
|
|||||||
code: int
|
code: int
|
||||||
|
|
||||||
|
|
||||||
class ModelPermission(BaseModel):
|
class ModelPermission(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||||
object: str = "model_permission"
|
object: str = "model_permission"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
@ -34,7 +40,7 @@ class ModelPermission(BaseModel):
|
|||||||
is_blocking: bool = False
|
is_blocking: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(OpenAIBaseModel):
|
||||||
id: str
|
id: str
|
||||||
object: str = "model"
|
object: str = "model"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
@ -44,26 +50,26 @@ class ModelCard(BaseModel):
|
|||||||
permission: List[ModelPermission] = Field(default_factory=list)
|
permission: List[ModelPermission] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ModelList(BaseModel):
|
class ModelList(OpenAIBaseModel):
|
||||||
object: str = "list"
|
object: str = "list"
|
||||||
data: List[ModelCard] = Field(default_factory=list)
|
data: List[ModelCard] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class UsageInfo(BaseModel):
|
class UsageInfo(OpenAIBaseModel):
|
||||||
prompt_tokens: int = 0
|
prompt_tokens: int = 0
|
||||||
total_tokens: int = 0
|
total_tokens: int = 0
|
||||||
completion_tokens: Optional[int] = 0
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormat(BaseModel):
|
class ResponseFormat(OpenAIBaseModel):
|
||||||
# type must be "json_object" or "text"
|
# type must be "json_object" or "text"
|
||||||
type: Literal["text", "json_object"]
|
type: Literal["text", "json_object"]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(OpenAIBaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
# https://platform.openai.com/docs/api-reference/chat/create
|
# https://platform.openai.com/docs/api-reference/chat/create
|
||||||
messages: List[Dict[str, str]]
|
messages: List[ChatCompletionMessageParam]
|
||||||
model: str
|
model: str
|
||||||
frequency_penalty: Optional[float] = 0.0
|
frequency_penalty: Optional[float] = 0.0
|
||||||
logit_bias: Optional[Dict[str, float]] = None
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
@ -204,7 +210,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(OpenAIBaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
# https://platform.openai.com/docs/api-reference/completions/create
|
# https://platform.openai.com/docs/api-reference/completions/create
|
||||||
model: str
|
model: str
|
||||||
@ -343,19 +349,19 @@ class CompletionRequest(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
class LogProbs(OpenAIBaseModel):
|
||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
tokens: List[str] = Field(default_factory=list)
|
tokens: List[str] = Field(default_factory=list)
|
||||||
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
|
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseChoice(BaseModel):
|
class CompletionResponseChoice(OpenAIBaseModel):
|
||||||
index: int
|
index: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[str] = None
|
||||||
stop_reason: Union[None, int, str] = Field(
|
stop_reason: Optional[Union[int, str]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=(
|
description=(
|
||||||
"The stop string or token id that caused the completion "
|
"The stop string or token id that caused the completion "
|
||||||
@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
object: str = "text_completion"
|
object: str = "text_completion"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
@ -373,12 +379,12 @@ class CompletionResponse(BaseModel):
|
|||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseStreamChoice(BaseModel):
|
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||||
index: int
|
index: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[str] = None
|
||||||
stop_reason: Union[None, int, str] = Field(
|
stop_reason: Optional[Union[int, str]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=(
|
description=(
|
||||||
"The stop string or token id that caused the completion "
|
"The stop string or token id that caused the completion "
|
||||||
@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompletionStreamResponse(BaseModel):
|
class CompletionStreamResponse(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
object: str = "text_completion"
|
object: str = "text_completion"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel):
|
|||||||
usage: Optional[UsageInfo] = Field(default=None)
|
usage: Optional[UsageInfo] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(OpenAIBaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatMessage
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[str] = None
|
||||||
stop_reason: Union[None, int, str] = None
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||||
object: str = "chat.completion"
|
object: str = "chat.completion"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class DeltaMessage(OpenAIBaseModel):
|
||||||
role: Optional[str] = None
|
role: Optional[str] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[str] = None
|
||||||
stop_reason: Union[None, int, str] = None
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionStreamResponse(BaseModel):
|
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||||
object: str = "chat.completion.chunk"
|
object: str = "chat.completion.chunk"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
import codecs
|
import codecs
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
|
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
|
||||||
|
Optional, Tuple, TypedDict, Union, final)
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from openai.types.chat import (ChatCompletionContentPartParam,
|
||||||
|
ChatCompletionRole)
|
||||||
|
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
@ -10,7 +13,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
|
OpenAIServing)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor)
|
||||||
@ -20,20 +24,41 @@ from vllm.utils import random_uuid
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@final # So that it should be compatible with Dict[str, str]
|
||||||
|
class ConversationMessage(TypedDict):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingChat(OpenAIServing):
|
class OpenAIServingChat(OpenAIServing):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
engine: AsyncLLMEngine,
|
engine: AsyncLLMEngine,
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
response_role: str,
|
response_role: str,
|
||||||
lora_modules: Optional[List[LoRA]] = None,
|
lora_modules: Optional[List[LoRAModulePath]] = None,
|
||||||
chat_template=None):
|
chat_template: Optional[str] = None):
|
||||||
super().__init__(engine=engine,
|
super().__init__(engine=engine,
|
||||||
served_model_names=served_model_names,
|
served_model_names=served_model_names,
|
||||||
lora_modules=lora_modules)
|
lora_modules=lora_modules)
|
||||||
self.response_role = response_role
|
self.response_role = response_role
|
||||||
self._load_chat_template(chat_template)
|
self._load_chat_template(chat_template)
|
||||||
|
|
||||||
|
def _parse_chat_message_content(
|
||||||
|
self,
|
||||||
|
role: ChatCompletionRole,
|
||||||
|
content: Optional[Union[str,
|
||||||
|
Iterable[ChatCompletionContentPartParam]]],
|
||||||
|
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
|
||||||
|
if content is None:
|
||||||
|
return [], []
|
||||||
|
if isinstance(content, str):
|
||||||
|
return [ConversationMessage(role=role, content=content)], []
|
||||||
|
|
||||||
|
# To be implemented: https://github.com/vllm-project/vllm/pull/3467
|
||||||
|
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
|
||||||
|
raise NotImplementedError("Complex input not supported yet")
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, request: ChatCompletionRequest, raw_request: Request
|
self, request: ChatCompletionRequest, raw_request: Request
|
||||||
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
||||||
@ -52,10 +77,19 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
conversation: List[ConversationMessage] = []
|
||||||
|
|
||||||
|
for m in request.messages:
|
||||||
|
messages, _ = self._parse_chat_message_content(
|
||||||
|
m["role"], m["content"])
|
||||||
|
|
||||||
|
conversation.extend(messages)
|
||||||
|
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
conversation=request.messages,
|
conversation=conversation,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=request.add_generation_prompt)
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in applying chat template from request: %s", e)
|
logger.error("Error in applying chat template from request: %s", e)
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -105,9 +139,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
async def chat_completion_stream_generator(
|
async def chat_completion_stream_generator(
|
||||||
self, request: ChatCompletionRequest,
|
self, request: ChatCompletionRequest,
|
||||||
result_generator: AsyncIterator[RequestOutput], request_id: str
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
|
request_id: str) -> AsyncGenerator[str, None]:
|
||||||
|
|
||||||
model_name = self.served_model_names[0]
|
model_name = self.served_model_names[0]
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
chunk_object_type = "chat.completion.chunk"
|
chunk_object_type = "chat.completion.chunk"
|
||||||
@ -252,7 +285,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
model_name = self.served_model_names[0]
|
model_name = self.served_model_names[0]
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
final_res: RequestOutput = None
|
final_res: Optional[RequestOutput] = None
|
||||||
|
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
if await raw_request.is_disconnected():
|
||||||
@ -317,7 +350,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _load_chat_template(self, chat_template):
|
def _load_chat_template(self, chat_template: Optional[str]):
|
||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
|
|
||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
|
|||||||
@ -11,7 +11,8 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
|||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
LogProbs, UsageInfo)
|
LogProbs, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
|
OpenAIServing)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor)
|
||||||
@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
engine: AsyncLLMEngine,
|
engine: AsyncLLMEngine,
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
lora_modules: Optional[List[LoRA]] = None):
|
lora_modules: Optional[List[LoRAModulePath]] = None):
|
||||||
super().__init__(engine=engine,
|
super().__init__(engine=engine,
|
||||||
served_model_names=served_model_names,
|
served_model_names=served_model_names,
|
||||||
lora_modules=lora_modules)
|
lora_modules=lora_modules)
|
||||||
@ -84,7 +85,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators = []
|
generators: List[AsyncIterator[RequestOutput]] = []
|
||||||
try:
|
try:
|
||||||
sampling_params = request.to_sampling_params()
|
sampling_params = request.to_sampling_params()
|
||||||
lora_request = self._maybe_get_lora(request)
|
lora_request = self._maybe_get_lora(request)
|
||||||
@ -148,7 +149,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
num_prompts=len(prompts))
|
num_prompts=len(prompts))
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res_batch: RequestOutput = [None] * len(prompts)
|
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||||
try:
|
try:
|
||||||
async for i, res in result_generator:
|
async for i, res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
if await raw_request.is_disconnected():
|
||||||
|
|||||||
@ -22,17 +22,15 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoRA:
|
class LoRAModulePath:
|
||||||
name: str
|
name: str
|
||||||
local_path: str
|
local_path: str
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServing:
|
class OpenAIServing:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
|
||||||
engine: AsyncLLMEngine,
|
lora_modules: Optional[List[LoRAModulePath]]):
|
||||||
served_model_names: List[str],
|
|
||||||
lora_modules=Optional[List[LoRA]]):
|
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.served_model_names = served_model_names
|
self.served_model_names = served_model_names
|
||||||
if lora_modules is None:
|
if lora_modules is None:
|
||||||
@ -158,7 +156,9 @@ class OpenAIServing:
|
|||||||
})
|
})
|
||||||
return json_str
|
return json_str
|
||||||
|
|
||||||
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
async def _check_model(
|
||||||
|
self, request: Union[CompletionRequest, ChatCompletionRequest]
|
||||||
|
) -> Optional[ErrorResponse]:
|
||||||
if request.model in self.served_model_names:
|
if request.model in self.served_model_names:
|
||||||
return None
|
return None
|
||||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||||
@ -168,14 +168,16 @@ class OpenAIServing:
|
|||||||
err_type="NotFoundError",
|
err_type="NotFoundError",
|
||||||
status_code=HTTPStatus.NOT_FOUND)
|
status_code=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
|
def _maybe_get_lora(
|
||||||
|
self, request: Union[CompletionRequest, ChatCompletionRequest]
|
||||||
|
) -> Optional[LoRARequest]:
|
||||||
if request.model in self.served_model_names:
|
if request.model in self.served_model_names:
|
||||||
return None
|
return None
|
||||||
for lora in self.lora_requests:
|
for lora in self.lora_requests:
|
||||||
if request.model == lora.lora_name:
|
if request.model == lora.lora_name:
|
||||||
return lora
|
return lora
|
||||||
# if _check_model has been called earlier, this will be unreachable
|
# if _check_model has been called earlier, this will be unreachable
|
||||||
raise ValueError("The model `{request.model}` does not exist.")
|
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||||
|
|
||||||
def _validate_prompt_and_tokenize(
|
def _validate_prompt_and_tokenize(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user