mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
[mypy] Enable following imports for entrypoints (#7248)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Fei <dfdfcai4@gmail.com>
This commit is contained in:
parent
4506641212
commit
baaedfdb2d
1
.github/workflows/mypy.yaml
vendored
1
.github/workflows/mypy.yaml
vendored
@ -38,7 +38,6 @@ jobs:
|
||||
mypy vllm/core --follow-imports skip
|
||||
mypy vllm/distributed --follow-imports skip
|
||||
mypy vllm/engine --follow-imports skip
|
||||
mypy vllm/entrypoints --follow-imports skip
|
||||
mypy vllm/executor --follow-imports skip
|
||||
mypy vllm/lora --follow-imports skip
|
||||
mypy vllm/model_executor --follow-imports skip
|
||||
|
||||
@ -6,7 +6,7 @@ sphinx-argparse==0.4.0
|
||||
msgspec
|
||||
|
||||
# packages to install to build the documentation
|
||||
pydantic
|
||||
pydantic >= 2.8
|
||||
-f https://download.pytorch.org/whl/cpu
|
||||
torch
|
||||
py-cpuinfo
|
||||
|
||||
@ -102,7 +102,6 @@ mypy vllm/attention --follow-imports skip
|
||||
mypy vllm/core --follow-imports skip
|
||||
mypy vllm/distributed --follow-imports skip
|
||||
mypy vllm/engine --follow-imports skip
|
||||
mypy vllm/entrypoints --follow-imports skip
|
||||
mypy vllm/executor --follow-imports skip
|
||||
mypy vllm/lora --follow-imports skip
|
||||
mypy vllm/model_executor --follow-imports skip
|
||||
|
||||
@ -56,6 +56,7 @@ files = [
|
||||
"vllm/*.py",
|
||||
"vllm/adapter_commons",
|
||||
"vllm/assets",
|
||||
"vllm/entrypoints",
|
||||
"vllm/inputs",
|
||||
"vllm/logging",
|
||||
"vllm/multimodal",
|
||||
|
||||
@ -11,7 +11,7 @@ fastapi
|
||||
aiohttp
|
||||
openai >= 1.0 # Ensure modern openai package (ensure types module present)
|
||||
uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
pydantic >= 2.8 # Required for OpenAI server.
|
||||
pillow # Required for image processing
|
||||
prometheus_client >= 0.18.0
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# imports for guided decoding tests
|
||||
import json
|
||||
import re
|
||||
from typing import List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, prompt_logprobs",
|
||||
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
|
||||
)
|
||||
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
prompt_logprobs: Optional[int]):
|
||||
params: Dict = {
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
"model":
|
||||
model_name
|
||||
}
|
||||
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs is not None and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.chat.completions.create(**params)
|
||||
else:
|
||||
completion = await client.chat.completions.create(**params)
|
||||
if prompt_logprobs is not None:
|
||||
assert completion.prompt_logprobs is not None
|
||||
assert len(completion.prompt_logprobs) > 0
|
||||
else:
|
||||
assert completion.prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
params: Dict = {
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
"model":
|
||||
model_name,
|
||||
"extra_body": {
|
||||
"prompt_logprobs": 1
|
||||
}
|
||||
}
|
||||
|
||||
completion_1 = await client.chat.completions.create(**params)
|
||||
|
||||
params["extra_body"] = {"prompt_logprobs": 2}
|
||||
completion_2 = await client.chat.completions.create(**params)
|
||||
|
||||
assert len(completion_1.prompt_logprobs[3]) == 1
|
||||
assert len(completion_2.prompt_logprobs[3]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
||||
@ -3,7 +3,7 @@ import json
|
||||
import re
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
@ -268,92 +268,6 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
||||
assert len(completion.choices[0].text) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, prompt_logprobs",
|
||||
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
|
||||
)
|
||||
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
model_name: str, prompt_logprobs: int):
|
||||
params: Dict = {
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
"model":
|
||||
model_name
|
||||
}
|
||||
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError) as err_info:
|
||||
await client.chat.completions.create(**params)
|
||||
expected_err_string = (
|
||||
"Error code: 400 - {'object': 'error', 'message': "
|
||||
"'Prompt_logprobs set to invalid negative value: -1',"
|
||||
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
|
||||
assert str(err_info.value) == expected_err_string
|
||||
else:
|
||||
completion = await client.chat.completions.create(**params)
|
||||
if prompt_logprobs and prompt_logprobs > 0:
|
||||
assert completion.prompt_logprobs is not None
|
||||
assert len(completion.prompt_logprobs) > 0
|
||||
else:
|
||||
assert completion.prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
params: Dict = {
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
"model":
|
||||
model_name,
|
||||
"extra_body": {
|
||||
"prompt_logprobs": 1
|
||||
}
|
||||
}
|
||||
|
||||
completion_1 = await client.chat.completions.create(**params)
|
||||
|
||||
params["extra_body"] = {"prompt_logprobs": 2}
|
||||
completion_2 = await client.chat.completions.create(**params)
|
||||
|
||||
assert len(completion_1.prompt_logprobs[3]) == 1
|
||||
assert len(completion_2.prompt_logprobs[3]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
|
||||
(MODEL_NAME, 0),
|
||||
@ -361,7 +275,7 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
(MODEL_NAME, None)])
|
||||
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
prompt_logprobs: int):
|
||||
prompt_logprobs: Optional[int]):
|
||||
params: Dict = {
|
||||
"prompt": ["A robot may not injure another robot", "My name is"],
|
||||
"model": model_name,
|
||||
@ -369,17 +283,12 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError) as err_info:
|
||||
if prompt_logprobs is not None and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(**params)
|
||||
expected_err_string = (
|
||||
"Error code: 400 - {'object': 'error', 'message': "
|
||||
"'Prompt_logprobs set to invalid negative value: -1',"
|
||||
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
|
||||
assert str(err_info.value) == expected_err_string
|
||||
else:
|
||||
completion = await client.completions.create(**params)
|
||||
if prompt_logprobs and prompt_logprobs > 0:
|
||||
if prompt_logprobs is not None:
|
||||
assert completion.choices[0].prompt_logprobs is not None
|
||||
assert len(completion.choices[0].prompt_logprobs) > 0
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -31,6 +30,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
@ -427,8 +427,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> List[int]:
|
||||
"""Async version of :meth:`_tokenize_prompt`."""
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
tokenizer = self.get_tokenizer_group(
|
||||
missing_msg="prompts must be None if skip_tokenizer_init is True")
|
||||
|
||||
return await tokenizer.encode_async(request_id=request_id,
|
||||
prompt=prompt,
|
||||
@ -771,7 +771,7 @@ class AsyncLLMEngine:
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> "PreTrainedTokenizer":
|
||||
) -> AnyTokenizer:
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_tokenizer.remote( # type: ignore
|
||||
lora_request)
|
||||
|
||||
@ -3,9 +3,9 @@ from contextlib import contextmanager
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
||||
Mapping, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Type, TypeVar, Union
|
||||
from typing import Set, Tuple, Type, Union
|
||||
|
||||
from typing_extensions import assert_never
|
||||
from typing_extensions import TypeVar, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
@ -43,8 +43,9 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import Counter, Device
|
||||
@ -67,6 +68,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
return config.to_diff_dict()
|
||||
|
||||
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
||||
|
||||
PromptComponents = Tuple[Optional[str], List[int],
|
||||
@ -493,12 +495,21 @@ class LLMEngine:
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
def get_tokenizer_group(
|
||||
self,
|
||||
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(fail_msg)
|
||||
self,
|
||||
group_type: Type[_G] = BaseTokenizerGroup,
|
||||
*,
|
||||
missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
|
||||
) -> _G:
|
||||
tokenizer_group = self.tokenizer
|
||||
|
||||
return self.tokenizer
|
||||
if tokenizer_group is None:
|
||||
raise ValueError(missing_msg)
|
||||
if not isinstance(tokenizer_group, group_type):
|
||||
raise TypeError("Invalid type of tokenizer group. "
|
||||
f"Expected type: {group_type}, but "
|
||||
f"found type: {type(tokenizer_group)}")
|
||||
|
||||
return tokenizer_group
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
@ -693,8 +704,8 @@ class LLMEngine:
|
||||
* prompt token ids
|
||||
'''
|
||||
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
tokenizer = self.get_tokenizer_group(
|
||||
missing_msg="prompts must be None if skip_tokenizer_init is True")
|
||||
|
||||
return tokenizer.encode(request_id=request_id,
|
||||
prompt=prompt,
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
|
||||
@ -29,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: List[Scheduler],
|
||||
seq_counter: Counter,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
||||
stop_checker: "StopChecker",
|
||||
):
|
||||
"""Create an output processor.
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
import functools
|
||||
from typing import Callable, List
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.interfaces import (
|
||||
SequenceGroupOutputProcessor)
|
||||
@ -12,6 +10,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
|
||||
SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -36,7 +35,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: List[Scheduler],
|
||||
seq_counter: Counter,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
||||
stop_checker: StopChecker,
|
||||
):
|
||||
self.detokenizer = detokenizer
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceStatus
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class StopChecker:
|
||||
@ -15,8 +14,7 @@ class StopChecker:
|
||||
"""
|
||||
|
||||
def __init__(self, max_model_len: int,
|
||||
get_tokenizer_for_seq: Callable[[Sequence],
|
||||
PreTrainedTokenizer]):
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
|
||||
# Do not use it directly, but use `self._get_max_model_len`.
|
||||
self._max_model_len = max_model_len
|
||||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
|
||||
runtime_checkable)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.inputs.data import PromptInputs
|
||||
@ -12,6 +10,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@ -40,6 +39,7 @@ class AsyncEngineClient(Protocol):
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generates outputs for a request"""
|
||||
...
|
||||
|
||||
def encode(
|
||||
self,
|
||||
@ -50,6 +50,7 @@ class AsyncEngineClient(Protocol):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model."""
|
||||
...
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
@ -60,25 +61,29 @@ class AsyncEngineClient(Protocol):
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
...
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
...
|
||||
"""Get the decoding configuration of the vLLM engine."""
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PreTrainedTokenizer:
|
||||
"""Get the appropriate Tokenizer for the request"""
|
||||
) -> AnyTokenizer:
|
||||
"""Get the appropriate tokenizer for the request"""
|
||||
...
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
pass
|
||||
...
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None,
|
||||
) -> None:
|
||||
pass
|
||||
...
|
||||
|
||||
async def check_health(self) -> None:
|
||||
"""Raise if unhealthy"""
|
||||
...
|
||||
|
||||
@ -61,6 +61,7 @@ async def generate(request: Request) -> Response:
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [
|
||||
prompt + output.text for output in request_output.outputs
|
||||
]
|
||||
@ -80,6 +81,7 @@ async def generate(request: Request) -> Response:
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
return JSONResponse(ret)
|
||||
@ -115,6 +117,7 @@ async def run_server(args: Namespace,
|
||||
logger.info("args: %s", args)
|
||||
|
||||
app = await init_app(args, llm_engine)
|
||||
assert engine is not None
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
|
||||
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
|
||||
Union, cast)
|
||||
Union)
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -15,9 +15,8 @@ from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
||||
# yapf: enable
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from pydantic import ConfigDict
|
||||
from transformers import PreTrainedTokenizer
|
||||
from typing_extensions import Required, TypedDict
|
||||
from pydantic import ConfigDict, TypeAdapter
|
||||
from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -50,9 +49,9 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartAudioParam,
|
||||
CustomChatCompletionContentPartParam]
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||
CustomChatCompletionContentPartParam, ]
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
@ -114,7 +113,7 @@ def load_chat_template(
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
|
||||
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
|
||||
modality: Literal["image", "audio"]) -> Optional[str]:
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
@ -151,11 +150,16 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str,
|
||||
return f"{placeholder_token_str}\n{text_prompt}"
|
||||
|
||||
|
||||
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
|
||||
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
|
||||
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
|
||||
|
||||
|
||||
def _parse_chat_message_content_parts(
|
||||
role: str,
|
||||
parts: Iterable[ChatCompletionContentPartParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> ChatMessageParseResult:
|
||||
texts: List[str] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
@ -164,7 +168,7 @@ def _parse_chat_message_content_parts(
|
||||
for part in parts:
|
||||
part_type = part["type"]
|
||||
if part_type == "text":
|
||||
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
||||
text = _TextParser.validate_python(part)["text"]
|
||||
texts.append(text)
|
||||
elif part_type == "image_url":
|
||||
modality = "image"
|
||||
@ -172,8 +176,7 @@ def _parse_chat_message_content_parts(
|
||||
raise NotImplementedError(
|
||||
"Multiple multimodal inputs is currently not supported.")
|
||||
|
||||
image_url = cast(ChatCompletionContentPartImageParam,
|
||||
part)["image_url"]
|
||||
image_url = _ImageParser.validate_python(part)["image_url"]
|
||||
|
||||
if image_url.get("detail", "auto") != "auto":
|
||||
logger.warning(
|
||||
@ -188,8 +191,7 @@ def _parse_chat_message_content_parts(
|
||||
raise NotImplementedError(
|
||||
"Multiple multimodal inputs is currently not supported.")
|
||||
|
||||
audio_url = cast(ChatCompletionContentPartAudioParam,
|
||||
part)["audio_url"]
|
||||
audio_url = _AudioParser.validate_python(part)["audio_url"]
|
||||
audio_future = async_get_and_parse_audio(audio_url["url"])
|
||||
mm_futures.append(audio_future)
|
||||
else:
|
||||
@ -219,7 +221,7 @@ def _parse_chat_message_content_parts(
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> ChatMessageParseResult:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
@ -230,14 +232,18 @@ def _parse_chat_message_content(
|
||||
messages = [ConversationMessage(role=role, content=content)]
|
||||
return ChatMessageParseResult(messages=messages, mm_futures=[])
|
||||
|
||||
return _parse_chat_message_content_parts(role, content, model_config,
|
||||
tokenizer)
|
||||
return _parse_chat_message_content_parts(
|
||||
role,
|
||||
content, # type: ignore
|
||||
model_config,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
|
||||
def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
@ -20,7 +19,9 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, deprecate_kwargs
|
||||
|
||||
@ -122,7 +123,7 @@ class LLM:
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
@ -175,22 +176,19 @@ class LLM:
|
||||
engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
return self.llm_engine.tokenizer.tokenizer
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
|
||||
|
||||
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
|
||||
|
||||
def set_tokenizer(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
) -> None:
|
||||
# While CachedTokenizer is dynamic, have no choice but
|
||||
# compare class name. Misjudgment will arise from
|
||||
# user-defined tokenizer started with 'Cached'
|
||||
if tokenizer.__class__.__name__.startswith("Cached"):
|
||||
self.llm_engine.tokenizer.tokenizer = tokenizer
|
||||
tokenizer_group.tokenizer = tokenizer
|
||||
else:
|
||||
self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
|
||||
tokenizer)
|
||||
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
|
||||
|
||||
@overload # LEGACY: single (prompt + optional token ids)
|
||||
def generate(
|
||||
@ -578,6 +576,8 @@ class LLM:
|
||||
|
||||
inputs: List[PromptInputs] = []
|
||||
for i in range(num_requests):
|
||||
item: PromptInputs
|
||||
|
||||
if prompts is not None:
|
||||
item = TextPrompt(prompt=prompts[i])
|
||||
elif prompt_token_ids is not None:
|
||||
@ -635,7 +635,7 @@ class LLM:
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
|
||||
@ -15,6 +15,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.routing import Mount
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
@ -29,14 +30,16 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
||||
from vllm.entrypoints.openai.rpc.server import run_rpc_server
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
@ -90,7 +93,8 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[AsyncEngineClient]:
|
||||
# Context manager to handle async_engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
global engine_args
|
||||
@ -142,12 +146,15 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||
logger.info("Started engine process with PID %d",
|
||||
rpc_server_process.pid)
|
||||
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||
async_engine_client = AsyncEngineRPCClient(rpc_path)
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
rpc_client = AsyncEngineRPCClient(rpc_path)
|
||||
async_engine_client = rpc_client # type: ignore
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await async_engine_client.setup()
|
||||
await rpc_client.setup()
|
||||
break
|
||||
except TimeoutError as e:
|
||||
if not rpc_server_process.is_alive():
|
||||
@ -161,7 +168,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||
rpc_server_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
async_engine_client.close()
|
||||
rpc_client.close()
|
||||
|
||||
# Wait for server process to join
|
||||
rpc_server_process.join()
|
||||
@ -216,10 +223,11 @@ async def tokenize(request: TokenizeRequest):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
else:
|
||||
assert isinstance(generator, TokenizeResponse)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/detokenize")
|
||||
async def detokenize(request: DetokenizeRequest):
|
||||
@ -227,10 +235,11 @@ async def detokenize(request: DetokenizeRequest):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
else:
|
||||
assert isinstance(generator, DetokenizeResponse)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models():
|
||||
@ -252,13 +261,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
if request.stream:
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
assert isinstance(generator, ChatCompletionResponse)
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
@ -267,12 +274,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
if request.stream:
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
@ -281,9 +287,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
else:
|
||||
elif isinstance(generator, EmbeddingResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@ -7,6 +7,7 @@ purposes.
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
@ -16,8 +17,19 @@ from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
lora_list = []
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
lora_list: List[LoRAModulePath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
@ -26,8 +38,19 @@ class LoRAParserAction(argparse.Action):
|
||||
|
||||
class PromptAdapterParserAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
adapter_list = []
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
adapter_list: List[PromptAdapterPath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
adapter_list.append(PromptAdapterPath(name, path))
|
||||
|
||||
@ -2,9 +2,9 @@ from functools import lru_cache, partial
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class AllowedTokenIdsLogitsProcessor:
|
||||
@ -51,10 +51,11 @@ def logit_bias_logits_processor(
|
||||
|
||||
|
||||
def get_logits_processors(
|
||||
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
|
||||
allowed_token_ids: Optional[List[int]],
|
||||
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]:
|
||||
logits_processors = []
|
||||
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
|
||||
allowed_token_ids: Optional[List[int]],
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> List[LogitsProcessor]:
|
||||
logits_processors: List[LogitsProcessor] = []
|
||||
if logit_bias:
|
||||
try:
|
||||
# Convert token_id to integer
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from transformers import PreTrainedTokenizer
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
@ -14,11 +13,13 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# torch is mocked during docs generation,
|
||||
# so we have to provide the values as literals
|
||||
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
|
||||
_LONG_INFO: Union["torch.iinfo", Namespace]
|
||||
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
@ -235,13 +236,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_sampling_params(
|
||||
self, tokenizer: PreTrainedTokenizer,
|
||||
self, tokenizer: AnyTokenizer,
|
||||
guided_decode_logits_processor: Optional[LogitsProcessor],
|
||||
default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.top_logprobs
|
||||
|
||||
# We now allow logprobs being true without top_logrobs.
|
||||
logits_processors = get_logits_processors(
|
||||
logit_bias=self.logit_bias,
|
||||
@ -251,7 +256,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
if guided_decode_logits_processor:
|
||||
logits_processors.append(guided_decode_logits_processor)
|
||||
|
||||
return SamplingParams(
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
@ -265,8 +270,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.top_logprobs if self.logprobs else None,
|
||||
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
|
||||
(self.top_logprobs if self.echo else None),
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
@ -280,14 +284,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, values):
|
||||
if (values.get('stream_options') is not None
|
||||
and not values.get('stream')):
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError(
|
||||
"stream_options can only be set if stream is true")
|
||||
return values
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and prompt_logprobs > 0:
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.")
|
||||
|
||||
if prompt_logprobs < 0:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value.")
|
||||
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0:
|
||||
raise ValueError("`top_logprobs` must be a positive value.")
|
||||
|
||||
if not data.get("logprobs"):
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@ -320,19 +346,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"When using `tool_choice`, `tools` must be set.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if "top_logprobs" in data and data["top_logprobs"] is not None:
|
||||
if "logprobs" not in data or data["logprobs"] is False:
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
)
|
||||
elif data["top_logprobs"] < 0:
|
||||
raise ValueError(
|
||||
"`top_logprobs` must be a value a positive value.")
|
||||
return data
|
||||
|
||||
|
||||
class CompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
@ -422,13 +435,17 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_sampling_params(
|
||||
self, tokenizer: PreTrainedTokenizer,
|
||||
self, tokenizer: AnyTokenizer,
|
||||
guided_decode_logits_processor: Optional[LogitsProcessor],
|
||||
default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.logprobs
|
||||
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
||||
logits_processors = get_logits_processors(
|
||||
@ -439,7 +456,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
if guided_decode_logits_processor:
|
||||
logits_processors.append(guided_decode_logits_processor)
|
||||
|
||||
return SamplingParams(
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
@ -458,8 +475,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
min_tokens=self.min_tokens,
|
||||
use_beam_search=self.use_beam_search,
|
||||
early_stopping=self.early_stopping,
|
||||
prompt_logprobs=self.prompt_logprobs
|
||||
if self.prompt_logprobs else self.logprobs if self.echo else None,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
@ -485,9 +501,17 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if "logprobs" in data and data[
|
||||
"logprobs"] is not None and not data["logprobs"] >= 0:
|
||||
raise ValueError("if passed, `logprobs` must be a positive value.")
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and prompt_logprobs > 0:
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.")
|
||||
|
||||
if prompt_logprobs < 0:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value.")
|
||||
|
||||
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
|
||||
raise ValueError("`logprobs` must be a positive value.")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -495,7 +519,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when stream is true.")
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@ -504,7 +529,7 @@ class EmbeddingRequest(OpenAIBaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str
|
||||
input: Union[List[int], List[List[int]], str, List[str]]
|
||||
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
@ -23,8 +23,8 @@ class AsyncEngineRPCServer:
|
||||
def __init__(self, async_engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, rpc_path: str):
|
||||
# Initialize engine first.
|
||||
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
|
||||
usage_context)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(
|
||||
async_engine_args, usage_context=usage_context)
|
||||
|
||||
# Initialize context.
|
||||
self.context = zmq.asyncio.Context()
|
||||
@ -39,7 +39,7 @@ class AsyncEngineRPCServer:
|
||||
self.context.destroy()
|
||||
self.engine.shutdown_background_loop()
|
||||
# Clear the engine reference so that it can be GC'ed.
|
||||
self.engine = None
|
||||
del self.engine
|
||||
|
||||
async def get_model_config(self, identity):
|
||||
"""Send the ModelConfig"""
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
@ -24,13 +23,14 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -67,9 +67,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
||||
ChatCompletionResponse]:
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
||||
ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
@ -83,16 +83,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.prompt_logprobs is not None:
|
||||
if request.stream and request.prompt_logprobs > 0:
|
||||
return self.create_error_response(
|
||||
"Prompt_logprobs are not available when stream is enabled")
|
||||
|
||||
if request.prompt_logprobs < 0:
|
||||
return self.create_error_response(
|
||||
f"Prompt_logprobs set to invalid "
|
||||
f"negative value: {request.prompt_logprobs}")
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@ -160,9 +150,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
engine_inputs: PromptInputs = {
|
||||
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
|
||||
}
|
||||
engine_inputs = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
@ -214,11 +203,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.served_model_names[0]
|
||||
created_time = int(time.time())
|
||||
chunk_object_type = "chat.completion.chunk"
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
@ -438,7 +427,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.served_model_names[0]
|
||||
@ -523,7 +512,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def _get_top_logprobs(
|
||||
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
|
||||
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
|
||||
return [
|
||||
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
||||
p[1],
|
||||
@ -541,12 +530,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
|
||||
logprobs_content = []
|
||||
logprobs_content: List[ChatCompletionLogProbsContent] = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
@ -554,23 +542,32 @@ class OpenAIServingChat(OpenAIServing):
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=token,
|
||||
bytes=list(token.encode("utf-8", errors="replace"))))
|
||||
bytes=list(token.encode("utf-8", errors="replace")),
|
||||
))
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
step_decoded = step_token.decoded_token
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=self._get_decoded_token(
|
||||
step_top_logprobs[token_id], token_id, tokenizer,
|
||||
self.return_tokens_as_token_ids),
|
||||
logprob=max(step_top_logprobs[token_id].logprob,
|
||||
-9999.0),
|
||||
bytes=list(
|
||||
step_top_logprobs[token_id].decoded_token.encode(
|
||||
"utf-8", errors="replace")),
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
self.return_tokens_as_token_ids,
|
||||
),
|
||||
logprob=max(step_token.logprob, -9999.0),
|
||||
bytes=None if step_decoded is None else list(
|
||||
step_decoded.encode("utf-8", errors="replace")),
|
||||
top_logprobs=self._get_top_logprobs(
|
||||
step_top_logprobs, num_output_top_logprobs,
|
||||
tokenizer)))
|
||||
step_top_logprobs,
|
||||
num_output_top_logprobs,
|
||||
tokenizer,
|
||||
),
|
||||
))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
|
||||
@ -3,10 +3,9 @@ import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple, cast
|
||||
from typing import Tuple, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
@ -19,7 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
UsageInfo)
|
||||
ErrorResponse, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing,
|
||||
@ -29,6 +28,7 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -60,8 +60,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
async def create_completion(self, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
@ -84,15 +87,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.time())
|
||||
|
||||
if request.prompt_logprobs is not None:
|
||||
if request.stream and request.prompt_logprobs > 0:
|
||||
return self.create_error_response(
|
||||
"Prompt_logprobs are not available when stream is enabled")
|
||||
elif request.prompt_logprobs < 0:
|
||||
return self.create_error_response(
|
||||
f"Prompt_logprobs set to invalid negative "
|
||||
f"value: {request.prompt_logprobs}")
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
@ -153,9 +147,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator: AsyncIterator[Tuple[
|
||||
int, RequestOutput]] = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
result_generator = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use
|
||||
@ -227,7 +220,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_texts = [""] * num_choices * num_prompts
|
||||
@ -236,6 +229,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
prompt_text = res.prompt
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[Dict[
|
||||
int, Logprob]]]]
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
@ -244,19 +244,25 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
assert prompt_text is not None
|
||||
# only return the prompt
|
||||
delta_text = res.prompt
|
||||
delta_token_ids = res.prompt_token_ids
|
||||
out_logprobs = res.prompt_logprobs
|
||||
delta_text = prompt_text
|
||||
delta_token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
has_echoed[i] = True
|
||||
elif (request.echo and request.max_tokens > 0
|
||||
and not has_echoed[i]):
|
||||
assert prompt_text is not None
|
||||
assert prompt_logprobs is not None
|
||||
# echo the prompt and first token
|
||||
delta_text = res.prompt + output.text
|
||||
delta_token_ids = (res.prompt_token_ids +
|
||||
output.token_ids)
|
||||
out_logprobs = res.prompt_logprobs + (output.logprobs
|
||||
or [])
|
||||
delta_text = prompt_text + output.text
|
||||
delta_token_ids = [
|
||||
*prompt_token_ids, *output.token_ids
|
||||
]
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
@ -301,7 +307,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats
|
||||
or output.finish_reason is not None):
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
prompt_tokens = len(prompt_token_ids)
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
@ -342,7 +348,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> CompletionResponse:
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
@ -353,16 +359,31 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
|
||||
Logprob]]]]
|
||||
|
||||
for output in final_res.outputs:
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
assert prompt_text is not None
|
||||
token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
output_text = prompt_text
|
||||
elif request.echo and request.max_tokens > 0:
|
||||
token_ids = prompt_token_ids + list(output.token_ids)
|
||||
out_logprobs = (prompt_logprobs + output.logprobs
|
||||
if request.logprobs is not None else None)
|
||||
assert prompt_text is not None
|
||||
token_ids = [*prompt_token_ids, *output.token_ids]
|
||||
|
||||
if request.logprobs is None:
|
||||
out_logprobs = None
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert output.logprobs is not None
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*output.logprobs,
|
||||
]
|
||||
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
@ -413,7 +434,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
initial_text_offset: int = 0,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
@ -430,17 +451,21 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
token = self._get_decoded_token(
|
||||
step_top_logprobs[token_id],
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids)
|
||||
token_logprob = max(step_top_logprobs[token_id].logprob,
|
||||
-9999.0)
|
||||
return_as_token_id=self.return_tokens_as_token_ids,
|
||||
)
|
||||
token_logprob = max(step_token.logprob, -9999.0)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(token_logprob)
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
|
||||
Union, cast)
|
||||
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
@ -16,7 +16,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingRequestOutput
|
||||
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -24,18 +24,28 @@ logger = init_logger(__name__)
|
||||
TypeTokenIDs = List[int]
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: EmbeddingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[List[float], str]:
|
||||
if encoding_format == "float":
|
||||
return output.embedding
|
||||
elif encoding_format == "base64":
|
||||
embedding_bytes = np.array(output.embedding).tobytes()
|
||||
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
def request_output_to_embedding_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str,
|
||||
encoding_format: str) -> EmbeddingResponse:
|
||||
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
|
||||
data: List[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
embedding = final_res.outputs.embedding
|
||||
if encoding_format == "base64":
|
||||
embedding_bytes = np.array(embedding).tobytes()
|
||||
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
embedding = _get_embedding(final_res.outputs, encoding_format)
|
||||
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
||||
data.append(embedding_data)
|
||||
|
||||
@ -76,8 +86,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[ErrorResponse, EmbeddingResponse]:
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
@ -89,8 +99,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = (request.encoding_format
|
||||
if request.encoding_format else "float")
|
||||
encoding_format = request.encoding_format
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
@ -145,11 +154,10 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator: AsyncIterator[Tuple[
|
||||
int, EmbeddingRequestOutput]] = merge_async_iterators(
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected
|
||||
if raw_request else None)
|
||||
result_generator = merge_async_iterators(
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
@ -175,7 +183,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
return response
|
||||
|
||||
def _check_embedding_mode(self, embedding_mode: bool):
|
||||
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
|
||||
if not embedding_mode:
|
||||
logger.warning(
|
||||
"embedding_mode is False. Embedding API will not work.")
|
||||
|
||||
@ -31,7 +31,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -153,6 +153,68 @@ class SamplingParams(
|
||||
output_text_buffer_length: int = 0
|
||||
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: Optional[int] = 1,
|
||||
best_of: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = 0.0,
|
||||
frequency_penalty: Optional[float] = 0.0,
|
||||
repetition_penalty: Optional[float] = 1.0,
|
||||
temperature: Optional[float] = 1.0,
|
||||
top_p: Optional[float] = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
seed: Optional[int] = None,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
min_tokens: int = 0,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
detokenize: bool = True,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=1)]] = None,
|
||||
) -> "SamplingParams":
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
presence_penalty=0.0
|
||||
if presence_penalty is None else presence_penalty,
|
||||
frequency_penalty=0.0
|
||||
if frequency_penalty is None else frequency_penalty,
|
||||
repetition_penalty=1.0
|
||||
if repetition_penalty is None else repetition_penalty,
|
||||
temperature=1.0 if temperature is None else temperature,
|
||||
top_p=1.0 if top_p is None else top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
seed=seed,
|
||||
use_beam_search=use_beam_search,
|
||||
length_penalty=length_penalty,
|
||||
early_stopping=early_stopping,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
ignore_eos=ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
detokenize=detokenize,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.best_of = self.best_of or self.n
|
||||
if 0 < self.temperature < _MAX_TEMP:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user