[mypy] Enable following imports for entrypoints (#7248)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Fei <dfdfcai4@gmail.com>
This commit is contained in:
Cyrus Leung 2024-08-21 14:28:21 +08:00 committed by GitHub
parent 4506641212
commit baaedfdb2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 480 additions and 320 deletions

View File

@ -38,7 +38,6 @@ jobs:
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip

View File

@ -6,7 +6,7 @@ sphinx-argparse==0.4.0
msgspec
# packages to install to build the documentation
pydantic
pydantic >= 2.8
-f https://download.pytorch.org/whl/cpu
torch
py-cpuinfo

View File

@ -102,7 +102,6 @@ mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip

View File

@ -56,6 +56,7 @@ files = [
"vllm/*.py",
"vllm/adapter_commons",
"vllm/assets",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",

View File

@ -11,7 +11,7 @@ fastapi
aiohttp
openai >= 1.0 # Ensure modern openai package (ensure types module present)
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
pydantic >= 2.8 # Required for OpenAI server.
pillow # Required for image processing
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0

View File

@ -1,7 +1,7 @@
# imports for guided decoding tests
import json
import re
from typing import List
from typing import Dict, List, Optional
import jsonschema
import openai # use the official client for correctness check
@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: Optional[int]):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError):
await client.chat.completions.create(**params)
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs is not None:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}
completion_1 = await client.chat.completions.create(**params)
params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)
assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",

View File

@ -3,7 +3,7 @@ import json
import re
import shutil
from tempfile import TemporaryDirectory
from typing import Dict, List
from typing import Dict, List, Optional
import jsonschema
import openai # use the official client for correctness check
@ -268,92 +268,6 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
assert len(completion.choices[0].text) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str, prompt_logprobs: int):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.chat.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}
completion_1 = await client.chat.completions.create(**params)
params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)
assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
(MODEL_NAME, 0),
@ -361,7 +275,7 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
(MODEL_NAME, None)])
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: int):
prompt_logprobs: Optional[int]):
params: Dict = {
"prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name,
@ -369,17 +283,12 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError):
await client.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
if prompt_logprobs is not None:
assert completion.choices[0].prompt_logprobs is not None
assert len(completion.choices[0].prompt_logprobs) > 0

View File

@ -6,7 +6,6 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
import torch
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
import vllm.envs as envs
@ -31,6 +30,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once
@ -427,8 +427,8 @@ class _AsyncLLMEngine(LLMEngine):
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
@ -771,7 +771,7 @@ class AsyncLLMEngine:
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
) -> AnyTokenizer:
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote( # type: ignore
lora_request)

View File

@ -3,9 +3,9 @@ from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, TypeVar, Union
from typing import Set, Tuple, Type, Union
from typing_extensions import assert_never
from typing_extensions import TypeVar, assert_never
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
@ -43,8 +43,9 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device
@ -67,6 +68,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
return config.to_diff_dict()
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
PromptComponents = Tuple[Optional[str], List[int],
@ -493,12 +495,21 @@ class LLMEngine:
"skip_tokenizer_init is True")
def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
self,
group_type: Type[_G] = BaseTokenizerGroup,
*,
missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
) -> _G:
tokenizer_group = self.tokenizer
return self.tokenizer
if tokenizer_group is None:
raise ValueError(missing_msg)
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return tokenizer_group
def get_tokenizer(
self,
@ -693,8 +704,8 @@ class LLMEngine:
* prompt token ids
'''
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return tokenizer.encode(request_id=request_id,
prompt=prompt,

View File

@ -1,13 +1,12 @@
from abc import ABC, abstractmethod
from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
@ -29,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker",
):
"""Create an output processor.

View File

@ -1,8 +1,6 @@
import functools
from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
@ -12,6 +10,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter
logger = init_logger(__name__)
@ -36,7 +35,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
detokenizer: Detokenizer,
scheduler: List[Scheduler],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: StopChecker,
):
self.detokenizer = detokenizer

View File

@ -1,10 +1,9 @@
from typing import Callable, Optional
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker:
@ -15,8 +14,7 @@ class StopChecker:
"""
def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence],
PreTrainedTokenizer]):
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
# Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq

View File

@ -1,8 +1,6 @@
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
runtime_checkable)
from transformers import PreTrainedTokenizer
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
@ -12,6 +10,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
@ -40,6 +39,7 @@ class AsyncEngineClient(Protocol):
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request"""
...
def encode(
self,
@ -50,6 +50,7 @@ class AsyncEngineClient(Protocol):
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model."""
...
async def abort(self, request_id: str) -> None:
"""Abort a request.
@ -60,25 +61,29 @@ class AsyncEngineClient(Protocol):
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
...
async def get_decoding_config(self) -> DecodingConfig:
...
"""Get the decoding configuration of the vLLM engine."""
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> PreTrainedTokenizer:
"""Get the appropriate Tokenizer for the request"""
) -> AnyTokenizer:
"""Get the appropriate tokenizer for the request"""
...
async def is_tracing_enabled(self) -> bool:
pass
...
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
pass
...
async def check_health(self) -> None:
"""Raise if unhealthy"""
...

View File

@ -61,6 +61,7 @@ async def generate(request: Request) -> Response:
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [
prompt + output.text for output in request_output.outputs
]
@ -80,6 +81,7 @@ async def generate(request: Request) -> Response:
assert final_output is not None
prompt = final_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
@ -115,6 +117,7 @@ async def run_server(args: Namespace,
logger.info("args: %s", args)
app = await init_app(args, llm_engine)
assert engine is not None
shutdown_task = await serve_http(
app,

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union, cast)
Union)
# yapf conflicts with isort for this block
# yapf: disable
@ -15,9 +15,8 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from transformers import PreTrainedTokenizer
from typing_extensions import Required, TypedDict
from pydantic import ConfigDict, TypeAdapter
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
@ -50,9 +49,9 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part."""
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam]
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam, ]
class CustomChatCompletionMessageParam(TypedDict, total=False):
@ -114,7 +113,7 @@ def load_chat_template(
@lru_cache(maxsize=None)
def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
@ -151,11 +150,16 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str,
return f"{placeholder_token_str}\n{text_prompt}"
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
@ -164,7 +168,7 @@ def _parse_chat_message_content_parts(
for part in parts:
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
text = _TextParser.validate_python(part)["text"]
texts.append(text)
elif part_type == "image_url":
modality = "image"
@ -172,8 +176,7 @@ def _parse_chat_message_content_parts(
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
image_url = _ImageParser.validate_python(part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
@ -188,8 +191,7 @@ def _parse_chat_message_content_parts(
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
audio_url = cast(ChatCompletionContentPartAudioParam,
part)["audio_url"]
audio_url = _AudioParser.validate_python(part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future)
else:
@ -219,7 +221,7 @@ def _parse_chat_message_content_parts(
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
@ -230,14 +232,18 @@ def _parse_chat_message_content(
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)
return _parse_chat_message_content_parts(
role,
content, # type: ignore
model_config,
tokenizer,
)
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []

View File

@ -1,8 +1,7 @@
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
@ -20,7 +19,9 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs
@ -122,7 +123,7 @@ class LLM:
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
@ -175,22 +176,19 @@ class LLM:
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
def get_tokenizer(
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer.tokenizer
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"):
self.llm_engine.tokenizer.tokenizer = tokenizer
tokenizer_group.tokenizer = tokenizer
else:
self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
tokenizer)
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
@overload # LEGACY: single (prompt + optional token ids)
def generate(
@ -578,6 +576,8 @@ class LLM:
inputs: List[PromptInputs] = []
for i in range(num_requests):
item: PromptInputs
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
elif prompt_token_ids is not None:
@ -635,7 +635,7 @@ class LLM:
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
request_id = str(next(self.request_counter))

View File

@ -15,6 +15,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig
@ -29,14 +30,16 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
@ -90,7 +93,8 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
@ -142,12 +146,15 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
logger.info("Started engine process with PID %d",
rpc_server_process.pid)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(rpc_path)
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
try:
while True:
try:
await async_engine_client.setup()
await rpc_client.setup()
break
except TimeoutError as e:
if not rpc_server_process.is_alive():
@ -161,7 +168,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
rpc_server_process.terminate()
# Close all open connections to the backend
async_engine_client.close()
rpc_client.close()
# Wait for server process to join
rpc_server_process.join()
@ -216,10 +223,11 @@ async def tokenize(request: TokenizeRequest):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, TokenizeResponse)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
@ -227,10 +235,11 @@ async def detokenize(request: DetokenizeRequest):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, DetokenizeResponse)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models")
async def show_available_models():
@ -252,13 +261,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
@ -267,12 +274,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@ -281,9 +287,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)

View File

@ -7,6 +7,7 @@ purposes.
import argparse
import json
import ssl
from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
@ -16,8 +17,19 @@ from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
lora_list: List[LoRAModulePath] = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
@ -26,8 +38,19 @@ class LoRAParserAction(argparse.Action):
class PromptAdapterParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
adapter_list = []
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
adapter_list: List[PromptAdapterPath] = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))

View File

@ -2,9 +2,9 @@ from functools import lru_cache, partial
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
import torch
from transformers import PreTrainedTokenizer
from vllm.sampling_params import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
class AllowedTokenIdsLogitsProcessor:
@ -51,10 +51,11 @@ def logit_bias_logits_processor(
def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]:
logits_processors = []
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
tokenizer: AnyTokenizer,
) -> List[LogitsProcessor]:
logits_processors: List[LogitsProcessor] = []
if logit_bias:
try:
# Convert token_id to integer

View File

@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@ -14,11 +13,13 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
_LONG_INFO: Union["torch.iinfo", Namespace]
try:
from sphinx.ext.autodoc.mock import _MockModule
@ -235,13 +236,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
# We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors(
logit_bias=self.logit_bias,
@ -251,7 +256,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
return SamplingParams(
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
@ -265,8 +270,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
(self.top_logprobs if self.echo else None),
prompt_logprobs=prompt_logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
@ -280,14 +284,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, values):
if (values.get('stream_options') is not None
and not values.get('stream')):
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"stream_options can only be set if stream is true")
return values
"Stream options can only be defined when `stream=True`.")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and prompt_logprobs > 0:
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.")
if not data.get("logprobs"):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
return data
@model_validator(mode="before")
@classmethod
@ -320,19 +346,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif data["top_logprobs"] < 0:
raise ValueError(
"`top_logprobs` must be a value a positive value.")
return data
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
@ -422,13 +435,17 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = get_logits_processors(
@ -439,7 +456,7 @@ class CompletionRequest(OpenAIBaseModel):
if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor)
return SamplingParams(
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
@ -458,8 +475,7 @@ class CompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.prompt_logprobs
if self.prompt_logprobs else self.logprobs if self.echo else None,
prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
@ -485,9 +501,17 @@ class CompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "logprobs" in data and data[
"logprobs"] is not None and not data["logprobs"] >= 0:
raise ValueError("if passed, `logprobs` must be a positive value.")
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and prompt_logprobs > 0:
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise ValueError("`logprobs` must be a positive value.")
return data
@model_validator(mode="before")
@ -495,7 +519,8 @@ class CompletionRequest(OpenAIBaseModel):
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when stream is true.")
"Stream options can only be defined when `stream=True`.")
return data
@ -504,7 +529,7 @@ class EmbeddingRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/embeddings
model: str
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None

View File

@ -23,8 +23,8 @@ class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, rpc_path: str):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context)
self.engine = AsyncLLMEngine.from_engine_args(
async_engine_args, usage_context=usage_context)
# Initialize context.
self.context = zmq.asyncio.Context()
@ -39,7 +39,7 @@ class AsyncEngineRPCServer:
self.context.destroy()
self.engine.shutdown_background_loop()
# Clear the engine reference so that it can be GC'ed.
self.engine = None
del self.engine
async def get_model_config(self, identity):
"""Send the ModelConfig"""

View File

@ -1,11 +1,10 @@
import asyncio
import time
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
from typing import Sequence as GenericSequence
from typing import Union
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
@ -24,13 +23,14 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
from vllm.inputs import PromptInputs
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid
logger = init_logger(__name__)
@ -67,9 +67,9 @@ class OpenAIServingChat(OpenAIServing):
async def create_chat_completion(
self,
request: ChatCompletionRequest,
raw_request: Optional[Request] = None
) -> Union[ErrorResponse, AsyncGenerator[str, None],
ChatCompletionResponse]:
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
@ -83,16 +83,6 @@ class OpenAIServingChat(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
if request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid "
f"negative value: {request.prompt_logprobs}")
try:
(
lora_request,
@ -160,9 +150,8 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
engine_inputs: PromptInputs = {
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
}
engine_inputs = TokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
@ -214,11 +203,11 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
# Send response for each token for each request.n (index)
@ -438,7 +427,7 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0]
@ -523,7 +512,7 @@ class OpenAIServingChat(OpenAIServing):
def _get_top_logprobs(
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1],
@ -541,12 +530,11 @@ class OpenAIServingChat(OpenAIServing):
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
logprobs_content = []
logprobs_content: List[ChatCompletionLogProbsContent] = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
@ -554,23 +542,32 @@ class OpenAIServingChat(OpenAIServing):
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
bytes=list(token.encode("utf-8", errors="replace"))))
bytes=list(token.encode("utf-8", errors="replace")),
))
else:
step_token = step_top_logprobs[token_id]
step_decoded = step_token.decoded_token
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self._get_decoded_token(
step_top_logprobs[token_id], token_id, tokenizer,
self.return_tokens_as_token_ids),
logprob=max(step_top_logprobs[token_id].logprob,
-9999.0),
bytes=list(
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
step_token,
token_id,
tokenizer,
self.return_tokens_as_token_ids,
),
logprob=max(step_token.logprob, -9999.0),
bytes=None if step_decoded is None else list(
step_decoded.encode("utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs,
tokenizer)))
step_top_logprobs,
num_output_top_logprobs,
tokenizer,
),
))
return ChatCompletionLogProbs(content=logprobs_content)

View File

@ -3,10 +3,9 @@ import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
from typing import Sequence as GenericSequence
from typing import Tuple, cast
from typing import Tuple, Union, cast
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
@ -19,7 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo)
ErrorResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
@ -29,6 +28,7 @@ from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
@ -60,8 +60,11 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
async def create_completion(
self,
request: CompletionRequest,
raw_request: Request,
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
@ -84,15 +87,6 @@ class OpenAIServingCompletion(OpenAIServing):
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
elif request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid negative "
f"value: {request.prompt_logprobs}")
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
@ -153,9 +147,8 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
@ -227,7 +220,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts
@ -236,6 +229,13 @@ class OpenAIServingCompletion(OpenAIServing):
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt
delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]]
for output in res.outputs:
i = output.index + prompt_idx * num_choices
@ -244,19 +244,25 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_text is not None
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
out_logprobs = res.prompt_logprobs
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = (res.prompt_token_ids +
output.token_ids)
out_logprobs = res.prompt_logprobs + (output.logprobs
or [])
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids, *output.token_ids
]
out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
has_echoed[i] = True
else:
# return just the delta
@ -301,7 +307,7 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = len(res.prompt_token_ids)
prompt_tokens = len(prompt_token_ids)
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
@ -342,7 +348,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str,
created_time: int,
model_name: str,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
@ -353,16 +359,31 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
Logprob]]]]
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_text is not None
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + list(output.token_ids)
out_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs is not None else None)
assert prompt_text is not None
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
@ -413,7 +434,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
@ -430,17 +451,21 @@ class OpenAIServingCompletion(OpenAIServing):
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token(
step_top_logprobs[token_id],
step_token,
token_id,
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
return_as_token_id=self.return_tokens_as_token_ids,
)
token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)

View File

@ -1,11 +1,11 @@
import asyncio
import base64
import time
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
Union, cast)
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
import numpy as np
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
@ -16,7 +16,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
@ -24,18 +24,28 @@ logger = init_logger(__name__)
TypeTokenIDs = List[int]
def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
return output.embedding
elif encoding_format == "base64":
embedding_bytes = np.array(output.embedding).tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)
def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
created_time: int, model_name: str,
encoding_format: str) -> EmbeddingResponse:
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding
if encoding_format == "base64":
embedding_bytes = np.array(embedding).tobytes()
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
embedding = _get_embedding(final_res.outputs, encoding_format)
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)
@ -76,8 +86,8 @@ class OpenAIServingEmbedding(OpenAIServing):
async def create_embedding(
self,
request: EmbeddingRequest,
raw_request: Optional[Request] = None
) -> Union[ErrorResponse, EmbeddingResponse]:
raw_request: Optional[Request] = None,
) -> Union[EmbeddingResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
@ -89,8 +99,7 @@ class OpenAIServingEmbedding(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
encoding_format = (request.encoding_format
if request.encoding_format else "float")
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
@ -145,11 +154,10 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, EmbeddingRequestOutput]] = merge_async_iterators(
*generators,
is_cancelled=raw_request.is_disconnected
if raw_request else None)
result_generator = merge_async_iterators(
*generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
@ -175,7 +183,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return response
def _check_embedding_mode(self, embedding_mode: bool):
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
if not embedding_mode:
logger.warning(
"embedding_mode is False. Embedding API will not work.")

View File

@ -31,7 +31,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)

View File

@ -153,6 +153,68 @@ class SamplingParams(
output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
@staticmethod
def from_optional(
n: Optional[int] = 1,
best_of: Optional[int] = None,
presence_penalty: Optional[float] = 0.0,
frequency_penalty: Optional[float] = 0.0,
repetition_penalty: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 1.0,
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
min_tokens: int = 0,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
detokenize: bool = True,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
) -> "SamplingParams":
return SamplingParams(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=0.0
if presence_penalty is None else presence_penalty,
frequency_penalty=0.0
if frequency_penalty is None else frequency_penalty,
repetition_penalty=1.0
if repetition_penalty is None else repetition_penalty,
temperature=1.0 if temperature is None else temperature,
top_p=1.0 if top_p is None else top_p,
top_k=top_k,
min_p=min_p,
seed=seed,
use_beam_search=use_beam_search,
length_penalty=length_penalty,
early_stopping=early_stopping,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs,
detokenize=detokenize,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def __post_init__(self) -> None:
self.best_of = self.best_of or self.n
if 0 < self.temperature < _MAX_TEMP: