diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index 0561158dcf65a..45aa2070d0a28 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -122,6 +122,9 @@ def mock_serving_setup(): models, request_logger=None) + serving_completion._process_inputs = AsyncMock(return_value=(MagicMock( + name="engine_request"), {})) + return mock_engine, serving_completion diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 07f39fe2b9bd0..81683854e177e 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -7,7 +7,7 @@ import asyncio from contextlib import suppress from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest import pytest_asyncio @@ -230,6 +230,7 @@ class MockHFConfig: @dataclass class MockModelConfig: task = "generate" + runner_type = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" @@ -244,11 +245,33 @@ class MockModelConfig: encoder_config = None generation_config: str = "auto" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) + skip_tokenizer_init = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} +def _build_serving_chat(engine: AsyncLLM, + model_config: MockModelConfig) -> OpenAIServingChat: + models = OpenAIServingModels(engine_client=engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=model_config) + serving_chat = OpenAIServingChat(engine, + model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + + async def _fake_process_inputs(request_id, engine_prompt, sampling_params, + *, lora_request, trace_headers, priority): + return dict(engine_prompt), {} + + serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs) + return serving_chat + + @dataclass class MockEngine: @@ -282,16 +305,7 @@ async def test_serving_chat_returns_correct_model_name(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=MockModelConfig()) - serving_chat = OpenAIServingChat(mock_engine, - MockModelConfig(), - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine, MockModelConfig()) messages = [{"role": "user", "content": "what is 1+1?"}] async def return_model_name(*args): @@ -318,16 +332,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=MockModelConfig()) - serving_chat = OpenAIServingChat(mock_engine, - MockModelConfig(), - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine, MockModelConfig()) req = ChatCompletionRequest( model=MODEL_NAME, @@ -361,16 +366,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.errored = False # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine, mock_model_config) # Test Case 1: No max_tokens specified in request req = ChatCompletionRequest( @@ -415,16 +411,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): mock_engine.errored = False # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine, mock_model_config) # Test case 1: No max_tokens specified, defaults to context_window req = ChatCompletionRequest( @@ -471,16 +458,7 @@ async def test_serving_chat_could_load_correct_generation_config(): mock_engine.errored = False # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine, mock_model_config) req = ChatCompletionRequest( model=MODEL_NAME, @@ -525,17 +503,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False - # Initialize the serving chat - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) - serving_chat = OpenAIServingChat(mock_engine, - mock_model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_chat = _build_serving_chat(mock_engine, mock_model_config) # Test cache_salt req = ChatCompletionRequest( @@ -549,10 +517,12 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): # By default, cache_salt in the engine prompt is not set with suppress(Exception): await serving_chat.create_chat_completion(req) - assert "cache_salt" not in mock_engine.generate.call_args.args[0] + engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1] + assert "cache_salt" not in engine_prompt # Test with certain cache_salt req.cache_salt = "test_salt" with suppress(Exception): await serving_chat.create_chat_completion(req) - assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt" + engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1] + assert engine_prompt.get("cache_salt") == "test_salt" diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 997c99af24089..bc917f2f57f0c 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -19,6 +19,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tasks import SupportedTask from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Device, collect_from_async_generator, random_uuid +from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) @@ -49,12 +50,16 @@ class EngineClient(ABC): @abstractmethod def generate( self, - prompt: PromptType, + prompt: Union[EngineCoreRequest, PromptType], sampling_params: SamplingParams, request_id: str, + *, + prompt_text: Optional[str] = None, lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request.""" ... diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2336158ac51ba..54eb60a8589de 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -274,7 +274,8 @@ class OpenAIServingChat(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - sampling_params: Union[SamplingParams, BeamSearchParams] + prompt_text, _, _ = (self._get_prompt_components( + request_prompts[i])) if self.default_sampling_params is None: self.default_sampling_params = {} @@ -285,6 +286,7 @@ class OpenAIServingChat(OpenAIServing): input_length=len(engine_prompt["prompt_token_ids"]), default_sampling_params=self.default_sampling_params) + sampling_params: Union[SamplingParams, BeamSearchParams] if request.use_beam_search: sampling_params = request.to_beam_search_params( max_tokens, self.default_sampling_params) @@ -309,13 +311,25 @@ class OpenAIServingChat(OpenAIServing): lora_request=lora_request, ) else: + engine_request, tokenization_kwargs = ( + await self._process_inputs( + request_id, + engine_prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + )) + generator = self.engine_client.generate( - engine_prompt, + engine_request, sampling_params, request_id, lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, + prompt_text=prompt_text, + tokenization_kwargs=tokenization_kwargs, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 0c61c48da0bc8..d0756e42b7963 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -9,7 +9,6 @@ from typing import Optional, Union, cast import jinja2 from fastapi import Request -from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -32,8 +31,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing, from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import get_max_tokens -from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, - is_tokens_prompt) +from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import RequestOutput @@ -157,23 +155,16 @@ class OpenAIServingCompletion(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - sampling_params: Union[SamplingParams, BeamSearchParams] - # Mypy does not infer that engine_prompt will have only one of - # "prompt_token_ids" or "prompt_embeds" defined, and both of - # these as Union[object, the expected type], where it infers - # object if engine_prompt is a subclass of one of the - # typeddicts that defines both keys. Worse, because of - # https://github.com/python/mypy/issues/8586, mypy does not - # infer the type of engine_prompt correctly because of the - # enumerate. So we need an unnecessary cast here. - engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], - engine_prompt) - if is_embeds_prompt(engine_prompt): - input_length = len(engine_prompt["prompt_embeds"]) - elif is_tokens_prompt(engine_prompt): - input_length = len(engine_prompt["prompt_token_ids"]) + prompt_text, prompt_token_ids, prompt_embeds = ( + self._get_prompt_components(engine_prompt)) + + input_length = None + if prompt_token_ids is not None: + input_length = len(prompt_token_ids) + elif prompt_embeds is not None: + input_length = len(prompt_embeds) else: - assert_never(engine_prompt) + raise NotImplementedError if self.default_sampling_params is None: self.default_sampling_params = {} @@ -185,6 +176,7 @@ class OpenAIServingCompletion(OpenAIServing): default_sampling_params=self.default_sampling_params, ) + sampling_params: Union[SamplingParams, BeamSearchParams] if request.use_beam_search: sampling_params = request.to_beam_search_params( max_tokens, self.default_sampling_params) @@ -220,13 +212,25 @@ class OpenAIServingCompletion(OpenAIServing): lora_request=lora_request, ) else: + engine_request, tokenization_kwargs = ( + await self._process_inputs( + request_id_item, + engine_prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + )) + generator = self.engine_client.generate( - engine_prompt, + engine_request, sampling_params, request_id_item, lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, + prompt_text=prompt_text, + tokenization_kwargs=tokenization_kwargs, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 4eb1f8b89d64f..dc41723800d0d 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -7,7 +7,8 @@ import traceback from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus -from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union +from typing import (Any, Callable, ClassVar, Generic, NamedTuple, Optional, + TypeVar, Union) import torch from fastapi import Request @@ -15,6 +16,11 @@ from pydantic import BaseModel, ConfigDict, Field from starlette.datastructures import Headers from typing_extensions import TypeIs +from vllm.entrypoints.utils import _validate_truncation_size +from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.processor import Processor + if sys.version_info >= (3, 12): from typing import TypedDict else: @@ -134,6 +140,12 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: and "prompt_embeds" in prompt) +class PromptComponents(NamedTuple): + text: Optional[str] = None + token_ids: Optional[list[int]] = None + embeds: Optional[torch.Tensor] = None + + RequestT = TypeVar("RequestT", bound=AnyRequest) @@ -239,6 +251,16 @@ class OpenAIServing: AsyncMicrobatchTokenizer] = {} self.log_error_stack = log_error_stack + async def _get_processor(self) -> Processor: + if not hasattr(self, "_processor"): + vllm_config = await self.engine_client.get_vllm_config() + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = init_tokenizer_from_configs(self.model_config) + self._processor = Processor(vllm_config, tokenizer) + return self._processor + def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer: """ Get a Renderer instance with the provided tokenizer. @@ -850,6 +872,36 @@ class OpenAIServing: return conversation, [request_prompt], [engine_prompt] + async def _process_inputs( + self, + request_id: str, + engine_prompt: PromptType, + sampling_params: SamplingParams, + *, + lora_request: Optional[LoRARequest], + trace_headers: Optional[Mapping[str, str]], + priority: int, + ) -> tuple[EngineCoreRequest, dict[str, Any]]: + """ + using the Processor to process inputs for AsyncLLM + """ + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size(self.max_model_len, + sampling_params.truncate_prompt_tokens, + tokenization_kwargs) + + processor = await self._get_processor() + engine_request = processor.process_inputs( + request_id, + engine_prompt, + sampling_params, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + trace_headers=trace_headers, + priority=priority, + ) + return engine_request, tokenization_kwargs + async def _generate_with_builtin_tools( self, request_id: str, @@ -861,6 +913,7 @@ class OpenAIServing: priority: int = 0, **kwargs, ): + prompt_text, _, _ = self._get_prompt_components(request_prompt) orig_priority = priority while True: self._log_inputs( @@ -869,14 +922,27 @@ class OpenAIServing: params=sampling_params, lora_request=lora_request, ) - generator = self.engine_client.generate( + trace_headers = kwargs.get("trace_headers") + engine_request, tokenization_kwargs = (await self._process_inputs( + request_id, engine_prompt, sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + )) + + generator = self.engine_client.generate( + engine_request, + sampling_params, request_id, lora_request=lora_request, priority=priority, + prompt_text=prompt_text, + tokenization_kwargs=tokenization_kwargs, **kwargs, ) + async for res in generator: context.append_output(res) # NOTE(woosuk): The stop condition is handled by the engine. @@ -905,6 +971,28 @@ class OpenAIServing: # OPTIMIZATION priority = orig_priority - 1 + def _get_prompt_components( + self, + inputs: Union[RequestPrompt, PromptType], + ) -> PromptComponents: + if isinstance(inputs, str): + return PromptComponents(text=inputs) + if isinstance(inputs, list): + return PromptComponents(token_ids=inputs) + if isinstance(inputs, dict): + return PromptComponents( + text=inputs.get("prompt"), # type: ignore[arg-type] + token_ids=inputs.get( + "prompt_token_ids"), # type: ignore[arg-type] + embeds=inputs.get("prompt_embeds"), + ) + + return PromptComponents( + text=getattr(inputs, "prompt", None), + token_ids=getattr(inputs, "prompt_token_ids", None), + embeds=getattr(inputs, "prompt_embeds", None), + ) + def _log_inputs( self, request_id: str, @@ -915,14 +1003,9 @@ class OpenAIServing: ) -> None: if self.request_logger is None: return - prompt, prompt_token_ids, prompt_embeds = None, None, None - if isinstance(inputs, str): - prompt = inputs - elif isinstance(inputs, list): - prompt_token_ids = inputs - else: - prompt = getattr(inputs, 'prompt', None) - prompt_token_ids = getattr(inputs, 'prompt_token_ids', None) + + prompt, prompt_token_ids, prompt_embeds = ( + self._get_prompt_components(inputs)) self.request_logger.log_inputs( request_id, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index e88b4c5346c30..ab3a4e5e6fe55 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -261,7 +261,7 @@ class AsyncLLM(EngineClient): async def add_request( self, request_id: str, - prompt: PromptType, + prompt: Union[EngineCoreRequest, PromptType], params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -269,6 +269,7 @@ class AsyncLLM(EngineClient): trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, + prompt_text: Optional[str] = None, ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" @@ -281,13 +282,20 @@ class AsyncLLM(EngineClient): queue = RequestOutputCollector(output_kind=params.output_kind) # Convert Input --> Request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - tokenization_kwargs, - trace_headers, priority, - data_parallel_rank) - prompt_text = prompt if isinstance(prompt, - str) else prompt.get("prompt") + if isinstance(prompt, EngineCoreRequest): + request = prompt + else: + assert prompt_text is None + logger.warning_once( + "Processor has been moved under OpenAIServing and will " + "be removed from AsyncLLM in v0.13.") + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + tokenization_kwargs, + trace_headers, priority, + data_parallel_rank) + prompt_text = (prompt if isinstance(prompt, str) else + prompt.get("prompt")) if is_pooling or params.n == 1: await self._add_request(request, prompt_text, None, 0, queue) @@ -332,10 +340,13 @@ class AsyncLLM(EngineClient): # re-multiplexed in the API server anyhow. async def generate( self, - prompt: PromptType, + prompt: Union[EngineCoreRequest, PromptType], sampling_params: SamplingParams, request_id: str, + *, + prompt_text: Optional[str] = None, lora_request: Optional[LoRARequest] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, @@ -368,25 +379,25 @@ class AsyncLLM(EngineClient): # to handle startup failure gracefully in the OpenAI server. self._run_output_handler() - tokenization_kwargs: dict[str, Any] = {} - truncate_prompt_tokens = sampling_params.truncate_prompt_tokens + if tokenization_kwargs is None: + tokenization_kwargs = {} + truncate_prompt_tokens = sampling_params.truncate_prompt_tokens - _validate_truncation_size( - self.model_config.max_model_len, - truncate_prompt_tokens, - tokenization_kwargs, - ) + _validate_truncation_size( + self.model_config.max_model_len, + truncate_prompt_tokens, + tokenization_kwargs, + ) - q = await self.add_request( - request_id, - prompt, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=priority, - tokenization_kwargs=tokenization_kwargs, - data_parallel_rank=data_parallel_rank, - ) + q = await self.add_request(request_id, + prompt, + sampling_params, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + trace_headers=trace_headers, + priority=priority, + data_parallel_rank=data_parallel_rank, + prompt_text=prompt_text) # The output_handler task pushes items into the queue. # This task pulls from the queue and yields to caller. @@ -535,7 +546,7 @@ class AsyncLLM(EngineClient): self._run_output_handler() if tokenization_kwargs is None: - tokenization_kwargs = dict[str, Any]() + tokenization_kwargs = {} _validate_truncation_size( self.model_config.max_model_len, truncate_prompt_tokens, @@ -547,9 +558,9 @@ class AsyncLLM(EngineClient): prompt, pooling_params, lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=priority, - tokenization_kwargs=tokenization_kwargs, ) # The output_handler task pushes items into the queue.