mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 22:35:46 +08:00
494 lines
19 KiB
Python
494 lines
19 KiB
Python
import json
|
|
import pathlib
|
|
from dataclasses import dataclass
|
|
from http import HTTPStatus
|
|
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
|
|
|
|
from pydantic import Field
|
|
from typing_extensions import Annotated
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.engine.protocol import EngineClient
|
|
from vllm.entrypoints.logger import RequestLogger
|
|
# yapf conflicts with isort for this block
|
|
# yapf: disable
|
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|
CompletionRequest,
|
|
DetokenizeRequest,
|
|
EmbeddingRequest, ErrorResponse,
|
|
LoadLoraAdapterRequest,
|
|
ModelCard, ModelList,
|
|
ModelPermission,
|
|
TokenizeChatRequest,
|
|
TokenizeCompletionRequest,
|
|
TokenizeRequest,
|
|
UnloadLoraAdapterRequest)
|
|
# yapf: enable
|
|
from vllm.inputs.parse import parse_and_batch_prompt
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
|
from vllm.sequence import Logprob
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
from vllm.utils import AtomicCounter
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BaseModelPath:
|
|
name: str
|
|
model_path: str
|
|
|
|
|
|
@dataclass
|
|
class PromptAdapterPath:
|
|
name: str
|
|
local_path: str
|
|
|
|
|
|
@dataclass
|
|
class LoRAModulePath:
|
|
name: str
|
|
path: str
|
|
base_model_name: Optional[str] = None
|
|
|
|
|
|
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
|
EmbeddingRequest, TokenizeRequest]
|
|
|
|
|
|
class TextTokensPrompt(TypedDict):
|
|
prompt: str
|
|
prompt_token_ids: List[int]
|
|
|
|
|
|
class OpenAIServing:
|
|
|
|
def __init__(
|
|
self,
|
|
engine_client: EngineClient,
|
|
model_config: ModelConfig,
|
|
base_model_paths: List[BaseModelPath],
|
|
*,
|
|
lora_modules: Optional[List[LoRAModulePath]],
|
|
prompt_adapters: Optional[List[PromptAdapterPath]],
|
|
request_logger: Optional[RequestLogger],
|
|
return_tokens_as_token_ids: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.engine_client = engine_client
|
|
self.model_config = model_config
|
|
self.max_model_len = model_config.max_model_len
|
|
|
|
self.base_model_paths = base_model_paths
|
|
|
|
self.lora_id_counter = AtomicCounter(0)
|
|
self.lora_requests = []
|
|
if lora_modules is not None:
|
|
self.lora_requests = [
|
|
LoRARequest(lora_name=lora.name,
|
|
lora_int_id=i,
|
|
lora_path=lora.path,
|
|
base_model_name=lora.base_model_name
|
|
if lora.base_model_name
|
|
and self._is_model_supported(lora.base_model_name)
|
|
else self.base_model_paths[0].name)
|
|
for i, lora in enumerate(lora_modules, start=1)
|
|
]
|
|
|
|
self.prompt_adapter_requests = []
|
|
if prompt_adapters is not None:
|
|
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
|
with pathlib.Path(prompt_adapter.local_path,
|
|
"adapter_config.json").open() as f:
|
|
adapter_config = json.load(f)
|
|
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
|
self.prompt_adapter_requests.append(
|
|
PromptAdapterRequest(
|
|
prompt_adapter_name=prompt_adapter.name,
|
|
prompt_adapter_id=i,
|
|
prompt_adapter_local_path=prompt_adapter.local_path,
|
|
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
|
|
|
self.request_logger = request_logger
|
|
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
|
|
|
async def show_available_models(self) -> ModelList:
|
|
"""Show available models. Right now we only have one model."""
|
|
model_cards = [
|
|
ModelCard(id=base_model.name,
|
|
max_model_len=self.max_model_len,
|
|
root=base_model.model_path,
|
|
permission=[ModelPermission()])
|
|
for base_model in self.base_model_paths
|
|
]
|
|
lora_cards = [
|
|
ModelCard(id=lora.lora_name,
|
|
root=lora.local_path,
|
|
parent=lora.base_model_name if lora.base_model_name else
|
|
self.base_model_paths[0].name,
|
|
permission=[ModelPermission()])
|
|
for lora in self.lora_requests
|
|
]
|
|
prompt_adapter_cards = [
|
|
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
|
root=self.base_model_paths[0].name,
|
|
permission=[ModelPermission()])
|
|
for prompt_adapter in self.prompt_adapter_requests
|
|
]
|
|
model_cards.extend(lora_cards)
|
|
model_cards.extend(prompt_adapter_cards)
|
|
return ModelList(data=model_cards)
|
|
|
|
def create_error_response(
|
|
self,
|
|
message: str,
|
|
err_type: str = "BadRequestError",
|
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
|
return ErrorResponse(message=message,
|
|
type=err_type,
|
|
code=status_code.value)
|
|
|
|
def create_streaming_error_response(
|
|
self,
|
|
message: str,
|
|
err_type: str = "BadRequestError",
|
|
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
|
json_str = json.dumps({
|
|
"error":
|
|
self.create_error_response(message=message,
|
|
err_type=err_type,
|
|
status_code=status_code).model_dump()
|
|
})
|
|
return json_str
|
|
|
|
async def _check_model(
|
|
self,
|
|
request: AnyRequest,
|
|
) -> Optional[ErrorResponse]:
|
|
if self._is_model_supported(request.model):
|
|
return None
|
|
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
|
return None
|
|
if request.model in [
|
|
prompt_adapter.prompt_adapter_name
|
|
for prompt_adapter in self.prompt_adapter_requests
|
|
]:
|
|
return None
|
|
return self.create_error_response(
|
|
message=f"The model `{request.model}` does not exist.",
|
|
err_type="NotFoundError",
|
|
status_code=HTTPStatus.NOT_FOUND)
|
|
|
|
def _maybe_get_adapters(
|
|
self, request: AnyRequest
|
|
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
|
|
None, PromptAdapterRequest]]:
|
|
if self._is_model_supported(request.model):
|
|
return None, None
|
|
for lora in self.lora_requests:
|
|
if request.model == lora.lora_name:
|
|
return lora, None
|
|
for prompt_adapter in self.prompt_adapter_requests:
|
|
if request.model == prompt_adapter.prompt_adapter_name:
|
|
return None, prompt_adapter
|
|
# if _check_model has been called earlier, this will be unreachable
|
|
raise ValueError(f"The model `{request.model}` does not exist.")
|
|
|
|
def _normalize_prompt_text_to_input(
|
|
self,
|
|
request: AnyRequest,
|
|
tokenizer: AnyTokenizer,
|
|
prompt: str,
|
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
|
add_special_tokens: bool,
|
|
) -> TextTokensPrompt:
|
|
if truncate_prompt_tokens is None:
|
|
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
|
|
else:
|
|
encoded = tokenizer(prompt,
|
|
add_special_tokens=add_special_tokens,
|
|
truncation=True,
|
|
max_length=truncate_prompt_tokens)
|
|
|
|
input_ids = encoded.input_ids
|
|
|
|
input_text = prompt
|
|
|
|
return self._validate_input(request, input_ids, input_text)
|
|
|
|
def _normalize_prompt_tokens_to_input(
|
|
self,
|
|
request: AnyRequest,
|
|
tokenizer: AnyTokenizer,
|
|
prompt_ids: List[int],
|
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
|
) -> TextTokensPrompt:
|
|
if truncate_prompt_tokens is None:
|
|
input_ids = prompt_ids
|
|
else:
|
|
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
|
|
|
input_text = tokenizer.decode(input_ids)
|
|
|
|
return self._validate_input(request, input_ids, input_text)
|
|
|
|
def _validate_input(
|
|
self,
|
|
request: AnyRequest,
|
|
input_ids: List[int],
|
|
input_text: str,
|
|
) -> TextTokensPrompt:
|
|
token_num = len(input_ids)
|
|
|
|
# Note: EmbeddingRequest doesn't have max_tokens
|
|
if isinstance(request, EmbeddingRequest):
|
|
if token_num > self.max_model_len:
|
|
raise ValueError(
|
|
f"This model's maximum context length is "
|
|
f"{self.max_model_len} tokens. However, you requested "
|
|
f"{token_num} tokens in the input for embedding "
|
|
f"generation. Please reduce the length of the input.")
|
|
return TextTokensPrompt(prompt=input_text,
|
|
prompt_token_ids=input_ids)
|
|
|
|
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
|
# and does not require model context length validation
|
|
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
|
DetokenizeRequest)):
|
|
return TextTokensPrompt(prompt=input_text,
|
|
prompt_token_ids=input_ids)
|
|
|
|
# chat completion endpoint supports max_completion_tokens
|
|
if isinstance(request, ChatCompletionRequest):
|
|
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
|
max_tokens = request.max_completion_tokens or request.max_tokens
|
|
else:
|
|
max_tokens = request.max_tokens
|
|
if max_tokens is None:
|
|
if token_num >= self.max_model_len:
|
|
raise ValueError(
|
|
f"This model's maximum context length is "
|
|
f"{self.max_model_len} tokens. However, you requested "
|
|
f"{token_num} tokens in the messages, "
|
|
f"Please reduce the length of the messages.")
|
|
elif token_num + max_tokens > self.max_model_len:
|
|
raise ValueError(
|
|
f"This model's maximum context length is "
|
|
f"{self.max_model_len} tokens. However, you requested "
|
|
f"{max_tokens + token_num} tokens "
|
|
f"({token_num} in the messages, "
|
|
f"{max_tokens} in the completion). "
|
|
f"Please reduce the length of the messages or completion.")
|
|
|
|
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
|
|
|
def _tokenize_prompt_input(
|
|
self,
|
|
request: AnyRequest,
|
|
tokenizer: AnyTokenizer,
|
|
prompt_input: Union[str, List[int]],
|
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
|
add_special_tokens: bool = True,
|
|
) -> TextTokensPrompt:
|
|
"""
|
|
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
|
|
that assumes single input.
|
|
"""
|
|
return next(
|
|
self._tokenize_prompt_inputs(
|
|
request,
|
|
tokenizer,
|
|
[prompt_input],
|
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
add_special_tokens=add_special_tokens,
|
|
))
|
|
|
|
def _tokenize_prompt_inputs(
|
|
self,
|
|
request: AnyRequest,
|
|
tokenizer: AnyTokenizer,
|
|
prompt_inputs: Iterable[Union[str, List[int]]],
|
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
|
add_special_tokens: bool = True,
|
|
) -> Iterator[TextTokensPrompt]:
|
|
"""
|
|
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
|
|
that assumes multiple inputs.
|
|
"""
|
|
for text in prompt_inputs:
|
|
if isinstance(text, str):
|
|
yield self._normalize_prompt_text_to_input(
|
|
request,
|
|
tokenizer,
|
|
prompt=text,
|
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
add_special_tokens=add_special_tokens,
|
|
)
|
|
else:
|
|
yield self._normalize_prompt_tokens_to_input(
|
|
request,
|
|
tokenizer,
|
|
prompt_ids=text,
|
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
)
|
|
|
|
def _tokenize_prompt_input_or_inputs(
|
|
self,
|
|
request: AnyRequest,
|
|
tokenizer: AnyTokenizer,
|
|
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
|
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
|
add_special_tokens: bool = True,
|
|
) -> Iterator[TextTokensPrompt]:
|
|
"""
|
|
Tokenize/detokenize depending on the input format.
|
|
|
|
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
|
|
, each input can be a string or array of tokens. Note that each request
|
|
can pass one or more inputs.
|
|
"""
|
|
for prompt_input in parse_and_batch_prompt(input_or_inputs):
|
|
# Although our type checking is based on mypy,
|
|
# VSCode Pyright extension should still work properly
|
|
# "is True" is required for Pyright to perform type narrowing
|
|
# See: https://github.com/microsoft/pyright/issues/7672
|
|
if prompt_input["is_tokens"] is False:
|
|
yield self._normalize_prompt_text_to_input(
|
|
request,
|
|
tokenizer,
|
|
prompt=prompt_input["content"],
|
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
add_special_tokens=add_special_tokens,
|
|
)
|
|
else:
|
|
yield self._normalize_prompt_tokens_to_input(
|
|
request,
|
|
tokenizer,
|
|
prompt_ids=prompt_input["content"],
|
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
)
|
|
|
|
def _log_inputs(
|
|
self,
|
|
request_id: str,
|
|
inputs: Union[str, List[int], TextTokensPrompt],
|
|
params: Optional[Union[SamplingParams, PoolingParams,
|
|
BeamSearchParams]],
|
|
lora_request: Optional[LoRARequest],
|
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
|
) -> None:
|
|
if self.request_logger is None:
|
|
return
|
|
|
|
if isinstance(inputs, str):
|
|
prompt = inputs
|
|
prompt_token_ids = None
|
|
elif isinstance(inputs, list):
|
|
prompt = None
|
|
prompt_token_ids = inputs
|
|
else:
|
|
prompt = inputs["prompt"]
|
|
prompt_token_ids = inputs["prompt_token_ids"]
|
|
|
|
self.request_logger.log_inputs(
|
|
request_id,
|
|
prompt,
|
|
prompt_token_ids,
|
|
params=params,
|
|
lora_request=lora_request,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_decoded_token(logprob: Logprob,
|
|
token_id: int,
|
|
tokenizer: AnyTokenizer,
|
|
return_as_token_id: bool = False) -> str:
|
|
if return_as_token_id:
|
|
return f"token_id:{token_id}"
|
|
|
|
if logprob.decoded_token is not None:
|
|
return logprob.decoded_token
|
|
return tokenizer.decode(token_id)
|
|
|
|
async def _check_load_lora_adapter_request(
|
|
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
|
# Check if both 'lora_name' and 'lora_path' are provided
|
|
if not request.lora_name or not request.lora_path:
|
|
return self.create_error_response(
|
|
message="Both 'lora_name' and 'lora_path' must be provided.",
|
|
err_type="InvalidUserInput",
|
|
status_code=HTTPStatus.BAD_REQUEST)
|
|
|
|
# Check if the lora adapter with the given name already exists
|
|
if any(lora_request.lora_name == request.lora_name
|
|
for lora_request in self.lora_requests):
|
|
return self.create_error_response(
|
|
message=
|
|
f"The lora adapter '{request.lora_name}' has already been"
|
|
"loaded.",
|
|
err_type="InvalidUserInput",
|
|
status_code=HTTPStatus.BAD_REQUEST)
|
|
|
|
return None
|
|
|
|
async def _check_unload_lora_adapter_request(
|
|
self,
|
|
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
|
# Check if either 'lora_name' or 'lora_int_id' is provided
|
|
if not request.lora_name and not request.lora_int_id:
|
|
return self.create_error_response(
|
|
message=
|
|
"either 'lora_name' and 'lora_int_id' needs to be provided.",
|
|
err_type="InvalidUserInput",
|
|
status_code=HTTPStatus.BAD_REQUEST)
|
|
|
|
# Check if the lora adapter with the given name exists
|
|
if not any(lora_request.lora_name == request.lora_name
|
|
for lora_request in self.lora_requests):
|
|
return self.create_error_response(
|
|
message=
|
|
f"The lora adapter '{request.lora_name}' cannot be found.",
|
|
err_type="InvalidUserInput",
|
|
status_code=HTTPStatus.BAD_REQUEST)
|
|
|
|
return None
|
|
|
|
async def load_lora_adapter(
|
|
self,
|
|
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
|
error_check_ret = await self._check_load_lora_adapter_request(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
|
|
lora_name, lora_path = request.lora_name, request.lora_path
|
|
unique_id = self.lora_id_counter.inc(1)
|
|
self.lora_requests.append(
|
|
LoRARequest(lora_name=lora_name,
|
|
lora_int_id=unique_id,
|
|
lora_path=lora_path))
|
|
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
|
|
|
async def unload_lora_adapter(
|
|
self,
|
|
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
|
error_check_ret = await self._check_unload_lora_adapter_request(request
|
|
)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
|
|
lora_name = request.lora_name
|
|
self.lora_requests = [
|
|
lora_request for lora_request in self.lora_requests
|
|
if lora_request.lora_name != lora_name
|
|
]
|
|
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
|
|
|
def _is_model_supported(self, model_name):
|
|
return any(model.name == model_name for model in self.base_model_paths)
|