mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 18:44:31 +08:00
[Renderer] Move Processor out of AsyncLLM (#24138)
Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
f376868620
commit
ff1daf6c8a
@ -122,6 +122,9 @@ def mock_serving_setup():
|
|||||||
models,
|
models,
|
||||||
request_logger=None)
|
request_logger=None)
|
||||||
|
|
||||||
|
serving_completion._process_inputs = AsyncMock(return_value=(MagicMock(
|
||||||
|
name="engine_request"), {}))
|
||||||
|
|
||||||
return mock_engine, serving_completion
|
return mock_engine, serving_completion
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import asyncio
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
@ -230,6 +230,7 @@ class MockHFConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MockModelConfig:
|
class MockModelConfig:
|
||||||
task = "generate"
|
task = "generate"
|
||||||
|
runner_type = "generate"
|
||||||
tokenizer = MODEL_NAME
|
tokenizer = MODEL_NAME
|
||||||
trust_remote_code = False
|
trust_remote_code = False
|
||||||
tokenizer_mode = "auto"
|
tokenizer_mode = "auto"
|
||||||
@ -244,11 +245,33 @@ class MockModelConfig:
|
|||||||
encoder_config = None
|
encoder_config = None
|
||||||
generation_config: str = "auto"
|
generation_config: str = "auto"
|
||||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
|
skip_tokenizer_init = False
|
||||||
|
|
||||||
def get_diff_sampling_param(self):
|
def get_diff_sampling_param(self):
|
||||||
return self.diff_sampling_param or {}
|
return self.diff_sampling_param or {}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_serving_chat(engine: AsyncLLM,
|
||||||
|
model_config: MockModelConfig) -> OpenAIServingChat:
|
||||||
|
models = OpenAIServingModels(engine_client=engine,
|
||||||
|
base_model_paths=BASE_MODEL_PATHS,
|
||||||
|
model_config=model_config)
|
||||||
|
serving_chat = OpenAIServingChat(engine,
|
||||||
|
model_config,
|
||||||
|
models,
|
||||||
|
response_role="assistant",
|
||||||
|
chat_template=CHAT_TEMPLATE,
|
||||||
|
chat_template_content_format="auto",
|
||||||
|
request_logger=None)
|
||||||
|
|
||||||
|
async def _fake_process_inputs(request_id, engine_prompt, sampling_params,
|
||||||
|
*, lora_request, trace_headers, priority):
|
||||||
|
return dict(engine_prompt), {}
|
||||||
|
|
||||||
|
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||||
|
return serving_chat
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockEngine:
|
class MockEngine:
|
||||||
|
|
||||||
@ -282,16 +305,7 @@ async def test_serving_chat_returns_correct_model_name():
|
|||||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
|
|
||||||
models = OpenAIServingModels(engine_client=mock_engine,
|
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
|
||||||
base_model_paths=BASE_MODEL_PATHS,
|
|
||||||
model_config=MockModelConfig())
|
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
|
||||||
MockModelConfig(),
|
|
||||||
models,
|
|
||||||
response_role="assistant",
|
|
||||||
chat_template=CHAT_TEMPLATE,
|
|
||||||
chat_template_content_format="auto",
|
|
||||||
request_logger=None)
|
|
||||||
messages = [{"role": "user", "content": "what is 1+1?"}]
|
messages = [{"role": "user", "content": "what is 1+1?"}]
|
||||||
|
|
||||||
async def return_model_name(*args):
|
async def return_model_name(*args):
|
||||||
@ -318,16 +332,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
|
|
||||||
models = OpenAIServingModels(engine_client=mock_engine,
|
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
|
||||||
base_model_paths=BASE_MODEL_PATHS,
|
|
||||||
model_config=MockModelConfig())
|
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
|
||||||
MockModelConfig(),
|
|
||||||
models,
|
|
||||||
response_role="assistant",
|
|
||||||
chat_template=CHAT_TEMPLATE,
|
|
||||||
chat_template_content_format="auto",
|
|
||||||
request_logger=None)
|
|
||||||
|
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
@ -361,16 +366,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
|
|
||||||
# Initialize the serving chat
|
# Initialize the serving chat
|
||||||
models = OpenAIServingModels(engine_client=mock_engine,
|
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||||
base_model_paths=BASE_MODEL_PATHS,
|
|
||||||
model_config=mock_model_config)
|
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
|
||||||
mock_model_config,
|
|
||||||
models,
|
|
||||||
response_role="assistant",
|
|
||||||
chat_template=CHAT_TEMPLATE,
|
|
||||||
chat_template_content_format="auto",
|
|
||||||
request_logger=None)
|
|
||||||
|
|
||||||
# Test Case 1: No max_tokens specified in request
|
# Test Case 1: No max_tokens specified in request
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
@ -415,16 +411,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
|||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
|
|
||||||
# Initialize the serving chat
|
# Initialize the serving chat
|
||||||
models = OpenAIServingModels(engine_client=mock_engine,
|
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||||
base_model_paths=BASE_MODEL_PATHS,
|
|
||||||
model_config=mock_model_config)
|
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
|
||||||
mock_model_config,
|
|
||||||
models,
|
|
||||||
response_role="assistant",
|
|
||||||
chat_template=CHAT_TEMPLATE,
|
|
||||||
chat_template_content_format="auto",
|
|
||||||
request_logger=None)
|
|
||||||
|
|
||||||
# Test case 1: No max_tokens specified, defaults to context_window
|
# Test case 1: No max_tokens specified, defaults to context_window
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
@ -471,16 +458,7 @@ async def test_serving_chat_could_load_correct_generation_config():
|
|||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
|
|
||||||
# Initialize the serving chat
|
# Initialize the serving chat
|
||||||
models = OpenAIServingModels(engine_client=mock_engine,
|
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||||
base_model_paths=BASE_MODEL_PATHS,
|
|
||||||
model_config=mock_model_config)
|
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
|
||||||
mock_model_config,
|
|
||||||
models,
|
|
||||||
response_role="assistant",
|
|
||||||
chat_template=CHAT_TEMPLATE,
|
|
||||||
chat_template_content_format="auto",
|
|
||||||
request_logger=None)
|
|
||||||
|
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
@ -525,17 +503,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
|||||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
mock_engine.errored = False
|
mock_engine.errored = False
|
||||||
|
|
||||||
# Initialize the serving chat
|
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||||
models = OpenAIServingModels(engine_client=mock_engine,
|
|
||||||
base_model_paths=BASE_MODEL_PATHS,
|
|
||||||
model_config=mock_model_config)
|
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
|
||||||
mock_model_config,
|
|
||||||
models,
|
|
||||||
response_role="assistant",
|
|
||||||
chat_template=CHAT_TEMPLATE,
|
|
||||||
chat_template_content_format="auto",
|
|
||||||
request_logger=None)
|
|
||||||
|
|
||||||
# Test cache_salt
|
# Test cache_salt
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
@ -549,10 +517,12 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
|||||||
# By default, cache_salt in the engine prompt is not set
|
# By default, cache_salt in the engine prompt is not set
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
await serving_chat.create_chat_completion(req)
|
await serving_chat.create_chat_completion(req)
|
||||||
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
|
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
|
||||||
|
assert "cache_salt" not in engine_prompt
|
||||||
|
|
||||||
# Test with certain cache_salt
|
# Test with certain cache_salt
|
||||||
req.cache_salt = "test_salt"
|
req.cache_salt = "test_salt"
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
await serving_chat.create_chat_completion(req)
|
await serving_chat.create_chat_completion(req)
|
||||||
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
|
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
|
||||||
|
assert engine_prompt.get("cache_salt") == "test_salt"
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
|
|||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import Device, collect_from_async_generator, random_uuid
|
from vllm.utils import Device, collect_from_async_generator, random_uuid
|
||||||
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -49,12 +50,16 @@ class EngineClient(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: Union[EngineCoreRequest, PromptType],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
*,
|
||||||
|
prompt_text: Optional[str] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
"""Generate outputs for a request."""
|
"""Generate outputs for a request."""
|
||||||
...
|
...
|
||||||
|
|||||||
@ -274,7 +274,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
prompt_text, _, _ = (self._get_prompt_components(
|
||||||
|
request_prompts[i]))
|
||||||
|
|
||||||
if self.default_sampling_params is None:
|
if self.default_sampling_params is None:
|
||||||
self.default_sampling_params = {}
|
self.default_sampling_params = {}
|
||||||
@ -285,6 +286,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
input_length=len(engine_prompt["prompt_token_ids"]),
|
input_length=len(engine_prompt["prompt_token_ids"]),
|
||||||
default_sampling_params=self.default_sampling_params)
|
default_sampling_params=self.default_sampling_params)
|
||||||
|
|
||||||
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
max_tokens, self.default_sampling_params)
|
max_tokens, self.default_sampling_params)
|
||||||
@ -309,13 +311,25 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
engine_request, tokenization_kwargs = (
|
||||||
|
await self._process_inputs(
|
||||||
|
request_id,
|
||||||
|
engine_prompt,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
priority=request.priority,
|
||||||
|
))
|
||||||
|
|
||||||
generator = self.engine_client.generate(
|
generator = self.engine_client.generate(
|
||||||
engine_prompt,
|
engine_request,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id,
|
request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from typing import Optional, Union, cast
|
|||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from typing_extensions import assert_never
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
@ -32,8 +31,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
|||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.renderer import RenderConfig
|
from vllm.entrypoints.renderer import RenderConfig
|
||||||
from vllm.entrypoints.utils import get_max_tokens
|
from vllm.entrypoints.utils import get_max_tokens
|
||||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
|
||||||
is_tokens_prompt)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logprobs import Logprob
|
from vllm.logprobs import Logprob
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
@ -157,23 +155,16 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
prompt_text, prompt_token_ids, prompt_embeds = (
|
||||||
# Mypy does not infer that engine_prompt will have only one of
|
self._get_prompt_components(engine_prompt))
|
||||||
# "prompt_token_ids" or "prompt_embeds" defined, and both of
|
|
||||||
# these as Union[object, the expected type], where it infers
|
input_length = None
|
||||||
# object if engine_prompt is a subclass of one of the
|
if prompt_token_ids is not None:
|
||||||
# typeddicts that defines both keys. Worse, because of
|
input_length = len(prompt_token_ids)
|
||||||
# https://github.com/python/mypy/issues/8586, mypy does not
|
elif prompt_embeds is not None:
|
||||||
# infer the type of engine_prompt correctly because of the
|
input_length = len(prompt_embeds)
|
||||||
# enumerate. So we need an unnecessary cast here.
|
|
||||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
|
||||||
engine_prompt)
|
|
||||||
if is_embeds_prompt(engine_prompt):
|
|
||||||
input_length = len(engine_prompt["prompt_embeds"])
|
|
||||||
elif is_tokens_prompt(engine_prompt):
|
|
||||||
input_length = len(engine_prompt["prompt_token_ids"])
|
|
||||||
else:
|
else:
|
||||||
assert_never(engine_prompt)
|
raise NotImplementedError
|
||||||
|
|
||||||
if self.default_sampling_params is None:
|
if self.default_sampling_params is None:
|
||||||
self.default_sampling_params = {}
|
self.default_sampling_params = {}
|
||||||
@ -185,6 +176,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
default_sampling_params=self.default_sampling_params,
|
default_sampling_params=self.default_sampling_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
max_tokens, self.default_sampling_params)
|
max_tokens, self.default_sampling_params)
|
||||||
@ -220,13 +212,25 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
engine_request, tokenization_kwargs = (
|
||||||
|
await self._process_inputs(
|
||||||
|
request_id_item,
|
||||||
|
engine_prompt,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
priority=request.priority,
|
||||||
|
))
|
||||||
|
|
||||||
generator = self.engine_client.generate(
|
generator = self.engine_client.generate(
|
||||||
engine_prompt,
|
engine_request,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id_item,
|
request_id_item,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
|
|||||||
@ -7,7 +7,8 @@ import traceback
|
|||||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
|
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, Optional,
|
||||||
|
TypeVar, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@ -15,6 +16,11 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from starlette.datastructures import Headers
|
from starlette.datastructures import Headers
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
|
from vllm.entrypoints.utils import _validate_truncation_size
|
||||||
|
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
|
||||||
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
|
from vllm.v1.engine.processor import Processor
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
else:
|
else:
|
||||||
@ -134,6 +140,12 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
|
|||||||
and "prompt_embeds" in prompt)
|
and "prompt_embeds" in prompt)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptComponents(NamedTuple):
|
||||||
|
text: Optional[str] = None
|
||||||
|
token_ids: Optional[list[int]] = None
|
||||||
|
embeds: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||||
|
|
||||||
|
|
||||||
@ -239,6 +251,16 @@ class OpenAIServing:
|
|||||||
AsyncMicrobatchTokenizer] = {}
|
AsyncMicrobatchTokenizer] = {}
|
||||||
self.log_error_stack = log_error_stack
|
self.log_error_stack = log_error_stack
|
||||||
|
|
||||||
|
async def _get_processor(self) -> Processor:
|
||||||
|
if not hasattr(self, "_processor"):
|
||||||
|
vllm_config = await self.engine_client.get_vllm_config()
|
||||||
|
if self.model_config.skip_tokenizer_init:
|
||||||
|
tokenizer = None
|
||||||
|
else:
|
||||||
|
tokenizer = init_tokenizer_from_configs(self.model_config)
|
||||||
|
self._processor = Processor(vllm_config, tokenizer)
|
||||||
|
return self._processor
|
||||||
|
|
||||||
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
|
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
|
||||||
"""
|
"""
|
||||||
Get a Renderer instance with the provided tokenizer.
|
Get a Renderer instance with the provided tokenizer.
|
||||||
@ -850,6 +872,36 @@ class OpenAIServing:
|
|||||||
|
|
||||||
return conversation, [request_prompt], [engine_prompt]
|
return conversation, [request_prompt], [engine_prompt]
|
||||||
|
|
||||||
|
async def _process_inputs(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
engine_prompt: PromptType,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
*,
|
||||||
|
lora_request: Optional[LoRARequest],
|
||||||
|
trace_headers: Optional[Mapping[str, str]],
|
||||||
|
priority: int,
|
||||||
|
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
using the Processor to process inputs for AsyncLLM
|
||||||
|
"""
|
||||||
|
tokenization_kwargs: dict[str, Any] = {}
|
||||||
|
_validate_truncation_size(self.max_model_len,
|
||||||
|
sampling_params.truncate_prompt_tokens,
|
||||||
|
tokenization_kwargs)
|
||||||
|
|
||||||
|
processor = await self._get_processor()
|
||||||
|
engine_request = processor.process_inputs(
|
||||||
|
request_id,
|
||||||
|
engine_prompt,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
return engine_request, tokenization_kwargs
|
||||||
|
|
||||||
async def _generate_with_builtin_tools(
|
async def _generate_with_builtin_tools(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -861,6 +913,7 @@ class OpenAIServing:
|
|||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
prompt_text, _, _ = self._get_prompt_components(request_prompt)
|
||||||
orig_priority = priority
|
orig_priority = priority
|
||||||
while True:
|
while True:
|
||||||
self._log_inputs(
|
self._log_inputs(
|
||||||
@ -869,14 +922,27 @@ class OpenAIServing:
|
|||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
generator = self.engine_client.generate(
|
trace_headers = kwargs.get("trace_headers")
|
||||||
|
engine_request, tokenization_kwargs = (await self._process_inputs(
|
||||||
|
request_id,
|
||||||
engine_prompt,
|
engine_prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
priority=priority,
|
||||||
|
))
|
||||||
|
|
||||||
|
generator = self.engine_client.generate(
|
||||||
|
engine_request,
|
||||||
|
sampling_params,
|
||||||
request_id,
|
request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for res in generator:
|
async for res in generator:
|
||||||
context.append_output(res)
|
context.append_output(res)
|
||||||
# NOTE(woosuk): The stop condition is handled by the engine.
|
# NOTE(woosuk): The stop condition is handled by the engine.
|
||||||
@ -905,6 +971,28 @@ class OpenAIServing:
|
|||||||
# OPTIMIZATION
|
# OPTIMIZATION
|
||||||
priority = orig_priority - 1
|
priority = orig_priority - 1
|
||||||
|
|
||||||
|
def _get_prompt_components(
|
||||||
|
self,
|
||||||
|
inputs: Union[RequestPrompt, PromptType],
|
||||||
|
) -> PromptComponents:
|
||||||
|
if isinstance(inputs, str):
|
||||||
|
return PromptComponents(text=inputs)
|
||||||
|
if isinstance(inputs, list):
|
||||||
|
return PromptComponents(token_ids=inputs)
|
||||||
|
if isinstance(inputs, dict):
|
||||||
|
return PromptComponents(
|
||||||
|
text=inputs.get("prompt"), # type: ignore[arg-type]
|
||||||
|
token_ids=inputs.get(
|
||||||
|
"prompt_token_ids"), # type: ignore[arg-type]
|
||||||
|
embeds=inputs.get("prompt_embeds"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return PromptComponents(
|
||||||
|
text=getattr(inputs, "prompt", None),
|
||||||
|
token_ids=getattr(inputs, "prompt_token_ids", None),
|
||||||
|
embeds=getattr(inputs, "prompt_embeds", None),
|
||||||
|
)
|
||||||
|
|
||||||
def _log_inputs(
|
def _log_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -915,14 +1003,9 @@ class OpenAIServing:
|
|||||||
) -> None:
|
) -> None:
|
||||||
if self.request_logger is None:
|
if self.request_logger is None:
|
||||||
return
|
return
|
||||||
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
|
||||||
if isinstance(inputs, str):
|
prompt, prompt_token_ids, prompt_embeds = (
|
||||||
prompt = inputs
|
self._get_prompt_components(inputs))
|
||||||
elif isinstance(inputs, list):
|
|
||||||
prompt_token_ids = inputs
|
|
||||||
else:
|
|
||||||
prompt = getattr(inputs, 'prompt', None)
|
|
||||||
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
|
|
||||||
|
|
||||||
self.request_logger.log_inputs(
|
self.request_logger.log_inputs(
|
||||||
request_id,
|
request_id,
|
||||||
|
|||||||
@ -261,7 +261,7 @@ class AsyncLLM(EngineClient):
|
|||||||
async def add_request(
|
async def add_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: PromptType,
|
prompt: Union[EngineCoreRequest, PromptType],
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -269,6 +269,7 @@ class AsyncLLM(EngineClient):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
data_parallel_rank: Optional[int] = None,
|
data_parallel_rank: Optional[int] = None,
|
||||||
|
prompt_text: Optional[str] = None,
|
||||||
) -> RequestOutputCollector:
|
) -> RequestOutputCollector:
|
||||||
"""Add new request to the AsyncLLM."""
|
"""Add new request to the AsyncLLM."""
|
||||||
|
|
||||||
@ -281,13 +282,20 @@ class AsyncLLM(EngineClient):
|
|||||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
queue = RequestOutputCollector(output_kind=params.output_kind)
|
||||||
|
|
||||||
# Convert Input --> Request.
|
# Convert Input --> Request.
|
||||||
request = self.processor.process_inputs(request_id, prompt, params,
|
if isinstance(prompt, EngineCoreRequest):
|
||||||
arrival_time, lora_request,
|
request = prompt
|
||||||
tokenization_kwargs,
|
else:
|
||||||
trace_headers, priority,
|
assert prompt_text is None
|
||||||
data_parallel_rank)
|
logger.warning_once(
|
||||||
prompt_text = prompt if isinstance(prompt,
|
"Processor has been moved under OpenAIServing and will "
|
||||||
str) else prompt.get("prompt")
|
"be removed from AsyncLLM in v0.13.")
|
||||||
|
request = self.processor.process_inputs(request_id, prompt, params,
|
||||||
|
arrival_time, lora_request,
|
||||||
|
tokenization_kwargs,
|
||||||
|
trace_headers, priority,
|
||||||
|
data_parallel_rank)
|
||||||
|
prompt_text = (prompt if isinstance(prompt, str) else
|
||||||
|
prompt.get("prompt"))
|
||||||
|
|
||||||
if is_pooling or params.n == 1:
|
if is_pooling or params.n == 1:
|
||||||
await self._add_request(request, prompt_text, None, 0, queue)
|
await self._add_request(request, prompt_text, None, 0, queue)
|
||||||
@ -332,10 +340,13 @@ class AsyncLLM(EngineClient):
|
|||||||
# re-multiplexed in the API server anyhow.
|
# re-multiplexed in the API server anyhow.
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: Union[EngineCoreRequest, PromptType],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
*,
|
||||||
|
prompt_text: Optional[str] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
data_parallel_rank: Optional[int] = None,
|
data_parallel_rank: Optional[int] = None,
|
||||||
@ -368,25 +379,25 @@ class AsyncLLM(EngineClient):
|
|||||||
# to handle startup failure gracefully in the OpenAI server.
|
# to handle startup failure gracefully in the OpenAI server.
|
||||||
self._run_output_handler()
|
self._run_output_handler()
|
||||||
|
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
if tokenization_kwargs is None:
|
||||||
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
tokenization_kwargs = {}
|
||||||
|
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
||||||
|
|
||||||
_validate_truncation_size(
|
_validate_truncation_size(
|
||||||
self.model_config.max_model_len,
|
self.model_config.max_model_len,
|
||||||
truncate_prompt_tokens,
|
truncate_prompt_tokens,
|
||||||
tokenization_kwargs,
|
tokenization_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
q = await self.add_request(
|
q = await self.add_request(request_id,
|
||||||
request_id,
|
prompt,
|
||||||
prompt,
|
sampling_params,
|
||||||
sampling_params,
|
lora_request=lora_request,
|
||||||
lora_request=lora_request,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
data_parallel_rank=data_parallel_rank,
|
||||||
data_parallel_rank=data_parallel_rank,
|
prompt_text=prompt_text)
|
||||||
)
|
|
||||||
|
|
||||||
# The output_handler task pushes items into the queue.
|
# The output_handler task pushes items into the queue.
|
||||||
# This task pulls from the queue and yields to caller.
|
# This task pulls from the queue and yields to caller.
|
||||||
@ -535,7 +546,7 @@ class AsyncLLM(EngineClient):
|
|||||||
self._run_output_handler()
|
self._run_output_handler()
|
||||||
|
|
||||||
if tokenization_kwargs is None:
|
if tokenization_kwargs is None:
|
||||||
tokenization_kwargs = dict[str, Any]()
|
tokenization_kwargs = {}
|
||||||
_validate_truncation_size(
|
_validate_truncation_size(
|
||||||
self.model_config.max_model_len,
|
self.model_config.max_model_len,
|
||||||
truncate_prompt_tokens,
|
truncate_prompt_tokens,
|
||||||
@ -547,9 +558,9 @@ class AsyncLLM(EngineClient):
|
|||||||
prompt,
|
prompt,
|
||||||
pooling_params,
|
pooling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# The output_handler task pushes items into the queue.
|
# The output_handler task pushes items into the queue.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user