[Renderer] Move Processor out of AsyncLLM (#24138)

Signed-off-by: Yang <lymailforjob@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Yang Liu 2025-10-03 04:29:45 -07:00 committed by yewentao256
parent f376868620
commit ff1daf6c8a
7 changed files with 215 additions and 125 deletions

View File

@ -122,6 +122,9 @@ def mock_serving_setup():
models,
request_logger=None)
serving_completion._process_inputs = AsyncMock(return_value=(MagicMock(
name="engine_request"), {}))
return mock_engine, serving_completion

View File

@ -7,7 +7,7 @@ import asyncio
from contextlib import suppress
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
@ -230,6 +230,7 @@ class MockHFConfig:
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
@ -244,11 +245,33 @@ class MockModelConfig:
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_serving_chat(engine: AsyncLLM,
model_config: MockModelConfig) -> OpenAIServingChat:
models = OpenAIServingModels(engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=model_config)
serving_chat = OpenAIServingChat(engine,
model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
async def _fake_process_inputs(request_id, engine_prompt, sampling_params,
*, lora_request, trace_headers, priority):
return dict(engine_prompt), {}
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_chat
@dataclass
class MockEngine:
@ -282,16 +305,7 @@ async def test_serving_chat_returns_correct_model_name():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
messages = [{"role": "user", "content": "what is 1+1?"}]
async def return_model_name(*args):
@ -318,16 +332,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
req = ChatCompletionRequest(
model=MODEL_NAME,
@ -361,16 +366,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
# Test Case 1: No max_tokens specified in request
req = ChatCompletionRequest(
@ -415,16 +411,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
# Test case 1: No max_tokens specified, defaults to context_window
req = ChatCompletionRequest(
@ -471,16 +458,7 @@ async def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
req = ChatCompletionRequest(
model=MODEL_NAME,
@ -525,17 +503,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
# Test cache_salt
req = ChatCompletionRequest(
@ -549,10 +517,12 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
# By default, cache_salt in the engine prompt is not set
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
assert "cache_salt" not in engine_prompt
# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
assert engine_prompt.get("cache_salt") == "test_salt"

View File

@ -19,6 +19,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tasks import SupportedTask
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import Device, collect_from_async_generator, random_uuid
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
@ -49,12 +50,16 @@ class EngineClient(ABC):
@abstractmethod
def generate(
self,
prompt: PromptType,
prompt: Union[EngineCoreRequest, PromptType],
sampling_params: SamplingParams,
request_id: str,
*,
prompt_text: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
...

View File

@ -274,7 +274,8 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
prompt_text, _, _ = (self._get_prompt_components(
request_prompts[i]))
if self.default_sampling_params is None:
self.default_sampling_params = {}
@ -285,6 +286,7 @@ class OpenAIServingChat(OpenAIServing):
input_length=len(engine_prompt["prompt_token_ids"]),
default_sampling_params=self.default_sampling_params)
sampling_params: Union[SamplingParams, BeamSearchParams]
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params)
@ -309,13 +311,25 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request,
)
else:
engine_request, tokenization_kwargs = (
await self._process_inputs(
request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
))
generator = self.engine_client.generate(
engine_prompt,
engine_request,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
)
generators.append(generator)

View File

@ -9,7 +9,6 @@ from typing import Optional, Union, cast
import jinja2
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
@ -32,8 +31,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt)
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
@ -157,23 +155,16 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
# Mypy does not infer that engine_prompt will have only one of
# "prompt_token_ids" or "prompt_embeds" defined, and both of
# these as Union[object, the expected type], where it infers
# object if engine_prompt is a subclass of one of the
# typeddicts that defines both keys. Worse, because of
# https://github.com/python/mypy/issues/8586, mypy does not
# infer the type of engine_prompt correctly because of the
# enumerate. So we need an unnecessary cast here.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
engine_prompt)
if is_embeds_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_embeds"])
elif is_tokens_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_token_ids"])
prompt_text, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(engine_prompt))
input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
assert_never(engine_prompt)
raise NotImplementedError
if self.default_sampling_params is None:
self.default_sampling_params = {}
@ -185,6 +176,7 @@ class OpenAIServingCompletion(OpenAIServing):
default_sampling_params=self.default_sampling_params,
)
sampling_params: Union[SamplingParams, BeamSearchParams]
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params)
@ -220,13 +212,25 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
)
else:
engine_request, tokenization_kwargs = (
await self._process_inputs(
request_id_item,
engine_prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
))
generator = self.engine_client.generate(
engine_prompt,
engine_request,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
)
generators.append(generator)

View File

@ -7,7 +7,8 @@ import traceback
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, Optional,
TypeVar, Union)
import torch
from fastapi import Request
@ -15,6 +16,11 @@ from pydantic import BaseModel, ConfigDict, Field
from starlette.datastructures import Headers
from typing_extensions import TypeIs
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.processor import Processor
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
@ -134,6 +140,12 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
and "prompt_embeds" in prompt)
class PromptComponents(NamedTuple):
text: Optional[str] = None
token_ids: Optional[list[int]] = None
embeds: Optional[torch.Tensor] = None
RequestT = TypeVar("RequestT", bound=AnyRequest)
@ -239,6 +251,16 @@ class OpenAIServing:
AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
async def _get_processor(self) -> Processor:
if not hasattr(self, "_processor"):
vllm_config = await self.engine_client.get_vllm_config()
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_configs(self.model_config)
self._processor = Processor(vllm_config, tokenizer)
return self._processor
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
@ -850,6 +872,36 @@ class OpenAIServing:
return conversation, [request_prompt], [engine_prompt]
async def _process_inputs(
self,
request_id: str,
engine_prompt: PromptType,
sampling_params: SamplingParams,
*,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]],
priority: int,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""
using the Processor to process inputs for AsyncLLM
"""
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.max_model_len,
sampling_params.truncate_prompt_tokens,
tokenization_kwargs)
processor = await self._get_processor()
engine_request = processor.process_inputs(
request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
)
return engine_request, tokenization_kwargs
async def _generate_with_builtin_tools(
self,
request_id: str,
@ -861,6 +913,7 @@ class OpenAIServing:
priority: int = 0,
**kwargs,
):
prompt_text, _, _ = self._get_prompt_components(request_prompt)
orig_priority = priority
while True:
self._log_inputs(
@ -869,14 +922,27 @@ class OpenAIServing:
params=sampling_params,
lora_request=lora_request,
)
generator = self.engine_client.generate(
trace_headers = kwargs.get("trace_headers")
engine_request, tokenization_kwargs = (await self._process_inputs(
request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
))
generator = self.engine_client.generate(
engine_request,
sampling_params,
request_id,
lora_request=lora_request,
priority=priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
**kwargs,
)
async for res in generator:
context.append_output(res)
# NOTE(woosuk): The stop condition is handled by the engine.
@ -905,6 +971,28 @@ class OpenAIServing:
# OPTIMIZATION
priority = orig_priority - 1
def _get_prompt_components(
self,
inputs: Union[RequestPrompt, PromptType],
) -> PromptComponents:
if isinstance(inputs, str):
return PromptComponents(text=inputs)
if isinstance(inputs, list):
return PromptComponents(token_ids=inputs)
if isinstance(inputs, dict):
return PromptComponents(
text=inputs.get("prompt"), # type: ignore[arg-type]
token_ids=inputs.get(
"prompt_token_ids"), # type: ignore[arg-type]
embeds=inputs.get("prompt_embeds"),
)
return PromptComponents(
text=getattr(inputs, "prompt", None),
token_ids=getattr(inputs, "prompt_token_ids", None),
embeds=getattr(inputs, "prompt_embeds", None),
)
def _log_inputs(
self,
request_id: str,
@ -915,14 +1003,9 @@ class OpenAIServing:
) -> None:
if self.request_logger is None:
return
prompt, prompt_token_ids, prompt_embeds = None, None, None
if isinstance(inputs, str):
prompt = inputs
elif isinstance(inputs, list):
prompt_token_ids = inputs
else:
prompt = getattr(inputs, 'prompt', None)
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
prompt, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(inputs))
self.request_logger.log_inputs(
request_id,

View File

@ -261,7 +261,7 @@ class AsyncLLM(EngineClient):
async def add_request(
self,
request_id: str,
prompt: PromptType,
prompt: Union[EngineCoreRequest, PromptType],
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
@ -269,6 +269,7 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
prompt_text: Optional[str] = None,
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
@ -281,13 +282,20 @@ class AsyncLLM(EngineClient):
queue = RequestOutputCollector(output_kind=params.output_kind)
# Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
tokenization_kwargs,
trace_headers, priority,
data_parallel_rank)
prompt_text = prompt if isinstance(prompt,
str) else prompt.get("prompt")
if isinstance(prompt, EngineCoreRequest):
request = prompt
else:
assert prompt_text is None
logger.warning_once(
"Processor has been moved under OpenAIServing and will "
"be removed from AsyncLLM in v0.13.")
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
tokenization_kwargs,
trace_headers, priority,
data_parallel_rank)
prompt_text = (prompt if isinstance(prompt, str) else
prompt.get("prompt"))
if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue)
@ -332,10 +340,13 @@ class AsyncLLM(EngineClient):
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: PromptType,
prompt: Union[EngineCoreRequest, PromptType],
sampling_params: SamplingParams,
request_id: str,
*,
prompt_text: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
tokenization_kwargs: Optional[dict[str, Any]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
@ -368,25 +379,25 @@ class AsyncLLM(EngineClient):
# to handle startup failure gracefully in the OpenAI server.
self._run_output_handler()
tokenization_kwargs: dict[str, Any] = {}
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
if tokenization_kwargs is None:
tokenization_kwargs = {}
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
_validate_truncation_size(
self.model_config.max_model_len,
truncate_prompt_tokens,
tokenization_kwargs,
)
_validate_truncation_size(
self.model_config.max_model_len,
truncate_prompt_tokens,
tokenization_kwargs,
)
q = await self.add_request(
request_id,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
)
q = await self.add_request(request_id,
prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
prompt_text=prompt_text)
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
@ -535,7 +546,7 @@ class AsyncLLM(EngineClient):
self._run_output_handler()
if tokenization_kwargs is None:
tokenization_kwargs = dict[str, Any]()
tokenization_kwargs = {}
_validate_truncation_size(
self.model_config.max_model_len,
truncate_prompt_tokens,
@ -547,9 +558,9 @@ class AsyncLLM(EngineClient):
prompt,
pooling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
)
# The output_handler task pushes items into the queue.