mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 22:36:01 +08:00
[Renderer] Move Processor out of AsyncLLM (#24138)
Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
f376868620
commit
ff1daf6c8a
@ -122,6 +122,9 @@ def mock_serving_setup():
|
||||
models,
|
||||
request_logger=None)
|
||||
|
||||
serving_completion._process_inputs = AsyncMock(return_value=(MagicMock(
|
||||
name="engine_request"), {}))
|
||||
|
||||
return mock_engine, serving_completion
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ import asyncio
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@ -230,6 +230,7 @@ class MockHFConfig:
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
task = "generate"
|
||||
runner_type = "generate"
|
||||
tokenizer = MODEL_NAME
|
||||
trust_remote_code = False
|
||||
tokenizer_mode = "auto"
|
||||
@ -244,11 +245,33 @@ class MockModelConfig:
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_serving_chat(engine: AsyncLLM,
|
||||
model_config: MockModelConfig) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=model_config)
|
||||
serving_chat = OpenAIServingChat(engine,
|
||||
model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
|
||||
async def _fake_process_inputs(request_id, engine_prompt, sampling_params,
|
||||
*, lora_request, trace_headers, priority):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
return serving_chat
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockEngine:
|
||||
|
||||
@ -282,16 +305,7 @@ async def test_serving_chat_returns_correct_model_name():
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=MockModelConfig())
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
MockModelConfig(),
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
|
||||
messages = [{"role": "user", "content": "what is 1+1?"}]
|
||||
|
||||
async def return_model_name(*args):
|
||||
@ -318,16 +332,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=MockModelConfig())
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
MockModelConfig(),
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@ -361,16 +366,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
|
||||
# Test Case 1: No max_tokens specified in request
|
||||
req = ChatCompletionRequest(
|
||||
@ -415,16 +411,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
|
||||
# Test case 1: No max_tokens specified, defaults to context_window
|
||||
req = ChatCompletionRequest(
|
||||
@ -471,16 +458,7 @@ async def test_serving_chat_could_load_correct_generation_config():
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@ -525,17 +503,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None)
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
|
||||
# Test cache_salt
|
||||
req = ChatCompletionRequest(
|
||||
@ -549,10 +517,12 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
# By default, cache_salt in the engine prompt is not set
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
|
||||
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
|
||||
assert "cache_salt" not in engine_prompt
|
||||
|
||||
# Test with certain cache_salt
|
||||
req.cache_salt = "test_salt"
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
|
||||
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
|
||||
assert engine_prompt.get("cache_salt") == "test_salt"
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Device, collect_from_async_generator, random_uuid
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -49,12 +50,16 @@ class EngineClient(ABC):
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
prompt: Union[EngineCoreRequest, PromptType],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
*,
|
||||
prompt_text: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request."""
|
||||
...
|
||||
|
||||
@ -274,7 +274,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
prompt_text, _, _ = (self._get_prompt_components(
|
||||
request_prompts[i]))
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
@ -285,6 +286,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
input_length=len(engine_prompt["prompt_token_ids"]),
|
||||
default_sampling_params=self.default_sampling_params)
|
||||
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
max_tokens, self.default_sampling_params)
|
||||
@ -309,13 +311,25 @@ class OpenAIServingChat(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
engine_request, tokenization_kwargs = (
|
||||
await self._process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
))
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
engine_request,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -32,8 +31,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||
is_tokens_prompt)
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
@ -157,23 +155,16 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
# Mypy does not infer that engine_prompt will have only one of
|
||||
# "prompt_token_ids" or "prompt_embeds" defined, and both of
|
||||
# these as Union[object, the expected type], where it infers
|
||||
# object if engine_prompt is a subclass of one of the
|
||||
# typeddicts that defines both keys. Worse, because of
|
||||
# https://github.com/python/mypy/issues/8586, mypy does not
|
||||
# infer the type of engine_prompt correctly because of the
|
||||
# enumerate. So we need an unnecessary cast here.
|
||||
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
||||
engine_prompt)
|
||||
if is_embeds_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_embeds"])
|
||||
elif is_tokens_prompt(engine_prompt):
|
||||
input_length = len(engine_prompt["prompt_token_ids"])
|
||||
prompt_text, prompt_token_ids, prompt_embeds = (
|
||||
self._get_prompt_components(engine_prompt))
|
||||
|
||||
input_length = None
|
||||
if prompt_token_ids is not None:
|
||||
input_length = len(prompt_token_ids)
|
||||
elif prompt_embeds is not None:
|
||||
input_length = len(prompt_embeds)
|
||||
else:
|
||||
assert_never(engine_prompt)
|
||||
raise NotImplementedError
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
@ -185,6 +176,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
default_sampling_params=self.default_sampling_params,
|
||||
)
|
||||
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
max_tokens, self.default_sampling_params)
|
||||
@ -220,13 +212,25 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
engine_request, tokenization_kwargs = (
|
||||
await self._process_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
))
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
engine_request,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
@ -7,7 +7,8 @@ import traceback
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
|
||||
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, Optional,
|
||||
TypeVar, Union)
|
||||
|
||||
import torch
|
||||
from fastapi import Request
|
||||
@ -15,6 +16,11 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
@ -134,6 +140,12 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
|
||||
and "prompt_embeds" in prompt)
|
||||
|
||||
|
||||
class PromptComponents(NamedTuple):
|
||||
text: Optional[str] = None
|
||||
token_ids: Optional[list[int]] = None
|
||||
embeds: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||
|
||||
|
||||
@ -239,6 +251,16 @@ class OpenAIServing:
|
||||
AsyncMicrobatchTokenizer] = {}
|
||||
self.log_error_stack = log_error_stack
|
||||
|
||||
async def _get_processor(self) -> Processor:
|
||||
if not hasattr(self, "_processor"):
|
||||
vllm_config = await self.engine_client.get_vllm_config()
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = init_tokenizer_from_configs(self.model_config)
|
||||
self._processor = Processor(vllm_config, tokenizer)
|
||||
return self._processor
|
||||
|
||||
def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer:
|
||||
"""
|
||||
Get a Renderer instance with the provided tokenizer.
|
||||
@ -850,6 +872,36 @@ class OpenAIServing:
|
||||
|
||||
return conversation, [request_prompt], [engine_prompt]
|
||||
|
||||
async def _process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
engine_prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
*,
|
||||
lora_request: Optional[LoRARequest],
|
||||
trace_headers: Optional[Mapping[str, str]],
|
||||
priority: int,
|
||||
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||
"""
|
||||
using the Processor to process inputs for AsyncLLM
|
||||
"""
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.max_model_len,
|
||||
sampling_params.truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
processor = await self._get_processor()
|
||||
engine_request = processor.process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
return engine_request, tokenization_kwargs
|
||||
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -861,6 +913,7 @@ class OpenAIServing:
|
||||
priority: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
prompt_text, _, _ = self._get_prompt_components(request_prompt)
|
||||
orig_priority = priority
|
||||
while True:
|
||||
self._log_inputs(
|
||||
@ -869,14 +922,27 @@ class OpenAIServing:
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
generator = self.engine_client.generate(
|
||||
trace_headers = kwargs.get("trace_headers")
|
||||
engine_request, tokenization_kwargs = (await self._process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
))
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_request,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
priority=priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async for res in generator:
|
||||
context.append_output(res)
|
||||
# NOTE(woosuk): The stop condition is handled by the engine.
|
||||
@ -905,6 +971,28 @@ class OpenAIServing:
|
||||
# OPTIMIZATION
|
||||
priority = orig_priority - 1
|
||||
|
||||
def _get_prompt_components(
|
||||
self,
|
||||
inputs: Union[RequestPrompt, PromptType],
|
||||
) -> PromptComponents:
|
||||
if isinstance(inputs, str):
|
||||
return PromptComponents(text=inputs)
|
||||
if isinstance(inputs, list):
|
||||
return PromptComponents(token_ids=inputs)
|
||||
if isinstance(inputs, dict):
|
||||
return PromptComponents(
|
||||
text=inputs.get("prompt"), # type: ignore[arg-type]
|
||||
token_ids=inputs.get(
|
||||
"prompt_token_ids"), # type: ignore[arg-type]
|
||||
embeds=inputs.get("prompt_embeds"),
|
||||
)
|
||||
|
||||
return PromptComponents(
|
||||
text=getattr(inputs, "prompt", None),
|
||||
token_ids=getattr(inputs, "prompt_token_ids", None),
|
||||
embeds=getattr(inputs, "prompt_embeds", None),
|
||||
)
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -915,14 +1003,9 @@ class OpenAIServing:
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
elif isinstance(inputs, list):
|
||||
prompt_token_ids = inputs
|
||||
else:
|
||||
prompt = getattr(inputs, 'prompt', None)
|
||||
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
|
||||
|
||||
prompt, prompt_token_ids, prompt_embeds = (
|
||||
self._get_prompt_components(inputs))
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
|
||||
@ -261,7 +261,7 @@ class AsyncLLM(EngineClient):
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
prompt: Union[EngineCoreRequest, PromptType],
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@ -269,6 +269,7 @@ class AsyncLLM(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
prompt_text: Optional[str] = None,
|
||||
) -> RequestOutputCollector:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
@ -281,13 +282,20 @@ class AsyncLLM(EngineClient):
|
||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
||||
|
||||
# Convert Input --> Request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
tokenization_kwargs,
|
||||
trace_headers, priority,
|
||||
data_parallel_rank)
|
||||
prompt_text = prompt if isinstance(prompt,
|
||||
str) else prompt.get("prompt")
|
||||
if isinstance(prompt, EngineCoreRequest):
|
||||
request = prompt
|
||||
else:
|
||||
assert prompt_text is None
|
||||
logger.warning_once(
|
||||
"Processor has been moved under OpenAIServing and will "
|
||||
"be removed from AsyncLLM in v0.13.")
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
tokenization_kwargs,
|
||||
trace_headers, priority,
|
||||
data_parallel_rank)
|
||||
prompt_text = (prompt if isinstance(prompt, str) else
|
||||
prompt.get("prompt"))
|
||||
|
||||
if is_pooling or params.n == 1:
|
||||
await self._add_request(request, prompt_text, None, 0, queue)
|
||||
@ -332,10 +340,13 @@ class AsyncLLM(EngineClient):
|
||||
# re-multiplexed in the API server anyhow.
|
||||
async def generate(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
prompt: Union[EngineCoreRequest, PromptType],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
*,
|
||||
prompt_text: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
@ -368,25 +379,25 @@ class AsyncLLM(EngineClient):
|
||||
# to handle startup failure gracefully in the OpenAI server.
|
||||
self._run_output_handler()
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
||||
|
||||
_validate_truncation_size(
|
||||
self.model_config.max_model_len,
|
||||
truncate_prompt_tokens,
|
||||
tokenization_kwargs,
|
||||
)
|
||||
_validate_truncation_size(
|
||||
self.model_config.max_model_len,
|
||||
truncate_prompt_tokens,
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
q = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
q = await self.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
prompt_text=prompt_text)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
# This task pulls from the queue and yields to caller.
|
||||
@ -535,7 +546,7 @@ class AsyncLLM(EngineClient):
|
||||
self._run_output_handler()
|
||||
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = dict[str, Any]()
|
||||
tokenization_kwargs = {}
|
||||
_validate_truncation_size(
|
||||
self.model_config.max_model_len,
|
||||
truncate_prompt_tokens,
|
||||
@ -547,9 +558,9 @@ class AsyncLLM(EngineClient):
|
||||
prompt,
|
||||
pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user