[Frontend] Refactor prompt processing (#4028)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2024-07-23 01:13:53 +08:00 committed by GitHub
parent 89c1c6a196
commit 739b61a348
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 699 additions and 391 deletions

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptStrictInputs from vllm.inputs import PromptInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size, size=(args.batch_size,
args.input_len)) args.input_len))
dummy_inputs: List[PromptStrictInputs] = [{ dummy_inputs: List[PromptInputs] = [{
"prompt_token_ids": batch "prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()] } for batch in dummy_prompt_token_ids.tolist()]

View File

@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>` Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`. via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`. by following :ref:`this guide <adding_multimodal_plugin>`.

View File

@ -1,7 +1,7 @@
LLM Inputs LLM Inputs
========== ==========
.. autodata:: vllm.inputs.PromptStrictInputs .. autodata:: vllm.inputs.PromptInputs
.. autoclass:: vllm.inputs.TextPrompt .. autoclass:: vllm.inputs.TextPrompt
:show-inheritance: :show-inheritance:

View File

@ -30,7 +30,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
internally for each model. internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

View File

@ -35,8 +35,8 @@ def sequence_with_eos(text: str, eos_token: str,
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [ @pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
("This text ends with EOS token", "</s>", 2), ("This text ends with EOS token", "</s>", 2),
]) ])
@pytest.mark.parametrize("ignore_eos", [True, False, None]) @pytest.mark.parametrize("ignore_eos", [True, False])
@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None]) @pytest.mark.parametrize("include_stop_str_in_output", [True, False])
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
ignore_eos: bool, include_stop_str_in_output: bool): ignore_eos: bool, include_stop_str_in_output: bool):

View File

@ -32,7 +32,10 @@ async def _async_serving_chat_init():
model_config, model_config,
served_model_names=[MODEL_NAME], served_model_names=[MODEL_NAME],
response_role="assistant", response_role="assistant",
chat_template=CHAT_TEMPLATE) chat_template=CHAT_TEMPLATE,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_completion return serving_completion

View File

@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput, from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput) EmbeddingRequestOutput, RequestOutput)
@ -19,7 +19,7 @@ __all__ = [
"__version__", "__version__",
"LLM", "LLM",
"ModelRegistry", "ModelRegistry",
"PromptStrictInputs", "PromptInputs",
"TextPrompt", "TextPrompt",
"TokensPrompt", "TokensPrompt",
"SamplingParams", "SamplingParams",

View File

@ -827,7 +827,6 @@ class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine.""" """Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False engine_use_ray: bool = False
disable_log_requests: bool = False disable_log_requests: bool = False
max_log_len: Optional[int] = None
@staticmethod @staticmethod
def add_cli_args(parser: FlexibleArgumentParser, def add_cli_args(parser: FlexibleArgumentParser,
@ -841,12 +840,6 @@ class AsyncEngineArgs(EngineArgs):
parser.add_argument('--disable-log-requests', parser.add_argument('--disable-log-requests',
action='store_true', action='store_true',
help='Disable logging requests.') help='Disable logging requests.')
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser return parser

View File

@ -1,8 +1,8 @@
import asyncio import asyncio
import time import time
from functools import partial from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
Set, Tuple, Type, Union) Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -151,7 +151,10 @@ class RequestTracker:
logger.info("Finished request %s.", request_id) logger.info("Finished request %s.", request_id)
self.abort_request(request_id) self.abort_request(request_id)
def add_request(self, request_id: str, def add_request(self,
request_id: str,
*,
verbose: bool = False,
**engine_add_request_kwargs) -> AsyncStream: **engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background """Add a request to be sent to the engine on the next background
loop iteration.""" loop iteration."""
@ -166,6 +169,9 @@ class RequestTracker:
self.new_requests_event.set() self.new_requests_event.set()
if verbose:
logger.info("Added request %s.", request_id)
return stream return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None: def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@ -305,8 +311,8 @@ class _AsyncLLMEngine(LLMEngine):
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@ -353,8 +359,6 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the async frontend will be executed in a separate process as the
model workers. model workers.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
max_log_len: Maximum number of prompt characters or prompt ID numbers
being printed in log.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
*args: Arguments for :class:`LLMEngine`. *args: Arguments for :class:`LLMEngine`.
@ -368,13 +372,11 @@ class AsyncLLMEngine:
engine_use_ray: bool, engine_use_ray: bool,
*args, *args,
log_requests: bool = True, log_requests: bool = True,
max_log_len: Optional[int] = None,
start_engine_loop: bool = True, start_engine_loop: bool = True,
**kwargs) -> None: **kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray self.engine_use_ray = engine_use_ray
self.log_requests = log_requests self.log_requests = log_requests
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs) self.engine = self._init_engine(*args, **kwargs)
self.background_loop: Optional[asyncio.Future] = None self.background_loop: Optional[asyncio.Future] = None
@ -468,7 +470,6 @@ class AsyncLLMEngine:
executor_class=executor_class, executor_class=executor_class,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop, start_engine_loop=start_engine_loop,
usage_context=usage_context, usage_context=usage_context,
stat_loggers=stat_loggers, stat_loggers=stat_loggers,
@ -667,30 +668,9 @@ class AsyncLLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests:
if isinstance(inputs, str):
shortened_prompt = inputs
shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, params,
shortened_token_ids, lora_request)
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
self.start_background_loop() self.start_background_loop()
@ -706,6 +686,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
verbose=self.log_requests,
inputs=inputs, inputs=inputs,
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
@ -721,7 +702,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]: ) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request. """Generate outputs for a request.
@ -804,7 +785,7 @@ class AsyncLLMEngine:
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]: ) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
@ -882,7 +863,7 @@ class AsyncLLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
*, *,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or """Common logic to process requests with SamplingParams or

View File

@ -1,6 +1,7 @@
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union from typing import Set, Type, TypeVar, Union
@ -522,7 +523,7 @@ class LLMEngine:
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> None: ) -> None:
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
@ -603,7 +604,7 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
@ -677,7 +678,7 @@ class LLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""

View File

@ -6,8 +6,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
TextTokensPrompt, TokensPrompt,
parse_and_batch_prompt) parse_and_batch_prompt)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -238,7 +237,7 @@ class LLM:
@overload @overload
def generate( def generate(
self, self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API /, # We may enable `inputs` keyword after removing the old API
*, *,
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
@ -255,7 +254,7 @@ class LLM:
"instead.") "instead.")
def generate( def generate(
self, self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None, Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None, Sequence[SamplingParams]]] = None,
@ -302,9 +301,7 @@ class LLM:
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
) )
else: else:
inputs = cast( inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
@ -383,7 +380,7 @@ class LLM:
@overload @overload
def encode( def encode(
self, self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API /, # We may enable `inputs` keyword after removing the old API
*, *,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
@ -400,7 +397,7 @@ class LLM:
"instead.") "instead.")
def encode( def encode(
self, self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None, Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
@ -417,7 +414,7 @@ class LLM:
Args: Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptStrictInputs` batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters. use the default pooling parameters.
@ -446,9 +443,7 @@ class LLM:
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
) )
else: else:
inputs = cast( inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
@ -496,14 +491,8 @@ class LLM:
inputs: List[PromptInputs] = [] inputs: List[PromptInputs] = []
for i in range(num_requests): for i in range(num_requests):
if prompts is not None: if prompts is not None:
if prompt_token_ids is not None:
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
else:
item = TextPrompt(prompt=prompts[i]) item = TextPrompt(prompt=prompts[i])
else: elif prompt_token_ids is not None:
if prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else: else:
raise AssertionError raise AssertionError
@ -514,7 +503,7 @@ class LLM:
def _validate_and_add_requests( def _validate_and_add_requests(
self, self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], inputs: Union[PromptInputs, Sequence[PromptInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]], Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],

View File

@ -0,0 +1,41 @@
from typing import List, Optional, Union
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
logger = init_logger(__name__)
class RequestLogger:
def __init__(self, *, max_log_len: Optional[int]) -> None:
super().__init__()
self.max_log_len = max_log_len
def log_inputs(
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
params: Optional[Union[SamplingParams, PoolingParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
if prompt is not None:
prompt = prompt[:max_log_len]
if prompt_token_ids is not None:
prompt_token_ids = prompt_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s, prompt_adapter_request: %s.", request_id,
prompt, params, prompt_token_ids, lora_request,
prompt_adapter_request)

View File

@ -18,6 +18,7 @@ from starlette.routing import Mount
import vllm.envs as envs import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -244,24 +245,48 @@ def run_server(args, llm_engine=None):
# When using single vLLM without engine_use_ray # When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config()) model_config = asyncio.run(engine.get_model_config())
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
global openai_serving_chat global openai_serving_chat
global openai_serving_completion global openai_serving_completion
global openai_serving_embedding global openai_serving_embedding
global openai_serving_tokenization global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat(engine, model_config, openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names, served_model_names,
args.response_role, args.response_role,
args.lora_modules, lora_modules=args.lora_modules,
args.chat_template) prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
)
openai_serving_completion = OpenAIServingCompletion( openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules, engine,
args.prompt_adapters) model_config,
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names,
served_model_names) lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
served_model_names,
request_logger=request_logger,
)
openai_serving_tokenization = OpenAIServingTokenization( openai_serving_tokenization = OpenAIServingTokenization(
engine, model_config, served_model_names, args.lora_modules, engine,
args.chat_template) model_config,
served_model_names,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
)
app.root_path = args.root_path app.root_path = args.root_path
logger.info("Available routes are:") logger.info("Available routes are:")

View File

@ -130,6 +130,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"using app.add_middleware(). ") "using app.add_middleware(). ")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser return parser

View File

@ -121,40 +121,42 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: begin-chat-completion-sampling-params # doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None best_of: Optional[int] = None
use_beam_search: Optional[bool] = False use_beam_search: bool = False
top_k: Optional[int] = -1 top_k: int = -1
min_p: Optional[float] = 0.0 min_p: float = 0.0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: float = 1.0
length_penalty: Optional[float] = 1.0 length_penalty: float = 1.0
early_stopping: Optional[bool] = False early_stopping: bool = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True include_stop_str_in_output: bool = False
spaces_between_special_tokens: Optional[bool] = True ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-chat-completion-sampling-params # doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params # doc: begin-chat-completion-extra-params
echo: Optional[bool] = Field( echo: bool = Field(
default=False, default=False,
description=( description=(
"If true, the new message will be prepended with the last message " "If true, the new message will be prepended with the last message "
"if they belong to the same role."), "if they belong to the same role."),
) )
add_generation_prompt: Optional[bool] = Field( add_generation_prompt: bool = Field(
default=True, default=True,
description= description=
("If true, the generation prompt will be added to the chat template. " ("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the " "This is a parameter used by chat template in tokenizer config of the "
"model."), "model."),
) )
add_special_tokens: Optional[bool] = Field( add_special_tokens: bool = Field(
default=False, default=False,
description=( description=(
"If true, special tokens (e.g. BOS) will be added to the prompt " "If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. " "on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the " "For most models, the chat template takes care of adding the "
"special tokens so this should be set to False (as is the " "special tokens so this should be set to false (as is the "
"default)."), "default)."),
) )
documents: Optional[List[Dict[str, str]]] = Field( documents: Optional[List[Dict[str, str]]] = Field(
@ -178,12 +180,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=("Additional kwargs to pass to the template renderer. " description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."), "Will be accessible by the chat template."),
) )
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field( guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None, default=None,
description=("If specified, the output will follow the JSON schema."), description=("If specified, the output will follow the JSON schema."),
@ -244,22 +240,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
return SamplingParams( return SamplingParams(
n=self.n, n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty, frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
temperature=self.temperature, temperature=self.temperature,
top_p=self.top_p, top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p, min_p=self.min_p,
seed=self.seed, seed=self.seed,
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
logprobs=self.top_logprobs if self.logprobs else None, logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None, prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping, early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
@ -267,6 +263,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
) )
@model_validator(mode='before') @model_validator(mode='before')
@ -348,26 +345,27 @@ class CompletionRequest(OpenAIBaseModel):
user: Optional[str] = None user: Optional[str] = None
# doc: begin-completion-sampling-params # doc: begin-completion-sampling-params
use_beam_search: Optional[bool] = False use_beam_search: bool = False
top_k: Optional[int] = -1 top_k: int = -1
min_p: Optional[float] = 0.0 min_p: float = 0.0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: float = 1.0
length_penalty: Optional[float] = 1.0 length_penalty: float = 1.0
early_stopping: Optional[bool] = False early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False include_stop_str_in_output: bool = False
min_tokens: Optional[int] = 0 ignore_eos: bool = False
skip_special_tokens: Optional[bool] = True min_tokens: int = 0
spaces_between_special_tokens: Optional[bool] = True skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params # doc: end-completion-sampling-params
# doc: begin-completion-extra-params # doc: begin-completion-extra-params
include_stop_str_in_output: Optional[bool] = Field( add_special_tokens: bool = Field(
default=False, default=True,
description=( description=(
"Whether to include the stop string in the output. " "If true (the default), special tokens (e.g. BOS) will be added to "
"This is only applied when the stop or stop_token_ids is set."), "the prompt."),
) )
response_format: Optional[ResponseFormat] = Field( response_format: Optional[ResponseFormat] = Field(
default=None, default=None,
@ -447,15 +445,15 @@ class CompletionRequest(OpenAIBaseModel):
seed=self.seed, seed=self.seed,
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1, max_tokens=self.max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping, early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None, prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens), spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
@ -489,11 +487,11 @@ class CompletionRequest(OpenAIBaseModel):
def validate_stream_options(cls, data): def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"): if data.get("stream_options") and not data.get("stream"):
raise ValueError( raise ValueError(
"Stream options can only be defined when stream is True.") "Stream options can only be defined when stream is true.")
return data return data
class EmbeddingRequest(BaseModel): class EmbeddingRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings # https://platform.openai.com/docs/api-reference/embeddings
model: str model: str
@ -565,13 +563,13 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None) usage: Optional[UsageInfo] = Field(default=None)
class EmbeddingResponseData(BaseModel): class EmbeddingResponseData(OpenAIBaseModel):
index: int index: int
object: str = "embedding" object: str = "embedding"
embedding: Union[List[float], str] embedding: Union[List[float], str]
class EmbeddingResponse(BaseModel): class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "list" object: str = "list"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
@ -670,8 +668,8 @@ class BatchRequestInput(OpenAIBaseModel):
# /v1/chat/completions is supported. # /v1/chat/completions is supported.
url: str url: str
# The parameteters of the request. # The parameters of the request.
body: Union[ChatCompletionRequest, ] body: ChatCompletionRequest
class BatchResponseData(OpenAIBaseModel): class BatchResponseData(OpenAIBaseModel):
@ -703,12 +701,22 @@ class BatchRequestOutput(OpenAIBaseModel):
error: Optional[Any] error: Optional[Any]
class TokenizeRequest(OpenAIBaseModel): class TokenizeCompletionRequest(OpenAIBaseModel):
model: str
prompt: str
add_special_tokens: bool = Field(default=True)
class TokenizeChatRequest(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(default=True) add_generation_prompt: bool = Field(default=True)
add_special_tokens: bool = Field(default=False) add_special_tokens: bool = Field(default=False)
prompt: Optional[str] = Field(default=None)
messages: Optional[List[ChatCompletionMessageParam]] = Field(default=None)
model: str TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
class TokenizeResponse(OpenAIBaseModel): class TokenizeResponse(OpenAIBaseModel):

View File

@ -6,6 +6,7 @@ import aiohttp
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (BatchRequestInput, from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput, BatchRequestOutput,
BatchResponseData, BatchResponseData,
@ -44,9 +45,17 @@ def parse_args():
type=nullable_str, type=nullable_str,
default="assistant", default="assistant",
help="The role name to return if " help="The role name to return if "
"`request.add_generation_prompt=true`.") "`request.add_generation_prompt=True`.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser.parse_args() return parser.parse_args()
@ -114,11 +123,20 @@ async def main(args):
# When using single vLLM without engine_use_ray # When using single vLLM without engine_use_ray
model_config = await engine.get_model_config() model_config = await engine.get_model_config()
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
openai_serving_chat = OpenAIServingChat( openai_serving_chat = OpenAIServingChat(
engine, engine,
model_config, model_config,
served_model_names, served_model_names,
args.response_role, args.response_role,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
) )
# Submit all requests in the file to the engine "concurrently". # Submit all requests in the file to the engine "concurrently".

View File

@ -12,6 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template, load_chat_template,
parse_chat_message_content) parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProb, ChatCompletionLogProbs,
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
@ -20,7 +21,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
FunctionCall, ToolCall, UsageInfo) FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing) OpenAIServing,
PromptAdapterPath)
from vllm.inputs import PromptInputs from vllm.inputs import PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
@ -37,17 +39,24 @@ logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing): class OpenAIServingChat(OpenAIServing):
def __init__(self, def __init__(
self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
response_role: str, response_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None, *,
chat_template: Optional[str] = None): lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine, super().__init__(engine=engine,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger)
self.response_role = response_role self.response_role = response_role
@ -74,7 +83,12 @@ class OpenAIServingChat(OpenAIServing):
return error_check_ret return error_check_ret
try: try:
_, lora_request = self._maybe_get_adapter(request) (
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.engine.get_tokenizer(lora_request)
conversation: List[ConversationMessage] = [] conversation: List[ConversationMessage] = []
@ -82,7 +96,7 @@ class OpenAIServingChat(OpenAIServing):
for msg in request.messages: for msg in request.messages:
chat_parsed_result = parse_chat_message_content( chat_parsed_result = parse_chat_message_content(
msg, self.model_config, tokenizer) msg, model_config, tokenizer)
conversation.extend(chat_parsed_result.messages) conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures) mm_futures.extend(chat_parsed_result.mm_futures)
@ -116,14 +130,8 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error in loading multi-modal data: %s", e) logger.error("Error in loading multi-modal data: %s", e)
return self.create_error_response(str(e)) return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}" request_id = f"chat-{random_uuid()}"
try: try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
request,
tokenizer,
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config() decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \ guided_decoding_backend = request.guided_decoding_backend \
@ -137,31 +145,47 @@ class OpenAIServingChat(OpenAIServing):
sampling_params.logits_processors = [] sampling_params.logits_processors = []
sampling_params.logits_processors.append( sampling_params.logits_processors.append(
guided_decode_logits_processor) guided_decode_logits_processor)
except ValueError as e:
return self.create_error_response(str(e))
inputs: PromptInputs = { prompt_inputs = self._tokenize_prompt_input(
"prompt": prompt_text, request,
"prompt_token_ids": prompt_ids, tokenizer,
prompt,
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
self._log_inputs(request_id,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
engine_inputs: PromptInputs = {
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
} }
if mm_data: if mm_data is not None:
inputs["multi_modal_data"] = mm_data engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled() is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None trace_headers = None
if is_tracing_enabled and raw_request: if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers) trace_headers = extract_trace_headers(raw_request.headers)
if not is_tracing_enabled and raw_request and contains_trace_headers( if (not is_tracing_enabled and raw_request
raw_request.headers): and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning() log_tracing_disabled_warning()
result_generator = self.engine.generate( result_generator = self.engine.generate(
inputs, engine_inputs,
sampling_params, sampling_params,
request_id, request_id,
lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
) )
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
@ -195,10 +219,11 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = True first_iteration = True
# Send response for each token for each request.n (index) # Send response for each token for each request.n (index)
assert request.n is not None num_choices = 1 if request.n is None else request.n
previous_texts = [""] * request.n previous_texts = [""] * num_choices
previous_num_tokens = [0] * request.n previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * request.n finish_reason_sent = [False] * num_choices
try: try:
async for res in result_generator: async for res in result_generator:
# We need to do it here, because if there are exceptions in # We need to do it here, because if there are exceptions in
@ -208,7 +233,7 @@ class OpenAIServingChat(OpenAIServing):
# Send first response for each request.n (index) with # Send first response for each request.n (index) with
# the role # the role
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
for i in range(request.n): for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=DeltaMessage(role=role), delta=DeltaMessage(role=role),
@ -236,19 +261,19 @@ class OpenAIServingChat(OpenAIServing):
last_msg_content = conversation[-1]["content"] last_msg_content = conversation[-1]["content"]
if last_msg_content: if last_msg_content:
for i in range(request.n): for i in range(num_choices):
choice_data = ( choice_data = (
ChatCompletionResponseStreamChoice( ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=DeltaMessage( delta=DeltaMessage(
content=last_msg_content), content=last_msg_content),
logprobs=None,
finish_reason=None)) finish_reason=None))
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
id=request_id, id=request_id,
object=chunk_object_type, object=chunk_object_type,
created=created_time, created=created_time,
choices=[choice_data], choices=[choice_data],
logprobs=None,
model=model_name) model=model_name)
if (request.stream_options and if (request.stream_options and
request.stream_options.include_usage): request.stream_options.include_usage):

View File

@ -2,13 +2,14 @@ import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional) Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple from typing import Tuple, cast
from fastapi import Request from fastapi import Request
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs, from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
@ -39,40 +40,24 @@ TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs] [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
def parse_prompt_format(prompt) -> Tuple[bool, list]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False
prompts = [prompt] # case 1: a string
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
elif isinstance(prompt[0], str):
prompt_is_tokens = False
prompts = prompt # case 2: array of strings
elif isinstance(prompt[0], int):
prompt_is_tokens = True
prompts = [prompt] # case 3: array of tokens
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
return prompt_is_tokens, prompts
class OpenAIServingCompletion(OpenAIServing): class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]], lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]]): prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
):
super().__init__(engine=engine, super().__init__(engine=engine,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules, lora_modules=lora_modules,
prompt_adapters=prompt_adapters) prompt_adapters=prompt_adapters,
request_logger=request_logger)
async def create_completion(self, request: CompletionRequest, async def create_completion(self, request: CompletionRequest,
raw_request: Request): raw_request: Request):
@ -101,12 +86,11 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = [] generators: List[AsyncIterator[RequestOutput]] = []
try: try:
adapter_type, adapter_request = self._maybe_get_adapter(request) (
lora_request, prompt_adapter_request = None, None lora_request,
if adapter_type == 'LoRA': prompt_adapter_request,
lora_request, prompt_adapter_request = adapter_request, None ) = self._maybe_get_adapters(request)
elif adapter_type == 'PromptAdapter':
lora_request, prompt_adapter_request = None, adapter_request
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.engine.get_tokenizer(lora_request)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
@ -122,17 +106,25 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params.logits_processors = [] sampling_params.logits_processors = []
sampling_params.logits_processors.append( sampling_params.logits_processors.append(
guided_decode_logit_processor) guided_decode_logit_processor)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts): prompts = list(
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt" self._tokenize_prompt_input_or_inputs(
prompt_formats = await self._validate_prompt_and_tokenize(
request, request,
tokenizer, tokenizer,
request.prompt,
truncate_prompt_tokens=sampling_params. truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens, truncate_prompt_tokens,
**{prompt_arg: prompt}) add_special_tokens=request.add_special_tokens,
prompt_ids, prompt_text = prompt_formats ))
for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = await self.engine.is_tracing_enabled() is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None trace_headers = None
@ -143,12 +135,9 @@ class OpenAIServingCompletion(OpenAIServing):
log_tracing_disabled_warning() log_tracing_disabled_warning()
generator = self.engine.generate( generator = self.engine.generate(
{ {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params, sampling_params,
f"{request_id}-{i}", request_id_item,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
@ -189,9 +178,27 @@ class OpenAIServingCompletion(OpenAIServing):
await self.engine.abort(f"{request_id}-{i}") await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res_batch[i] = res final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
final_res.prompt = prompts[i]["prompt"]
final_res_batch_checked = cast(List[RequestOutput],
final_res_batch)
response = self.request_output_to_completion_response( response = self.request_output_to_completion_response(
final_res_batch, request, request_id, created_time, model_name, final_res_batch_checked,
tokenizer) request,
request_id,
created_time,
model_name,
tokenizer,
)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@ -220,10 +227,10 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts: int, num_prompts: int,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
assert request.n is not None num_choices = 1 if request.n is None else request.n
previous_texts = [""] * request.n * num_prompts previous_texts = [""] * num_choices * num_prompts
previous_num_tokens = [0] * request.n * num_prompts previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * request.n * num_prompts has_echoed = [False] * num_choices * num_prompts
try: try:
async for prompt_idx, res in result_generator: async for prompt_idx, res in result_generator:
@ -234,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
raise StopAsyncIteration() raise StopAsyncIteration()
for output in res.outputs: for output in res.outputs:
i = output.index + prompt_idx * request.n i = output.index + prompt_idx * num_choices
# TODO(simon): optimize the performance by avoiding full # TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending. # text O(n^2) sending.
@ -343,8 +350,8 @@ class OpenAIServingCompletion(OpenAIServing):
choices: List[CompletionResponseChoice] = [] choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
num_generated_tokens = 0 num_generated_tokens = 0
for final_res in final_res_batch: for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt prompt_text = final_res.prompt

View File

@ -1,16 +1,16 @@
import base64 import base64
import time import time
from typing import AsyncIterator, List, Optional, Tuple from typing import AsyncIterator, List, Optional, Tuple, cast
import numpy as np import numpy as np
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest, from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse, EmbeddingResponse,
EmbeddingResponseData, UsageInfo) EmbeddingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput from vllm.outputs import EmbeddingRequestOutput
@ -28,11 +28,11 @@ def request_output_to_embedding_response(
data: List[EmbeddingResponseData] = [] data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch): for idx, final_res in enumerate(final_res_batch):
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding embedding = final_res.outputs.embedding
if encoding_format == "base64": if encoding_format == "base64":
embedding = base64.b64encode(np.array(embedding)) embedding_bytes = np.array(embedding).tobytes()
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data) data.append(embedding_data)
@ -54,12 +54,20 @@ def request_output_to_embedding_response(
class OpenAIServingEmbedding(OpenAIServing): class OpenAIServingEmbedding(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, def __init__(
served_model_names: List[str]): self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
request_logger: Optional[RequestLogger],
):
super().__init__(engine=engine, super().__init__(engine=engine,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=None) lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self._check_embedding_mode(model_config.embedding_mode) self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(self, request: EmbeddingRequest, async def create_embedding(self, request: EmbeddingRequest,
@ -80,29 +88,47 @@ class OpenAIServingEmbedding(OpenAIServing):
"dimensions is currently not supported") "dimensions is currently not supported")
model_name = request.model model_name = request.model
request_id = f"cmpl-{random_uuid()}" request_id = f"embd-{random_uuid()}"
created_time = int(time.monotonic()) created_time = int(time.monotonic())
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators = [] generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
try: try:
prompt_is_tokens, prompts = parse_prompt_format(request.input) (
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()
tokenizer = await self.engine.get_tokenizer() prompts = list(
for i, prompt in enumerate(prompts): self._tokenize_prompt_input_or_inputs(
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt" request,
prompt_formats = await self._validate_prompt_and_tokenize( tokenizer,
request, tokenizer, **{prompt_arg: prompt}) request.input,
prompt_ids, prompt_text = prompt_formats ))
for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError(
"Prompt adapter is not supported "
"for embedding models")
generator = self.engine.encode( generator = self.engine.encode(
{ {"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
pooling_params, pooling_params,
f"{request_id}-{i}", request_id_item,
lora_request=lora_request,
) )
generators.append(generator) generators.append(generator)
@ -121,11 +147,17 @@ class OpenAIServingEmbedding(OpenAIServing):
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}") await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res_batch[i] = res final_res_batch[i] = res
for final_res in final_res_batch:
assert final_res is not None
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
final_res_batch)
response = request_output_to_embedding_response( response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name, final_res_batch_checked, request_id, created_time, model_name,
encoding_format) encoding_format)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error

View File

@ -2,23 +2,33 @@ import json
import pathlib import pathlib
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from pydantic import Field from pydantic import Field
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest, CompletionRequest,
DetokenizeRequest, DetokenizeRequest,
EmbeddingRequest, ErrorResponse, EmbeddingRequest, ErrorResponse,
ModelCard, ModelList, ModelCard, ModelList,
ModelPermission, TokenizeRequest) ModelPermission,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeRequest)
# yapf: enable
from vllm.inputs import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
logger = init_logger(__name__) logger = init_logger(__name__)
@ -36,6 +46,17 @@ class LoRAModulePath:
local_path: str local_path: str
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: List[int]
class OpenAIServing: class OpenAIServing:
def __init__( def __init__(
@ -43,8 +64,10 @@ class OpenAIServing:
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]], lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]] = None, prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
): ):
super().__init__() super().__init__()
@ -78,6 +101,8 @@ class OpenAIServing:
prompt_adapter_local_path=prompt_adapter.local_path, prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens)) prompt_adapter_num_virtual_tokens=num_virtual_tokens))
self.request_logger = request_logger
async def show_available_models(self) -> ModelList: async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model.""" """Show available models. Right now we only have one model."""
model_cards = [ model_cards = [
@ -126,9 +151,8 @@ class OpenAIServing:
return json_str return json_str
async def _check_model( async def _check_model(
self, request: Union[ChatCompletionRequest, CompletionRequest, self,
DetokenizeRequest, EmbeddingRequest, request: AnyRequest,
TokenizeRequest]
) -> Optional[ErrorResponse]: ) -> Optional[ErrorResponse]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None
@ -144,64 +168,65 @@ class OpenAIServing:
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_adapter( def _maybe_get_adapters(
self, request: Union[CompletionRequest, ChatCompletionRequest, self, request: AnyRequest
EmbeddingRequest, TokenizeRequest, ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
DetokenizeRequest] None, PromptAdapterRequest]]:
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
PromptAdapterRequest]]]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None, None return None, None
for lora in self.lora_requests: for lora in self.lora_requests:
if request.model == lora.lora_name: if request.model == lora.lora_name:
return 'LoRA', lora return lora, None
for prompt_adapter in self.prompt_adapter_requests: for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name: if request.model == prompt_adapter.prompt_adapter_name:
return 'PromptAdapter', prompt_adapter return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable # if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.") raise ValueError(f"The model `{request.model}` does not exist.")
async def _validate_prompt_and_tokenize( def _normalize_prompt_text_to_input(
self, self,
request: Union[ChatCompletionRequest, CompletionRequest, request: AnyRequest,
DetokenizeRequest, EmbeddingRequest, tokenizer: AnyTokenizer,
TokenizeRequest], prompt: str,
tokenizer: "PreTrainedTokenizer", truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
prompt: Optional[str] = None, add_special_tokens: bool,
prompt_ids: Optional[List[int]] = None, ) -> TextTokensPrompt:
truncate_prompt_tokens: Optional[Annotated[int, if truncate_prompt_tokens is None:
Field(ge=1)]] = None, encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
add_special_tokens: Optional[bool] = True
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if prompt and prompt_ids:
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
if prompt_ids is None:
# When using OpenAIServingChat for chat completions, for
# most models the special tokens (e.g., BOS) have already
# been added by the chat template. Therefore, we do not
# need to add them again.
# Set add_special_tokens to False (by default) to avoid
# adding the BOS tokens again.
tokenizer_kwargs: Dict[str, Any] = {
"add_special_tokens": add_special_tokens
}
if truncate_prompt_tokens is not None:
tokenizer_kwargs.update({
"truncation": True,
"max_length": truncate_prompt_tokens,
})
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else: else:
input_ids = prompt_ids encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
input_text = prompt if prompt is not None else tokenizer.decode( input_ids = encoded.input_ids
input_ids)
input_text = prompt
return self._validate_input(request, input_ids, input_text)
def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_ids: List[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
input_ids = prompt_ids
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
input_text = tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: AnyRequest,
input_ids: List[int],
input_text: str,
) -> TextTokensPrompt:
token_num = len(input_ids) token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens # Note: EmbeddingRequest doesn't have max_tokens
@ -211,13 +236,16 @@ class OpenAIServing:
f"This model's maximum context length is " f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested " f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding " f"{token_num} tokens in the input for embedding "
f"generation. Please reduce the length of the input.", ) f"generation. Please reduce the length of the input.")
return input_ids, input_text return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation # and does not require model context length validation
if isinstance(request, (TokenizeRequest, DetokenizeRequest)): if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
return input_ids, input_text DetokenizeRequest)):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
if request.max_tokens is None: if request.max_tokens is None:
if token_num >= self.max_model_len: if token_num >= self.max_model_len:
@ -225,7 +253,7 @@ class OpenAIServing:
f"This model's maximum context length is " f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested " f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, " f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.", ) f"Please reduce the length of the messages.")
request.max_tokens = self.max_model_len - token_num request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len: if token_num + request.max_tokens > self.max_model_len:
@ -235,13 +263,132 @@ class OpenAIServing:
f"{request.max_tokens + token_num} tokens " f"{request.max_tokens + token_num} tokens "
f"({token_num} in the messages, " f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). " f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", ) f"Please reduce the length of the messages or completion.")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _tokenize_prompt_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, List[int]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes single input.
"""
return next(
self._tokenize_prompt_inputs(
request,
tokenizer,
[prompt_input],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
))
def _tokenize_prompt_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes multiple inputs.
"""
for text in prompt_inputs:
if isinstance(text, str):
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=text,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else: else:
return input_ids, input_text yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=text,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _tokenize_prompt_input_or_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _log_inputs(
self,
request_id: str,
inputs: Union[str, List[int], TextTokensPrompt],
params: Optional[Union[SamplingParams, PoolingParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
if self.request_logger is None:
return
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = None
elif isinstance(inputs, list):
prompt = None
prompt_token_ids = inputs
else:
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
params=params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
@staticmethod @staticmethod
def _get_decoded_token(logprob: Logprob, token_id: int, def _get_decoded_token(
tokenizer: PreTrainedTokenizer) -> str: logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
) -> str:
if logprob.decoded_token is not None: if logprob.decoded_token is not None:
return logprob.decoded_token return logprob.decoded_token
return tokenizer.decode(token_id) return tokenizer.decode(token_id)

View File

@ -1,83 +1,135 @@
from typing import List, Optional from typing import List, Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template, load_chat_template,
parse_chat_message_content) parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (DetokenizeRequest, from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
ErrorResponse,
TokenizeChatRequest,
TokenizeRequest, TokenizeRequest,
TokenizeResponse) TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing) OpenAIServing)
from vllm.utils import random_uuid
class OpenAIServingTokenization(OpenAIServing): class OpenAIServingTokenization(OpenAIServing):
def __init__(self, def __init__(
self,
engine: AsyncLLMEngine, engine: AsyncLLMEngine,
model_config: ModelConfig, model_config: ModelConfig,
served_model_names: List[str], served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None, *,
chat_template: Optional[str] = None): lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine, super().__init__(engine=engine,
model_config=model_config, model_config=model_config,
served_model_names=served_model_names, served_model_names=served_model_names,
lora_modules=lora_modules) lora_modules=lora_modules,
prompt_adapters=None,
request_logger=request_logger)
# If this is None we use the tokenizer's default chat template # If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template) self.chat_template = load_chat_template(chat_template)
async def create_tokenize(self, async def create_tokenize(
request: TokenizeRequest) -> TokenizeResponse: self,
request: TokenizeRequest,
) -> Union[TokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
if not (request.prompt or request.messages): request_id = f"tokn-{random_uuid()}"
return self.create_error_response(
"Either `prompt` or `messages` should be provided.")
if (request.prompt and request.messages): (
return self.create_error_response( lora_request,
"Only one of `prompt` or `messages` should be provided.") prompt_adapter_request,
) = self._maybe_get_adapters(request)
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.engine.get_tokenizer(lora_request)
if request.messages:
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
conversation: List[ConversationMessage] = [] conversation: List[ConversationMessage] = []
for message in request.messages: for message in request.messages:
result = parse_chat_message_content(message, self.model_config, result = parse_chat_message_content(message, model_config,
tokenizer) tokenizer)
conversation.extend(result.messages) conversation.extend(result.messages)
request.prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
conversation=conversation, conversation=conversation,
tokenize=False, tokenize=False,
chat_template=self.chat_template) chat_template=self.chat_template)
assert isinstance(prompt, str)
else:
prompt = request.prompt
(input_ids, input_text) = await self._validate_prompt_and_tokenize( self._log_inputs(request_id,
prompt,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect tokenization
prompt_input = self._tokenize_prompt_input(
request, request,
tokenizer, tokenizer,
prompt=request.prompt, prompt,
add_special_tokens=request.add_special_tokens) add_special_tokens=request.add_special_tokens,
)
input_ids = prompt_input["prompt_token_ids"]
return TokenizeResponse(tokens=input_ids, return TokenizeResponse(tokens=input_ids,
count=len(input_ids), count=len(input_ids),
max_model_len=self.max_model_len) max_model_len=self.max_model_len)
async def create_detokenize( async def create_detokenize(
self, request: DetokenizeRequest) -> DetokenizeResponse: self,
request: DetokenizeRequest,
) -> Union[DetokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
_, lora_request = self._maybe_get_adapter(request) request_id = f"tokn-{random_uuid()}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request) tokenizer = await self.engine.get_tokenizer(lora_request)
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
request, tokenizer, prompt_ids=request.tokens) self._log_inputs(request_id,
request.tokens,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for tokenization")
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
request.tokens,
)
input_text = prompt_input["prompt"]
return DetokenizeResponse(prompt=input_text) return DetokenizeResponse(prompt=input_text)

View File

@ -1,6 +1,5 @@
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
PromptStrictInputs, TextPrompt, TextTokensPrompt, TextPrompt, TokensPrompt, parse_and_batch_prompt)
TokensPrompt, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
@ -14,6 +13,6 @@ See also:
__all__ = [ __all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs", "TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry" "InputContext", "InputRegistry"
] ]

View File

@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
""" """
class TextTokensPrompt(TypedDict): PromptInputs = Union[str, TextPrompt, TokensPrompt]
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
""" """
The inputs to the LLM, which can take one of the following forms: The inputs to the LLM, which can take one of the following forms:
@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
- A tokenized prompt (:class:`TokensPrompt`) - A tokenized prompt (:class:`TokensPrompt`)
""" """
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict): class LLMInputs(TypedDict):
""" """

View File

@ -5,7 +5,8 @@ import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union)
import torch import torch
@ -438,7 +439,7 @@ class SequenceGroup:
embeddings: Optional[List[float]] = None, embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None, pooling_params: Optional[PoolingParams] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id