[mypy] Enable following imports for entrypoints (#7248)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Fei <dfdfcai4@gmail.com>
This commit is contained in:
Cyrus Leung 2024-08-21 14:28:21 +08:00 committed by GitHub
parent 4506641212
commit baaedfdb2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 480 additions and 320 deletions

View File

@ -38,7 +38,6 @@ jobs:
mypy vllm/core --follow-imports skip mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip mypy vllm/model_executor --follow-imports skip

View File

@ -6,7 +6,7 @@ sphinx-argparse==0.4.0
msgspec msgspec
# packages to install to build the documentation # packages to install to build the documentation
pydantic pydantic >= 2.8
-f https://download.pytorch.org/whl/cpu -f https://download.pytorch.org/whl/cpu
torch torch
py-cpuinfo py-cpuinfo

View File

@ -102,7 +102,6 @@ mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip mypy vllm/model_executor --follow-imports skip

View File

@ -56,6 +56,7 @@ files = [
"vllm/*.py", "vllm/*.py",
"vllm/adapter_commons", "vllm/adapter_commons",
"vllm/assets", "vllm/assets",
"vllm/entrypoints",
"vllm/inputs", "vllm/inputs",
"vllm/logging", "vllm/logging",
"vllm/multimodal", "vllm/multimodal",

View File

@ -11,7 +11,7 @@ fastapi
aiohttp aiohttp
openai >= 1.0 # Ensure modern openai package (ensure types module present) openai >= 1.0 # Ensure modern openai package (ensure types module present)
uvicorn[standard] uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server. pydantic >= 2.8 # Required for OpenAI server.
pillow # Required for image processing pillow # Required for image processing
prometheus_client >= 0.18.0 prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0 prometheus-fastapi-instrumentator >= 7.0.0

View File

@ -1,7 +1,7 @@
# imports for guided decoding tests # imports for guided decoding tests
import json import json
import re import re
from typing import List from typing import Dict, List, Optional
import jsonschema import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: Optional[int]):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError):
await client.chat.completions.create(**params)
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs is not None:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}
completion_1 = await client.chat.completions.create(**params)
params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)
assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",

View File

@ -3,7 +3,7 @@ import json
import re import re
import shutil import shutil
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Dict, List from typing import Dict, List, Optional
import jsonschema import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
@ -268,92 +268,6 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
assert len(completion.choices[0].text) >= 0 assert len(completion.choices[0].text) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str, prompt_logprobs: int):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.chat.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}
completion_1 = await client.chat.completions.create(**params)
params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)
assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), @pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
(MODEL_NAME, 0), (MODEL_NAME, 0),
@ -361,7 +275,7 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
(MODEL_NAME, None)]) (MODEL_NAME, None)])
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str, model_name: str,
prompt_logprobs: int): prompt_logprobs: Optional[int]):
params: Dict = { params: Dict = {
"prompt": ["A robot may not injure another robot", "My name is"], "prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name, "model": model_name,
@ -369,17 +283,12 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
if prompt_logprobs is not None: if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs} params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs and prompt_logprobs < 0: if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info: with pytest.raises(BadRequestError):
await client.completions.create(**params) await client.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else: else:
completion = await client.completions.create(**params) completion = await client.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0: if prompt_logprobs is not None:
assert completion.choices[0].prompt_logprobs is not None assert completion.choices[0].prompt_logprobs is not None
assert len(completion.choices[0].prompt_logprobs) > 0 assert len(completion.choices[0].prompt_logprobs) > 0

View File

@ -6,7 +6,6 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union) Optional, Set, Tuple, Type, Union)
import torch import torch
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
@ -31,6 +30,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
@ -427,8 +427,8 @@ class _AsyncLLMEngine(LLMEngine):
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
) -> List[int]: ) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`.""" """Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group("prompts must be None if " tokenizer = self.get_tokenizer_group(
"skip_tokenizer_init is True") missing_msg="prompts must be None if skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id, return await tokenizer.encode_async(request_id=request_id,
prompt=prompt, prompt=prompt,
@ -771,7 +771,7 @@ class AsyncLLMEngine:
async def get_tokenizer( async def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer": ) -> AnyTokenizer:
if self.engine_use_ray: if self.engine_use_ray:
return await self.engine.get_tokenizer.remote( # type: ignore return await self.engine.get_tokenizer.remote( # type: ignore
lora_request) lora_request)

View File

@ -3,9 +3,9 @@ from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional) Mapping, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, TypeVar, Union from typing import Set, Tuple, Type, Union
from typing_extensions import assert_never from typing_extensions import TypeVar, assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
@ -43,8 +43,9 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer) init_tracer)
from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device from vllm.utils import Counter, Device
@ -67,6 +68,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
return config.to_diff_dict() return config.to_diff_dict()
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
PromptComponents = Tuple[Optional[str], List[int], PromptComponents = Tuple[Optional[str], List[int],
@ -494,11 +496,20 @@ class LLMEngine:
def get_tokenizer_group( def get_tokenizer_group(
self, self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup: group_type: Type[_G] = BaseTokenizerGroup,
if self.tokenizer is None: *,
raise ValueError(fail_msg) missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
) -> _G:
tokenizer_group = self.tokenizer
return self.tokenizer if tokenizer_group is None:
raise ValueError(missing_msg)
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return tokenizer_group
def get_tokenizer( def get_tokenizer(
self, self,
@ -693,8 +704,8 @@ class LLMEngine:
* prompt token ids * prompt token ids
''' '''
tokenizer = self.get_tokenizer_group("prompts must be None if " tokenizer = self.get_tokenizer_group(
"skip_tokenizer_init is True") missing_msg="prompts must be None if skip_tokenizer_init is True")
return tokenizer.encode(request_id=request_id, return tokenizer.encode(request_id=request_id,
prompt=prompt, prompt=prompt,

View File

@ -1,13 +1,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter from vllm.utils import Counter
@ -29,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: List[Scheduler], scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: "StopChecker", stop_checker: "StopChecker",
): ):
"""Create an output processor. """Create an output processor.

View File

@ -1,8 +1,6 @@
import functools import functools
from typing import Callable, List from typing import Callable, List
from transformers import PreTrainedTokenizer
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
@ -12,6 +10,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus) SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Counter from vllm.utils import Counter
logger = init_logger(__name__) logger = init_logger(__name__)
@ -36,7 +35,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
detokenizer: Detokenizer, detokenizer: Detokenizer,
scheduler: List[Scheduler], scheduler: List[Scheduler],
seq_counter: Counter, seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
stop_checker: StopChecker, stop_checker: StopChecker,
): ):
self.detokenizer = detokenizer self.detokenizer = detokenizer

View File

@ -1,10 +1,9 @@
from typing import Callable, Optional from typing import Callable, Optional
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
class StopChecker: class StopChecker:
@ -15,8 +14,7 @@ class StopChecker:
""" """
def __init__(self, max_model_len: int, def __init__(self, max_model_len: int,
get_tokenizer_for_seq: Callable[[Sequence], get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
PreTrainedTokenizer]):
# Do not use it directly, but use `self._get_max_model_len`. # Do not use it directly, but use `self._get_max_model_len`.
self._max_model_len = max_model_len self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq self.get_tokenizer_for_seq = get_tokenizer_for_seq

View File

@ -1,8 +1,6 @@
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol, from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
runtime_checkable) runtime_checkable)
from transformers import PreTrainedTokenizer
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs from vllm.inputs.data import PromptInputs
@ -12,6 +10,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable @runtime_checkable
@ -40,6 +39,7 @@ class AsyncEngineClient(Protocol):
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request""" """Generates outputs for a request"""
...
def encode( def encode(
self, self,
@ -50,6 +50,7 @@ class AsyncEngineClient(Protocol):
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.""" """Generate outputs for a request from an embedding model."""
...
async def abort(self, request_id: str) -> None: async def abort(self, request_id: str) -> None:
"""Abort a request. """Abort a request.
@ -60,25 +61,29 @@ class AsyncEngineClient(Protocol):
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
...
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
...
"""Get the decoding configuration of the vLLM engine.""" """Get the decoding configuration of the vLLM engine."""
async def get_tokenizer( async def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> PreTrainedTokenizer: ) -> AnyTokenizer:
"""Get the appropriate Tokenizer for the request""" """Get the appropriate tokenizer for the request"""
...
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
pass ...
async def do_log_stats( async def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None, model_output: Optional[List[SamplerOutput]] = None,
) -> None: ) -> None:
pass ...
async def check_health(self) -> None: async def check_health(self) -> None:
"""Raise if unhealthy""" """Raise if unhealthy"""
...

View File

@ -61,6 +61,7 @@ async def generate(request: Request) -> Response:
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator: async for request_output in results_generator:
prompt = request_output.prompt prompt = request_output.prompt
assert prompt is not None
text_outputs = [ text_outputs = [
prompt + output.text for output in request_output.outputs prompt + output.text for output in request_output.outputs
] ]
@ -80,6 +81,7 @@ async def generate(request: Request) -> Response:
assert final_output is not None assert final_output is not None
prompt = final_output.prompt prompt = final_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in final_output.outputs] text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs} ret = {"text": text_outputs}
return JSONResponse(ret) return JSONResponse(ret)
@ -115,6 +117,7 @@ async def run_server(args: Namespace,
logger.info("args: %s", args) logger.info("args: %s", args)
app = await init_app(args, llm_engine) app = await init_app(args, llm_engine)
assert engine is not None
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union, cast) Union)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -15,9 +15,8 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
# yapf: enable # yapf: enable
# pydantic needs the TypedDict from typing_extensions # pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict from pydantic import ConfigDict, TypeAdapter
from transformers import PreTrainedTokenizer from typing_extensions import Required, TypeAlias, TypedDict
from typing_extensions import Required, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
@ -50,9 +49,9 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part.""" """The type of the content part."""
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam, ChatCompletionContentPartParam: TypeAlias = Union[
ChatCompletionContentPartAudioParam, OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam] CustomChatCompletionContentPartParam, ]
class CustomChatCompletionMessageParam(TypedDict, total=False): class CustomChatCompletionMessageParam(TypedDict, total=False):
@ -114,7 +113,7 @@ def load_chat_template(
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer, def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]: modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt # TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template) # (similar to chat template)
@ -151,11 +150,16 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str,
return f"{placeholder_token_str}\n{text_prompt}" return f"{placeholder_token_str}\n{text_prompt}"
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
role: str, role: str,
parts: Iterable[ChatCompletionContentPartParam], parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> ChatMessageParseResult: ) -> ChatMessageParseResult:
texts: List[str] = [] texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = []
@ -164,7 +168,7 @@ def _parse_chat_message_content_parts(
for part in parts: for part in parts:
part_type = part["type"] part_type = part["type"]
if part_type == "text": if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"] text = _TextParser.validate_python(part)["text"]
texts.append(text) texts.append(text)
elif part_type == "image_url": elif part_type == "image_url":
modality = "image" modality = "image"
@ -172,8 +176,7 @@ def _parse_chat_message_content_parts(
raise NotImplementedError( raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.") "Multiple multimodal inputs is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam, image_url = _ImageParser.validate_python(part)["image_url"]
part)["image_url"]
if image_url.get("detail", "auto") != "auto": if image_url.get("detail", "auto") != "auto":
logger.warning( logger.warning(
@ -188,8 +191,7 @@ def _parse_chat_message_content_parts(
raise NotImplementedError( raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.") "Multiple multimodal inputs is currently not supported.")
audio_url = cast(ChatCompletionContentPartAudioParam, audio_url = _AudioParser.validate_python(part)["audio_url"]
part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"]) audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future) mm_futures.append(audio_future)
else: else:
@ -219,7 +221,7 @@ def _parse_chat_message_content_parts(
def _parse_chat_message_content( def _parse_chat_message_content(
message: ChatCompletionMessageParam, message: ChatCompletionMessageParam,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> ChatMessageParseResult: ) -> ChatMessageParseResult:
role = message["role"] role = message["role"]
content = message.get("content") content = message.get("content")
@ -230,14 +232,18 @@ def _parse_chat_message_content(
messages = [ConversationMessage(role=role, content=content)] messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[]) return ChatMessageParseResult(messages=messages, mm_futures=[])
return _parse_chat_message_content_parts(role, content, model_config, return _parse_chat_message_content_parts(
tokenizer) role,
content, # type: ignore
model_config,
tokenizer,
)
def parse_chat_messages( def parse_chat_messages(
messages: List[ChatCompletionMessageParam], messages: List[ChatCompletionMessageParam],
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: ) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = [] conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = []

View File

@ -1,8 +1,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from tqdm.auto import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
@ -20,7 +19,9 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs from vllm.utils import Counter, deprecate_kwargs
@ -122,7 +123,7 @@ class LLM:
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
seed: int = 0, seed: int = 0,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: int = 4, swap_space: float = 4,
cpu_offload_gb: float = 0, cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None, enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
@ -175,22 +176,19 @@ class LLM:
engine_args, usage_context=UsageContext.LLM_CLASS) engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter() self.request_counter = Counter()
def get_tokenizer( def get_tokenizer(self) -> AnyTokenizer:
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
return self.llm_engine.tokenizer.tokenizer
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
# While CachedTokenizer is dynamic, have no choice but # While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from # compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached' # user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"): if tokenizer.__class__.__name__.startswith("Cached"):
self.llm_engine.tokenizer.tokenizer = tokenizer tokenizer_group.tokenizer = tokenizer
else: else:
self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
tokenizer)
@overload # LEGACY: single (prompt + optional token ids) @overload # LEGACY: single (prompt + optional token ids)
def generate( def generate(
@ -578,6 +576,8 @@ class LLM:
inputs: List[PromptInputs] = [] inputs: List[PromptInputs] = []
for i in range(num_requests): for i in range(num_requests):
item: PromptInputs
if prompts is not None: if prompts is not None:
item = TextPrompt(prompt=prompts[i]) item = TextPrompt(prompt=prompts[i])
elif prompt_token_ids is not None: elif prompt_token_ids is not None:
@ -635,7 +635,7 @@ class LLM:
self, self,
inputs: PromptInputs, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))

View File

@ -15,6 +15,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Mount from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -29,14 +30,16 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
CompletionResponse,
DetokenizeRequest, DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
EmbeddingRequest, ErrorResponse, EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
TokenizeRequest, TokenizeRequest,
TokenizeResponse) TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
@ -90,7 +93,8 @@ async def lifespan(app: FastAPI):
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: async def build_async_engine_client(
args: Namespace) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle # Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
global engine_args global engine_args
@ -142,12 +146,15 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
logger.info("Started engine process with PID %d", logger.info("Started engine process with PID %d",
rpc_server_process.pid) rpc_server_process.pid)
# Build RPCClient, which conforms to AsyncEngineClient Protocol. # Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(rpc_path) # NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
try: try:
while True: while True:
try: try:
await async_engine_client.setup() await rpc_client.setup()
break break
except TimeoutError as e: except TimeoutError as e:
if not rpc_server_process.is_alive(): if not rpc_server_process.is_alive():
@ -161,7 +168,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
rpc_server_process.terminate() rpc_server_process.terminate()
# Close all open connections to the backend # Close all open connections to the backend
async_engine_client.close() rpc_client.close()
# Wait for server process to join # Wait for server process to join
rpc_server_process.join() rpc_server_process.join()
@ -216,10 +223,11 @@ async def tokenize(request: TokenizeRequest):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
else: elif isinstance(generator, TokenizeResponse):
assert isinstance(generator, TokenizeResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize") @router.post("/detokenize")
async def detokenize(request: DetokenizeRequest): async def detokenize(request: DetokenizeRequest):
@ -227,10 +235,11 @@ async def detokenize(request: DetokenizeRequest):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
else: elif isinstance(generator, DetokenizeResponse):
assert isinstance(generator, DetokenizeResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models") @router.get("/v1/models")
async def show_available_models(): async def show_available_models():
@ -252,13 +261,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
if request.stream: elif isinstance(generator, ChatCompletionResponse):
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions") @router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
@ -267,12 +274,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
if request.stream: elif isinstance(generator, CompletionResponse):
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings") @router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@ -281,9 +287,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
else: elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
assert_never(generator)
def build_app(args: Namespace) -> FastAPI: def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)

View File

@ -7,6 +7,7 @@ purposes.
import argparse import argparse
import json import json
import ssl import ssl
from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
@ -16,8 +17,19 @@ from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action): class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(
lora_list = [] self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
lora_list: List[LoRAModulePath] = []
for item in values: for item in values:
name, path = item.split('=') name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path)) lora_list.append(LoRAModulePath(name, path))
@ -26,8 +38,19 @@ class LoRAParserAction(argparse.Action):
class PromptAdapterParserAction(argparse.Action): class PromptAdapterParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(
adapter_list = [] self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
adapter_list: List[PromptAdapterPath] = []
for item in values: for item in values:
name, path = item.split('=') name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path)) adapter_list.append(PromptAdapterPath(name, path))

View File

@ -2,9 +2,9 @@ from functools import lru_cache, partial
from typing import Dict, FrozenSet, Iterable, List, Optional, Union from typing import Dict, FrozenSet, Iterable, List, Optional, Union
import torch import torch
from transformers import PreTrainedTokenizer
from vllm.sampling_params import LogitsProcessor from vllm.sampling_params import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
class AllowedTokenIdsLogitsProcessor: class AllowedTokenIdsLogitsProcessor:
@ -53,8 +53,9 @@ def logit_bias_logits_processor(
def get_logits_processors( def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]], allowed_token_ids: Optional[List[int]],
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]: tokenizer: AnyTokenizer,
logits_processors = [] ) -> List[LogitsProcessor]:
logits_processors: List[LogitsProcessor] = []
if logit_bias: if logit_bias:
try: try:
# Convert token_id to integer # Convert token_id to integer

View File

@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
import torch import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@ -14,11 +13,13 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
# torch is mocked during docs generation, # torch is mocked during docs generation,
# so we have to provide the values as literals # so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
_LONG_INFO: Union["torch.iinfo", Namespace]
try: try:
from sphinx.ext.autodoc.mock import _MockModule from sphinx.ext.autodoc.mock import _MockModule
@ -235,13 +236,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
def to_sampling_params( def to_sampling_params(
self, tokenizer: PreTrainedTokenizer, self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor], guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams: default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
# We now allow logprobs being true without top_logrobs. # We now allow logprobs being true without top_logrobs.
logits_processors = get_logits_processors( logits_processors = get_logits_processors(
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
@ -251,7 +256,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
if guided_decode_logits_processor: if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor) logits_processors.append(guided_decode_logits_processor)
return SamplingParams( return SamplingParams.from_optional(
n=self.n, n=self.n,
best_of=self.best_of, best_of=self.best_of,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
@ -265,8 +270,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None, logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else prompt_logprobs=prompt_logprobs,
(self.top_logprobs if self.echo else None),
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=max_tokens, max_tokens=max_tokens,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
@ -280,14 +284,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
) )
@model_validator(mode='before') @model_validator(mode="before")
@classmethod @classmethod
def validate_stream_options(cls, values): def validate_stream_options(cls, data):
if (values.get('stream_options') is not None if data.get("stream_options") and not data.get("stream"):
and not values.get('stream')):
raise ValueError( raise ValueError(
"stream_options can only be set if stream is true") "Stream options can only be defined when `stream=True`.")
return values
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and prompt_logprobs > 0:
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.")
if not data.get("logprobs"):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -320,19 +346,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.") "When using `tool_choice`, `tools` must be set.")
return data return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif data["top_logprobs"] < 0:
raise ValueError(
"`top_logprobs` must be a value a positive value.")
return data
class CompletionRequest(OpenAIBaseModel): class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
@ -422,13 +435,17 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params # doc: end-completion-extra-params
def to_sampling_params( def to_sampling_params(
self, tokenizer: PreTrainedTokenizer, self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor], guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams: default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
echo_without_generation = self.echo and self.max_tokens == 0 echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = get_logits_processors( logits_processors = get_logits_processors(
@ -439,7 +456,7 @@ class CompletionRequest(OpenAIBaseModel):
if guided_decode_logits_processor: if guided_decode_logits_processor:
logits_processors.append(guided_decode_logits_processor) logits_processors.append(guided_decode_logits_processor)
return SamplingParams( return SamplingParams.from_optional(
n=self.n, n=self.n,
best_of=self.best_of, best_of=self.best_of,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
@ -458,8 +475,7 @@ class CompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping, early_stopping=self.early_stopping,
prompt_logprobs=self.prompt_logprobs prompt_logprobs=prompt_logprobs,
if self.prompt_logprobs else self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
@ -485,9 +501,17 @@ class CompletionRequest(OpenAIBaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_logprobs(cls, data): def check_logprobs(cls, data):
if "logprobs" in data and data[ if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
"logprobs"] is not None and not data["logprobs"] >= 0: if data.get("stream") and prompt_logprobs > 0:
raise ValueError("if passed, `logprobs` must be a positive value.") raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise ValueError("`logprobs` must be a positive value.")
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@ -495,7 +519,8 @@ class CompletionRequest(OpenAIBaseModel):
def validate_stream_options(cls, data): def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"): if data.get("stream_options") and not data.get("stream"):
raise ValueError( raise ValueError(
"Stream options can only be defined when stream is true.") "Stream options can only be defined when `stream=True`.")
return data return data
@ -504,7 +529,7 @@ class EmbeddingRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/embeddings # https://platform.openai.com/docs/api-reference/embeddings
model: str model: str
input: Union[List[int], List[List[int]], str, List[str]] input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$') encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None dimensions: Optional[int] = None
user: Optional[str] = None user: Optional[str] = None

View File

@ -23,8 +23,8 @@ class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs, def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, rpc_path: str): usage_context: UsageContext, rpc_path: str):
# Initialize engine first. # Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, self.engine = AsyncLLMEngine.from_engine_args(
usage_context) async_engine_args, usage_context=usage_context)
# Initialize context. # Initialize context.
self.context = zmq.asyncio.Context() self.context = zmq.asyncio.Context()
@ -39,7 +39,7 @@ class AsyncEngineRPCServer:
self.context.destroy() self.context.destroy()
self.engine.shutdown_background_loop() self.engine.shutdown_background_loop()
# Clear the engine reference so that it can be GC'ed. # Clear the engine reference so that it can be GC'ed.
self.engine = None del self.engine
async def get_model_config(self, identity): async def get_model_config(self, identity):
"""Send the ModelConfig""" """Send the ModelConfig"""

View File

@ -1,11 +1,10 @@
import asyncio import asyncio
import time import time
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Union from typing import Union
from fastapi import Request from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
@ -24,13 +23,14 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing, OpenAIServing,
PromptAdapterPath) PromptAdapterPath)
from vllm.inputs import PromptInputs from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid from vllm.utils import iterate_with_cancellation, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -67,9 +67,9 @@ class OpenAIServingChat(OpenAIServing):
async def create_chat_completion( async def create_chat_completion(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
raw_request: Optional[Request] = None raw_request: Optional[Request] = None,
) -> Union[ErrorResponse, AsyncGenerator[str, None], ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ChatCompletionResponse]: ErrorResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create See https://platform.openai.com/docs/api-reference/chat/create
@ -83,16 +83,6 @@ class OpenAIServingChat(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
if request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid "
f"negative value: {request.prompt_logprobs}")
try: try:
( (
lora_request, lora_request,
@ -160,9 +150,8 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request) prompt_adapter_request=prompt_adapter_request)
engine_inputs: PromptInputs = { engine_inputs = TokensPrompt(
"prompt_token_ids": prompt_inputs["prompt_token_ids"], prompt_token_ids=prompt_inputs["prompt_token_ids"])
}
if mm_data is not None: if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data engine_inputs["multi_modal_data"] = mm_data
@ -214,11 +203,11 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput],
request_id: str, request_id: str,
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True first_iteration = True
# Send response for each token for each request.n (index) # Send response for each token for each request.n (index)
@ -438,7 +427,7 @@ class OpenAIServingChat(OpenAIServing):
result_generator: AsyncIterator[RequestOutput], result_generator: AsyncIterator[RequestOutput],
request_id: str, request_id: str,
conversation: List[ConversationMessage], conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]: ) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0] model_name = self.served_model_names[0]
@ -523,7 +512,7 @@ class OpenAIServingChat(OpenAIServing):
def _get_top_logprobs( def _get_top_logprobs(
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]: tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
return [ return [
ChatCompletionLogProb(token=(token := self._get_decoded_token( ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1], p[1],
@ -541,12 +530,11 @@ class OpenAIServingChat(OpenAIServing):
self, self,
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
num_output_top_logprobs: Optional[int] = None, num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs: ) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs.""" """Create OpenAI-style logprobs."""
logprobs_content: List[ChatCompletionLogProbsContent] = []
logprobs_content = []
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
@ -554,23 +542,32 @@ class OpenAIServingChat(OpenAIServing):
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids: if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
token=token, token=token,
bytes=list(token.encode("utf-8", errors="replace")))) bytes=list(token.encode("utf-8", errors="replace")),
))
else: else:
step_token = step_top_logprobs[token_id]
step_decoded = step_token.decoded_token
logprobs_content.append( logprobs_content.append(
ChatCompletionLogProbsContent( ChatCompletionLogProbsContent(
token=self._get_decoded_token( token=self._get_decoded_token(
step_top_logprobs[token_id], token_id, tokenizer, step_token,
self.return_tokens_as_token_ids), token_id,
logprob=max(step_top_logprobs[token_id].logprob, tokenizer,
-9999.0), self.return_tokens_as_token_ids,
bytes=list( ),
step_top_logprobs[token_id].decoded_token.encode( logprob=max(step_token.logprob, -9999.0),
"utf-8", errors="replace")), bytes=None if step_decoded is None else list(
step_decoded.encode("utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs( top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs, step_top_logprobs,
tokenizer))) num_output_top_logprobs,
tokenizer,
),
))
return ChatCompletionLogProbs(content=logprobs_content) return ChatCompletionLogProbs(content=logprobs_content)

View File

@ -3,10 +3,9 @@ import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional) Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple, cast from typing import Tuple, Union, cast
from fastapi import Request from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
@ -19,7 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
UsageInfo) ErrorResponse, UsageInfo)
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing, OpenAIServing,
@ -29,6 +28,7 @@ from vllm.outputs import RequestOutput
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -60,8 +60,11 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids) return_tokens_as_token_ids=return_tokens_as_token_ids)
async def create_completion(self, request: CompletionRequest, async def create_completion(
raw_request: Request): self,
request: CompletionRequest,
raw_request: Request,
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create See https://platform.openai.com/docs/api-reference/completions/create
@ -84,15 +87,6 @@ class OpenAIServingCompletion(OpenAIServing):
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time()) created_time = int(time.time())
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
elif request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid negative "
f"value: {request.prompt_logprobs}")
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = [] generators: List[AsyncGenerator[RequestOutput, None]] = []
try: try:
@ -153,8 +147,7 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[ result_generator = merge_async_iterators(
int, RequestOutput]] = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected) *generators, is_cancelled=raw_request.is_disconnected)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
@ -227,7 +220,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int, created_time: int,
model_name: str, model_name: str,
num_prompts: int, num_prompts: int,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts previous_texts = [""] * num_choices * num_prompts
@ -236,6 +229,13 @@ class OpenAIServingCompletion(OpenAIServing):
try: try:
async for prompt_idx, res in result_generator: async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt
delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]]
for output in res.outputs: for output in res.outputs:
i = output.index + prompt_idx * num_choices i = output.index + prompt_idx * num_choices
@ -244,19 +244,25 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
assert prompt_text is not None
# only return the prompt # only return the prompt
delta_text = res.prompt delta_text = prompt_text
delta_token_ids = res.prompt_token_ids delta_token_ids = prompt_token_ids
out_logprobs = res.prompt_logprobs out_logprobs = prompt_logprobs
has_echoed[i] = True has_echoed[i] = True
elif (request.echo and request.max_tokens > 0 elif (request.echo and request.max_tokens > 0
and not has_echoed[i]): and not has_echoed[i]):
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token # echo the prompt and first token
delta_text = res.prompt + output.text delta_text = prompt_text + output.text
delta_token_ids = (res.prompt_token_ids + delta_token_ids = [
output.token_ids) *prompt_token_ids, *output.token_ids
out_logprobs = res.prompt_logprobs + (output.logprobs ]
or []) out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
has_echoed[i] = True has_echoed[i] = True
else: else:
# return just the delta # return just the delta
@ -301,7 +307,7 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage): and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None): or output.finish_reason is not None):
prompt_tokens = len(res.prompt_token_ids) prompt_tokens = len(prompt_token_ids)
completion_tokens = len(output.token_ids) completion_tokens = len(output.token_ids)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
@ -342,7 +348,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
) -> CompletionResponse: ) -> CompletionResponse:
choices: List[CompletionResponseChoice] = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
@ -353,16 +359,31 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = final_res.prompt_logprobs prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
Logprob]]]]
for output in final_res.outputs: for output in final_res.outputs:
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
assert prompt_text is not None
token_ids = prompt_token_ids token_ids = prompt_token_ids
out_logprobs = prompt_logprobs out_logprobs = prompt_logprobs
output_text = prompt_text output_text = prompt_text
elif request.echo and request.max_tokens > 0: elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + list(output.token_ids) assert prompt_text is not None
out_logprobs = (prompt_logprobs + output.logprobs token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is not None else None)
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text output_text = prompt_text + output.text
else: else:
token_ids = output.token_ids token_ids = output.token_ids
@ -413,7 +434,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int, num_output_top_logprobs: int,
tokenizer: PreTrainedTokenizer, tokenizer: AnyTokenizer,
initial_text_offset: int = 0, initial_text_offset: int = 0,
) -> CompletionLogProbs: ) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API.""" """Create logprobs for OpenAI Completion API."""
@ -430,17 +451,21 @@ class OpenAIServingCompletion(OpenAIServing):
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids: if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
out_tokens.append(token) out_tokens.append(token)
out_token_logprobs.append(None) out_token_logprobs.append(None)
out_top_logprobs.append(None) out_top_logprobs.append(None)
else: else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token( token = self._get_decoded_token(
step_top_logprobs[token_id], step_token,
token_id, token_id,
tokenizer, tokenizer,
return_as_token_id=self.return_tokens_as_token_ids) return_as_token_id=self.return_tokens_as_token_ids,
token_logprob = max(step_top_logprobs[token_id].logprob, )
-9999.0) token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token) out_tokens.append(token)
out_token_logprobs.append(token_logprob) out_token_logprobs.append(token_logprob)

View File

@ -1,11 +1,11 @@
import asyncio import asyncio
import base64 import base64
import time import time
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple, from typing import AsyncGenerator, List, Literal, Optional, Union, cast
Union, cast)
import numpy as np import numpy as np
from fastapi import Request from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
@ -16,7 +16,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
ErrorResponse, UsageInfo) ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -24,18 +24,28 @@ logger = init_logger(__name__)
TypeTokenIDs = List[int] TypeTokenIDs = List[int]
def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
return output.embedding
elif encoding_format == "base64":
embedding_bytes = np.array(output.embedding).tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)
def request_output_to_embedding_response( def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str, final_res_batch: List[EmbeddingRequestOutput], request_id: str,
created_time: int, model_name: str, created_time: int, model_name: str,
encoding_format: str) -> EmbeddingResponse: encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = [] data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch): for idx, final_res in enumerate(final_res_batch):
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding embedding = _get_embedding(final_res.outputs, encoding_format)
if encoding_format == "base64":
embedding_bytes = np.array(embedding).tobytes()
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data) data.append(embedding_data)
@ -76,8 +86,8 @@ class OpenAIServingEmbedding(OpenAIServing):
async def create_embedding( async def create_embedding(
self, self,
request: EmbeddingRequest, request: EmbeddingRequest,
raw_request: Optional[Request] = None raw_request: Optional[Request] = None,
) -> Union[ErrorResponse, EmbeddingResponse]: ) -> Union[EmbeddingResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API. """Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create See https://platform.openai.com/docs/api-reference/embeddings/create
@ -89,8 +99,7 @@ class OpenAIServingEmbedding(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
encoding_format = (request.encoding_format encoding_format = request.encoding_format
if request.encoding_format else "float")
if request.dimensions is not None: if request.dimensions is not None:
return self.create_error_response( return self.create_error_response(
"dimensions is currently not supported") "dimensions is currently not supported")
@ -145,11 +154,10 @@ class OpenAIServingEmbedding(OpenAIServing):
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[ result_generator = merge_async_iterators(
int, EmbeddingRequestOutput]] = merge_async_iterators(
*generators, *generators,
is_cancelled=raw_request.is_disconnected is_cancelled=raw_request.is_disconnected if raw_request else None,
if raw_request else None) )
# Non-streaming response # Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]] final_res_batch: List[Optional[EmbeddingRequestOutput]]
@ -175,7 +183,7 @@ class OpenAIServingEmbedding(OpenAIServing):
return response return response
def _check_embedding_mode(self, embedding_mode: bool): def _check_embedding_mode(self, embedding_mode: bool) -> bool:
if not embedding_mode: if not embedding_mode:
logger.warning( logger.warning(
"embedding_mode is False. Embedding API will not work.") "embedding_mode is False. Embedding API will not work.")

View File

@ -31,7 +31,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -153,6 +153,68 @@ class SamplingParams(
output_text_buffer_length: int = 0 output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
@staticmethod
def from_optional(
n: Optional[int] = 1,
best_of: Optional[int] = None,
presence_penalty: Optional[float] = 0.0,
frequency_penalty: Optional[float] = 0.0,
repetition_penalty: Optional[float] = 1.0,
temperature: Optional[float] = 1.0,
top_p: Optional[float] = 1.0,
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
min_tokens: int = 0,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
detokenize: bool = True,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
) -> "SamplingParams":
return SamplingParams(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=0.0
if presence_penalty is None else presence_penalty,
frequency_penalty=0.0
if frequency_penalty is None else frequency_penalty,
repetition_penalty=1.0
if repetition_penalty is None else repetition_penalty,
temperature=1.0 if temperature is None else temperature,
top_p=1.0 if top_p is None else top_p,
top_k=top_k,
min_p=min_p,
seed=seed,
use_beam_search=use_beam_search,
length_penalty=length_penalty,
early_stopping=early_stopping,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs,
detokenize=detokenize,
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.best_of = self.best_of or self.n self.best_of = self.best_of or self.n
if 0 < self.temperature < _MAX_TEMP: if 0 < self.temperature < _MAX_TEMP: