mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 12:49:07 +08:00
[Renderer] Move Processor out of LLMEngine (#26165)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
73a99cc2a5
commit
d78fda7cda
@ -37,6 +37,7 @@ from vllm.entrypoints.utils import (_validate_truncation_size,
|
|||||||
log_non_default_args)
|
log_non_default_args)
|
||||||
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
|
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
|
||||||
TokensPrompt)
|
TokensPrompt)
|
||||||
|
from vllm.inputs.parse import get_prompt_components
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -49,10 +50,13 @@ from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
|
|||||||
SamplingParams)
|
SamplingParams)
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||||
get_cached_tokenizer)
|
get_cached_tokenizer,
|
||||||
|
init_tokenizer_from_configs)
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Counter, Device, as_iter, is_list_of
|
from vllm.utils import Counter, Device, as_iter, is_list_of
|
||||||
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.llm_engine import LLMEngine
|
from vllm.v1.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -312,6 +316,10 @@ class LLM:
|
|||||||
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
|
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
|
||||||
io_processor_plugin)
|
io_processor_plugin)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_config(self):
|
||||||
|
return self.llm_engine.model_config
|
||||||
|
|
||||||
def get_tokenizer(self) -> AnyTokenizer:
|
def get_tokenizer(self) -> AnyTokenizer:
|
||||||
return self.llm_engine.get_tokenizer()
|
return self.llm_engine.get_tokenizer()
|
||||||
|
|
||||||
@ -324,6 +332,16 @@ class LLM:
|
|||||||
else:
|
else:
|
||||||
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
|
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
def _get_processor(self) -> Processor:
|
||||||
|
if not hasattr(self, "_processor"):
|
||||||
|
vllm_config = self.llm_engine.vllm_config
|
||||||
|
if self.model_config.skip_tokenizer_init:
|
||||||
|
tokenizer = None
|
||||||
|
else:
|
||||||
|
tokenizer = init_tokenizer_from_configs(self.model_config)
|
||||||
|
self._processor = Processor(vllm_config, tokenizer)
|
||||||
|
return self._processor
|
||||||
|
|
||||||
def get_default_sampling_params(self) -> SamplingParams:
|
def get_default_sampling_params(self) -> SamplingParams:
|
||||||
if self.default_sampling_params is None:
|
if self.default_sampling_params is None:
|
||||||
self.default_sampling_params = (
|
self.default_sampling_params = (
|
||||||
@ -1497,8 +1515,6 @@ class LLM:
|
|||||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||||
it = tqdm_func(it, desc="Adding requests")
|
it = tqdm_func(it, desc="Adding requests")
|
||||||
|
|
||||||
model_config = self.llm_engine.model_config
|
|
||||||
|
|
||||||
for i, prompt in enumerate(it):
|
for i, prompt in enumerate(it):
|
||||||
|
|
||||||
if isinstance(prompt, dict):
|
if isinstance(prompt, dict):
|
||||||
@ -1506,17 +1522,9 @@ class LLM:
|
|||||||
prompt.get("multi_modal_data"),
|
prompt.get("multi_modal_data"),
|
||||||
prompt.get("multi_modal_uuids"))
|
prompt.get("multi_modal_uuids"))
|
||||||
|
|
||||||
param = params[i] if isinstance(params, Sequence) else params
|
|
||||||
|
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
|
||||||
_validate_truncation_size(model_config.max_model_len,
|
|
||||||
param.truncate_prompt_tokens,
|
|
||||||
tokenization_kwargs)
|
|
||||||
|
|
||||||
self._add_request(
|
self._add_request(
|
||||||
prompt,
|
prompt,
|
||||||
params[i] if isinstance(params, Sequence) else params,
|
params[i] if isinstance(params, Sequence) else params,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
lora_request=lora_request[i] if isinstance(
|
lora_request=lora_request[i] if isinstance(
|
||||||
lora_request, Sequence) else lora_request,
|
lora_request, Sequence) else lora_request,
|
||||||
priority=priority[i] if priority else 0,
|
priority=priority[i] if priority else 0,
|
||||||
@ -1557,23 +1565,59 @@ class LLM:
|
|||||||
raise ValueError(f"Multi-modal data for {modality} is None"
|
raise ValueError(f"Multi-modal data for {modality} is None"
|
||||||
f" but UUID is not provided")
|
f" but UUID is not provided")
|
||||||
|
|
||||||
def _add_request(
|
def _process_inputs(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
request_id: str,
|
||||||
|
engine_prompt: PromptType,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
*,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest],
|
||||||
priority: int = 0,
|
priority: int,
|
||||||
) -> None:
|
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||||
request_id = str(next(self.request_counter))
|
"""Use the Processor to process inputs for LLMEngine."""
|
||||||
self.llm_engine.add_request(
|
tokenization_kwargs: dict[str, Any] = {}
|
||||||
|
_validate_truncation_size(self.model_config.max_model_len,
|
||||||
|
params.truncate_prompt_tokens,
|
||||||
|
tokenization_kwargs)
|
||||||
|
|
||||||
|
processor = self._get_processor()
|
||||||
|
engine_request = processor.process_inputs(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
engine_prompt,
|
||||||
params,
|
params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
)
|
)
|
||||||
|
return engine_request, tokenization_kwargs
|
||||||
|
|
||||||
|
def _add_request(
|
||||||
|
self,
|
||||||
|
prompt: PromptType,
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
) -> None:
|
||||||
|
prompt_text, _, _ = get_prompt_components(prompt)
|
||||||
|
request_id = str(next(self.request_counter))
|
||||||
|
|
||||||
|
engine_request, tokenization_kwargs = self._process_inputs(
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.llm_engine.add_request(
|
||||||
|
request_id,
|
||||||
|
engine_request,
|
||||||
|
params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
priority=priority,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
)
|
||||||
|
|
||||||
def _run_engine(
|
def _run_engine(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -7,8 +7,7 @@ import traceback
|
|||||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, Optional,
|
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
|
||||||
TypeVar, Union)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@ -69,6 +68,7 @@ from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
|
from vllm.inputs.parse import PromptComponents, get_prompt_components
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logprobs import Logprob, PromptLogprobs
|
from vllm.logprobs import Logprob, PromptLogprobs
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -140,12 +140,6 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
|
|||||||
and "prompt_embeds" in prompt)
|
and "prompt_embeds" in prompt)
|
||||||
|
|
||||||
|
|
||||||
class PromptComponents(NamedTuple):
|
|
||||||
text: Optional[str] = None
|
|
||||||
token_ids: Optional[list[int]] = None
|
|
||||||
embeds: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||||
|
|
||||||
|
|
||||||
@ -876,25 +870,23 @@ class OpenAIServing:
|
|||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
engine_prompt: PromptType,
|
engine_prompt: PromptType,
|
||||||
sampling_params: SamplingParams,
|
params: Union[SamplingParams, PoolingParams],
|
||||||
*,
|
*,
|
||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
trace_headers: Optional[Mapping[str, str]],
|
trace_headers: Optional[Mapping[str, str]],
|
||||||
priority: int,
|
priority: int,
|
||||||
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||||
"""
|
"""Use the Processor to process inputs for AsyncLLM."""
|
||||||
using the Processor to process inputs for AsyncLLM
|
|
||||||
"""
|
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
tokenization_kwargs: dict[str, Any] = {}
|
||||||
_validate_truncation_size(self.max_model_len,
|
_validate_truncation_size(self.max_model_len,
|
||||||
sampling_params.truncate_prompt_tokens,
|
params.truncate_prompt_tokens,
|
||||||
tokenization_kwargs)
|
tokenization_kwargs)
|
||||||
|
|
||||||
processor = await self._get_processor()
|
processor = await self._get_processor()
|
||||||
engine_request = processor.process_inputs(
|
engine_request = processor.process_inputs(
|
||||||
request_id,
|
request_id,
|
||||||
engine_prompt,
|
engine_prompt,
|
||||||
sampling_params,
|
params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
@ -973,25 +965,12 @@ class OpenAIServing:
|
|||||||
|
|
||||||
def _get_prompt_components(
|
def _get_prompt_components(
|
||||||
self,
|
self,
|
||||||
inputs: Union[RequestPrompt, PromptType],
|
prompt: Union[RequestPrompt, PromptType],
|
||||||
) -> PromptComponents:
|
) -> PromptComponents:
|
||||||
if isinstance(inputs, str):
|
if isinstance(prompt, list):
|
||||||
return PromptComponents(text=inputs)
|
return PromptComponents(token_ids=prompt)
|
||||||
if isinstance(inputs, list):
|
|
||||||
return PromptComponents(token_ids=inputs)
|
|
||||||
if isinstance(inputs, dict):
|
|
||||||
return PromptComponents(
|
|
||||||
text=inputs.get("prompt"), # type: ignore[arg-type]
|
|
||||||
token_ids=inputs.get(
|
|
||||||
"prompt_token_ids"), # type: ignore[arg-type]
|
|
||||||
embeds=inputs.get("prompt_embeds"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return PromptComponents(
|
return get_prompt_components(prompt) # type: ignore[arg-type]
|
||||||
text=getattr(inputs, "prompt", None),
|
|
||||||
token_ids=getattr(inputs, "prompt_token_ids", None),
|
|
||||||
embeds=getattr(inputs, "prompt_embeds", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _log_inputs(
|
def _log_inputs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal, Optional, TypedDict, Union, cast, overload
|
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
|
||||||
|
Union, cast, overload)
|
||||||
|
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
@ -11,6 +12,9 @@ from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
|
|||||||
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
|
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
|
||||||
TokensPrompt)
|
TokensPrompt)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class ParsedText(TypedDict):
|
class ParsedText(TypedDict):
|
||||||
content: str
|
content: str
|
||||||
@ -149,3 +153,23 @@ def split_enc_dec_inputs(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return None, inputs
|
return None, inputs
|
||||||
|
|
||||||
|
|
||||||
|
class PromptComponents(NamedTuple):
|
||||||
|
text: Optional[str] = None
|
||||||
|
token_ids: Optional[list[int]] = None
|
||||||
|
embeds: Optional["torch.Tensor"] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_components(prompt: PromptType) -> PromptComponents:
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
return PromptComponents(text=prompt)
|
||||||
|
|
||||||
|
if (encoder_prompt := prompt.get("encoder_prompt")):
|
||||||
|
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
return PromptComponents(
|
||||||
|
text=prompt.get("prompt"), # type: ignore[arg-type]
|
||||||
|
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||||
|
embeds=prompt.get("prompt_embeds"),
|
||||||
|
)
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
|||||||
init_tokenizer_from_configs)
|
init_tokenizer_from_configs)
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Device
|
from vllm.utils import Device
|
||||||
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import EngineCoreClient
|
||||||
from vllm.v1.engine.output_processor import OutputProcessor
|
from vllm.v1.engine.output_processor import OutputProcessor
|
||||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||||
@ -213,13 +214,14 @@ class LLMEngine:
|
|||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: PromptType,
|
prompt: Union[EngineCoreRequest, PromptType],
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
prompt_text: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Validate the request_id type.
|
# Validate the request_id type.
|
||||||
if not isinstance(request_id, str):
|
if not isinstance(request_id, str):
|
||||||
@ -227,12 +229,18 @@ class LLMEngine:
|
|||||||
f"request_id must be a string, got {type(request_id)}")
|
f"request_id must be a string, got {type(request_id)}")
|
||||||
|
|
||||||
# Process raw inputs into the request.
|
# Process raw inputs into the request.
|
||||||
request = self.processor.process_inputs(request_id, prompt, params,
|
if isinstance(prompt, EngineCoreRequest):
|
||||||
arrival_time, lora_request,
|
request = prompt
|
||||||
tokenization_kwargs,
|
else:
|
||||||
trace_headers, priority)
|
assert prompt_text is None
|
||||||
prompt_text = prompt if isinstance(prompt,
|
logger.warning_once("Processor has been moved under LLM and will "
|
||||||
str) else prompt.get("prompt")
|
"be removed from LLMEngine in v0.13.")
|
||||||
|
request = self.processor.process_inputs(request_id, prompt, params,
|
||||||
|
arrival_time, lora_request,
|
||||||
|
tokenization_kwargs,
|
||||||
|
trace_headers, priority)
|
||||||
|
prompt_text = (prompt if isinstance(prompt, str) else
|
||||||
|
prompt.get("prompt"))
|
||||||
|
|
||||||
n = params.n if isinstance(params, SamplingParams) else 1
|
n = params.n if isinstance(params, SamplingParams) else 1
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user