mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 03:26:31 +08:00
[Core] Factor out input preprocessing to a separate class (#7329)
This commit is contained in:
parent
8f44a92d85
commit
5ec9c0fb3c
@ -11,9 +11,10 @@ def test_skip_tokenizer_initialization(model: str):
|
||||
# token ids.
|
||||
llm = LLM(model=model, skip_tokenizer_init=True)
|
||||
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
|
||||
with pytest.raises(ValueError) as err:
|
||||
|
||||
with pytest.raises(ValueError, match="cannot pass text prompts when"):
|
||||
llm.generate("abc", sampling_params)
|
||||
assert "prompts must be None if" in str(err.value)
|
||||
|
||||
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
|
||||
sampling_params=sampling_params)
|
||||
assert len(outputs) > 0
|
||||
|
||||
@ -4,22 +4,17 @@ from functools import partial
|
||||
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
||||
PromptComponents, SchedulerOutputState)
|
||||
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -403,139 +398,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> List[int]:
|
||||
"""Async version of :meth:`_tokenize_prompt`."""
|
||||
tokenizer = self.get_tokenizer_group(
|
||||
missing_msg="prompts must be None if skip_tokenizer_init is True")
|
||||
|
||||
return await tokenizer.encode_async(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
async def _extract_prompt_components_async(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PromptComponents:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
elif isinstance(inputs, dict):
|
||||
if "prompt_token_ids" in inputs:
|
||||
prompt = None
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
else:
|
||||
# NOTE: This extra assignment is required to pass mypy
|
||||
prompt = parsed_prompt = inputs["prompt"]
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
parsed_prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
else:
|
||||
assert_never(inputs)
|
||||
|
||||
return prompt, prompt_token_ids, multi_modal_data
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
encoder_task = self._extract_prompt_components_async(
|
||||
inputs["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := inputs["decoder_prompt"]) is None:
|
||||
encoder_comps = await encoder_task
|
||||
decoder_comps = None, None, None
|
||||
else:
|
||||
decoder_task = self._extract_prompt_components_async(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
encoder_comps, decoder_comps = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
else:
|
||||
encoder_comps = await self._extract_prompt_components_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
decoder_comps = None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
||||
|
||||
async def _process_decoder_only_prompt_async(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||
prompt_comps = await self._extract_prompt_components_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
prompt_comps,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
async def process_model_inputs_async(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
|
||||
"""Async version of :meth:`process_model_inputs`."""
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
model_inputs = await self._process_encoder_decoder_prompt_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||
"to decoder-only models")
|
||||
|
||||
# Decoder-only operation
|
||||
model_inputs = await self._process_decoder_only_prompt_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
return self.input_processor(model_inputs)
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -553,12 +415,13 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
processed_inputs = await self.process_model_inputs_async(
|
||||
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
|
||||
@ -6,10 +6,10 @@ from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||
Iterable, List, Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Type, Union
|
||||
from typing import Set, Type, Union
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeVar, assert_never
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
@ -28,13 +28,11 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
|
||||
InputRegistry, LLMInputs, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
InputRegistry, LLMInputs, PromptInputs)
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -75,11 +73,6 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
||||
|
||||
PromptComponents = Tuple[Optional[str], List[int],
|
||||
Optional[MultiModalDataDict]]
|
||||
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
||||
Optional[MultiModalDataDict]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutputState:
|
||||
@ -313,6 +306,9 @@ class LLMEngine:
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(model_config,
|
||||
self.tokenizer)
|
||||
|
||||
self.input_registry = input_registry
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
model_config)
|
||||
@ -571,19 +567,15 @@ class LLMEngine:
|
||||
if model_executor := getattr(self, "model_executor", None):
|
||||
model_executor.shutdown()
|
||||
|
||||
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
def get_tokenizer_group(
|
||||
self,
|
||||
group_type: Type[_G] = BaseTokenizerGroup,
|
||||
*,
|
||||
missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
|
||||
) -> _G:
|
||||
tokenizer_group = self.tokenizer
|
||||
|
||||
if tokenizer_group is None:
|
||||
raise ValueError(missing_msg)
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
if not isinstance(tokenizer_group, group_type):
|
||||
raise TypeError("Invalid type of tokenizer group. "
|
||||
f"Expected type: {group_type}, but "
|
||||
@ -615,52 +607,6 @@ class LLMEngine:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
def _get_bos_token_id(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
if self.tokenizer is None:
|
||||
logger.warning("Using None for BOS token id because tokenizer "
|
||||
"is not initialized")
|
||||
return None
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
|
||||
|
||||
def _get_eos_token_id(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
if self.tokenizer is None:
|
||||
logger.warning("Using None for EOS token id because tokenizer "
|
||||
"is not initialized")
|
||||
return None
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
||||
|
||||
def _get_decoder_start_token_id(self) -> Optional[int]:
|
||||
'''
|
||||
Obtain the decoder start token id employed by an encoder/decoder
|
||||
model. Returns None for non-encoder/decoder models or if the
|
||||
model config is unavailable.
|
||||
'''
|
||||
|
||||
if not self.is_encoder_decoder_model():
|
||||
logger.warning("Using None for decoder start token id because "
|
||||
"this is not an encoder/decoder model.")
|
||||
return None
|
||||
|
||||
if (self.model_config is None or self.model_config.hf_config is None):
|
||||
logger.warning("Using None for decoder start token id because "
|
||||
"model config is not available.")
|
||||
return None
|
||||
|
||||
dec_start_token_id = getattr(self.model_config.hf_config,
|
||||
'decoder_start_token_id', None)
|
||||
if dec_start_token_id is None:
|
||||
logger.warning("Falling back on <BOS> for decoder start token id "
|
||||
"because decoder start token id is not available.")
|
||||
dec_start_token_id = self._get_bos_token_id()
|
||||
|
||||
return dec_start_token_id
|
||||
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -675,7 +621,7 @@ class LLMEngine:
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = self._get_eos_token_id(lora_request)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
||||
lora_request, prompt_adapter_request)
|
||||
@ -725,334 +671,6 @@ class LLMEngine:
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
_LLMInputComponentsType = Tuple[str, List[int]]
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
decoder_input_ids: Optional[List[int]],
|
||||
) -> List[int]:
|
||||
"""
|
||||
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
||||
|
||||
Based on
|
||||
|
||||
https://github.com/huggingface/transformers/blob/
|
||||
4037a2b5b1278736e566aec12e169100275545ea/
|
||||
src/transformers/generation/utils.py
|
||||
|
||||
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
|
||||
|
||||
Arguments:
|
||||
|
||||
* decoder_input_ids: input token ids to preprocess
|
||||
|
||||
Returns:
|
||||
|
||||
* Processed token list
|
||||
"""
|
||||
|
||||
decoder_start_token_id = self._get_decoder_start_token_id()
|
||||
assert decoder_start_token_id is not None
|
||||
|
||||
if decoder_input_ids is None:
|
||||
# no decoder prompt input ->
|
||||
# use decoder_start_token_id as decoder_input_ids
|
||||
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
|
||||
|
||||
if (len(decoder_input_ids) == 0
|
||||
or decoder_input_ids[0] != decoder_start_token_id):
|
||||
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
||||
|
||||
return decoder_input_ids
|
||||
|
||||
def _tokenize_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> List[int]:
|
||||
'''
|
||||
Wrapper around application of the model's tokenizer.
|
||||
|
||||
Arguments:
|
||||
|
||||
* prompt
|
||||
* request_id
|
||||
* lora_request
|
||||
|
||||
Returns:
|
||||
|
||||
* prompt token ids
|
||||
'''
|
||||
|
||||
tokenizer = self.get_tokenizer_group(
|
||||
missing_msg="prompts must be None if skip_tokenizer_init is True")
|
||||
|
||||
return tokenizer.encode(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
def _extract_prompt_components(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PromptComponents:
|
||||
'''
|
||||
Extract the components of any single encoder or decoder input prompt.
|
||||
|
||||
Arguments:
|
||||
|
||||
* request_id
|
||||
* inputs: single encoder or decoder input prompt
|
||||
* lora_request: this is only valid for decoder prompts
|
||||
|
||||
Returns:
|
||||
|
||||
* prompt
|
||||
* prompt_token_ids
|
||||
* multi_modal_data
|
||||
'''
|
||||
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
elif isinstance(inputs, dict):
|
||||
if "prompt_token_ids" in inputs:
|
||||
prompt = None
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
else:
|
||||
# NOTE: This extra assignment is required to pass mypy
|
||||
prompt = parsed_prompt = inputs["prompt"]
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
parsed_prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
else:
|
||||
assert_never(inputs)
|
||||
|
||||
return prompt, prompt_token_ids, multi_modal_data
|
||||
|
||||
def _apply_prompt_adapter(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> List[int]:
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = (
|
||||
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
|
||||
+ prompt_token_ids)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
|
||||
'''
|
||||
Specifically for encoder/decoder models:
|
||||
generate a default decoder prompt for when
|
||||
the user specifies only the encoder prompt.
|
||||
|
||||
Encoder/decoder models utilize the decoder
|
||||
prompt in different ways; as new models are
|
||||
added, it is intended that this function
|
||||
will be extended to produce differing
|
||||
default decoder prompts, depending on the
|
||||
model variety.
|
||||
|
||||
Absent a special case, the default behavior
|
||||
of this method is to mirror the behavior of
|
||||
the HuggingFace (HF) GenerationMixin for a None
|
||||
decoder prompt, which is to employ a logit processor
|
||||
setting to force the first decoded token to be <BOS>.
|
||||
Here, this behavior is approximated by having the
|
||||
"default" decoder prompt be <BOS>.
|
||||
|
||||
However, it is possible that in the future
|
||||
other models may have different or more
|
||||
complex logic for the default decoder prompt.
|
||||
This motivates having a special helper method
|
||||
for default decoder prompts.
|
||||
|
||||
Returns:
|
||||
|
||||
* prompt_token_ids
|
||||
'''
|
||||
|
||||
bos_token_id = self._get_bos_token_id()
|
||||
assert bos_token_id is not None
|
||||
return [bos_token_id]
|
||||
|
||||
def _build_enc_dec_llm_inputs(
|
||||
self,
|
||||
encoder_comps: PromptComponents,
|
||||
decoder_comps: DecoderPromptComponents,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
|
||||
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
|
||||
|
||||
if encoder_mm_data is not None or decoder_mm_data is not None:
|
||||
raise ValueError("Multi-modal encoder-decoder models are "
|
||||
"not supported yet")
|
||||
|
||||
decoder_prompt_ids = (
|
||||
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
|
||||
|
||||
return EncoderDecoderLLMInputs(
|
||||
prompt_token_ids=decoder_prompt_ids,
|
||||
prompt=decoder_prompt,
|
||||
encoder_prompt_token_ids=encoder_prompt_ids,
|
||||
encoder_prompt=encoder_prompt,
|
||||
)
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
'''
|
||||
For encoder/decoder models only:
|
||||
Process an input prompt into an
|
||||
:class:`EncoderDecoderLLMInputs` instance.
|
||||
|
||||
There are two types of input prompts:
|
||||
singleton prompts which carry only the
|
||||
encoder prompt, and explicit encoder/decoder
|
||||
prompts which carry both the encoder and the
|
||||
decoder prompts as member variables.
|
||||
|
||||
This function handles the following scenarios:
|
||||
* Singleton encoder prompt: extract encoder prompt
|
||||
token ids & infer default decoder prompt token ids
|
||||
* Explicit encoder/decoder prompt: extract encoder
|
||||
and decoder prompt token ids
|
||||
|
||||
Note that for Explicit encoder/decoder prompts,
|
||||
each sub-prompt (encoder or decoder prompt) can
|
||||
have any possible singleton type; thus this
|
||||
method relies on helper functions to obtain
|
||||
token ids for the sub-prompts.
|
||||
|
||||
Arguments:
|
||||
|
||||
* inputs: an input prompt
|
||||
* request_id
|
||||
|
||||
Returns:
|
||||
|
||||
* :class:`EncoderDecoderLLMInputs` instance
|
||||
'''
|
||||
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
encoder_comps = self._extract_prompt_components(
|
||||
inputs["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := inputs["decoder_prompt"]) is None:
|
||||
decoder_comps = None, None, None
|
||||
else:
|
||||
decoder_comps = self._extract_prompt_components(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
encoder_comps = self._extract_prompt_components(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
decoder_comps = None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
||||
|
||||
def _build_decoder_only_llm_inputs(
|
||||
self,
|
||||
prompt_comps: PromptComponents,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> LLMInputs:
|
||||
prompt, prompt_token_ids, multi_modal_data = prompt_comps
|
||||
|
||||
prompt_token_ids = self._apply_prompt_adapter(
|
||||
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
'''
|
||||
For decoder-only models:
|
||||
Process an input prompt into an :class:`LLMInputs` instance.
|
||||
|
||||
Arguments:
|
||||
|
||||
* inputs: input prompt
|
||||
* request_id
|
||||
* lora_request
|
||||
* prompt_adapter_request
|
||||
|
||||
Returns:
|
||||
|
||||
* :class:`LLMInputs` instance
|
||||
'''
|
||||
|
||||
prompt_comps = self._extract_prompt_components(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
prompt_comps,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
|
||||
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
model_inputs = self._process_encoder_decoder_prompt(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||
"to decoder-only models")
|
||||
|
||||
# Decoder-only operation
|
||||
model_inputs = self._process_decoder_only_prompt(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
return self.input_processor(model_inputs)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -1111,12 +729,13 @@ class LLMEngine:
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
processed_inputs = self.process_model_inputs(
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
@ -2043,7 +1662,7 @@ class LLMEngine:
|
||||
metrics.model_execute_time)
|
||||
|
||||
def is_encoder_decoder_model(self):
|
||||
return self.model_config.is_encoder_decoder_model
|
||||
return self.input_preprocessor.is_encoder_decoder_model()
|
||||
|
||||
def is_embedding_model(self):
|
||||
return self.model_config.is_embedding_model
|
||||
|
||||
@ -5,7 +5,8 @@ from typing_extensions import TypeIs
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
||||
LLMInputs, PromptInputs)
|
||||
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
|
||||
TokensPrompt)
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
@ -60,8 +61,38 @@ def parse_and_batch_prompt(
|
||||
for elem in prompt
|
||||
]
|
||||
|
||||
raise ValueError("prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays")
|
||||
raise TypeError("prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays")
|
||||
|
||||
|
||||
class ParsedStrPrompt(TypedDict):
|
||||
type: Literal["str"]
|
||||
content: str
|
||||
|
||||
|
||||
class ParsedTextPrompt(TypedDict):
|
||||
type: Literal["text"]
|
||||
content: TextPrompt
|
||||
|
||||
|
||||
class ParsedTokensPrompt(TypedDict):
|
||||
type: Literal["tokens"]
|
||||
content: TokensPrompt
|
||||
|
||||
|
||||
def parse_singleton_prompt(
|
||||
inputs: SingletonPromptInputs,
|
||||
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
|
||||
if isinstance(inputs, str):
|
||||
return ParsedStrPrompt(type="str", content=inputs)
|
||||
elif isinstance(inputs, dict):
|
||||
if "prompt_token_ids" in inputs:
|
||||
return ParsedTokensPrompt(type="tokens",
|
||||
content=inputs) # type: ignore
|
||||
elif "prompt" in inputs:
|
||||
return ParsedTextPrompt(type="text", content=inputs)
|
||||
|
||||
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
|
||||
|
||||
|
||||
def is_explicit_encoder_decoder_prompt(
|
||||
|
||||
536
vllm/inputs/preprocess.py
Normal file
536
vllm/inputs/preprocess.py
Normal file
@ -0,0 +1,536 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
|
||||
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PromptComponents = Tuple[Optional[str], List[int],
|
||||
Optional["MultiModalDataDict"]]
|
||||
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
||||
Optional["MultiModalDataDict"]]
|
||||
|
||||
|
||||
class InputPreprocessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: Optional[BaseTokenizerGroup],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def get_tokenizer_group(self) -> BaseTokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("You cannot pass text prompts when "
|
||||
"`skip_tokenizer_init` is True")
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
def get_bos_token_id(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
if self.tokenizer is None:
|
||||
logger.warning("Using None for BOS token id because tokenizer "
|
||||
"is not initialized")
|
||||
return None
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
|
||||
|
||||
def get_eos_token_id(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
if self.tokenizer is None:
|
||||
logger.warning("Using None for EOS token id because tokenizer "
|
||||
"is not initialized")
|
||||
return None
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
||||
|
||||
def get_decoder_start_token_id(self) -> Optional[int]:
|
||||
'''
|
||||
Obtain the decoder start token id employed by an encoder/decoder
|
||||
model. Returns None for non-encoder/decoder models or if the
|
||||
model config is unavailable.
|
||||
'''
|
||||
|
||||
if not self.is_encoder_decoder_model():
|
||||
logger.warning("Using None for decoder start token id because "
|
||||
"this is not an encoder/decoder model.")
|
||||
return None
|
||||
|
||||
if (self.model_config is None or self.model_config.hf_config is None):
|
||||
logger.warning("Using None for decoder start token id because "
|
||||
"model config is not available.")
|
||||
return None
|
||||
|
||||
dec_start_token_id = getattr(self.model_config.hf_config,
|
||||
'decoder_start_token_id', None)
|
||||
if dec_start_token_id is None:
|
||||
logger.warning("Falling back on <BOS> for decoder start token id "
|
||||
"because decoder start token id is not available.")
|
||||
dec_start_token_id = self.get_bos_token_id()
|
||||
|
||||
return dec_start_token_id
|
||||
|
||||
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
|
||||
'''
|
||||
Specifically for encoder/decoder models:
|
||||
generate a default decoder prompt for when
|
||||
the user specifies only the encoder prompt.
|
||||
|
||||
Encoder/decoder models utilize the decoder
|
||||
prompt in different ways; as new models are
|
||||
added, it is intended that this function
|
||||
will be extended to produce differing
|
||||
default decoder prompts, depending on the
|
||||
model variety.
|
||||
|
||||
Absent a special case, the default behavior
|
||||
of this method is to mirror the behavior of
|
||||
the HuggingFace (HF) GenerationMixin for a None
|
||||
decoder prompt, which is to employ a logit processor
|
||||
setting to force the first decoded token to be <BOS>.
|
||||
Here, this behavior is approximated by having the
|
||||
"default" decoder prompt be <BOS>.
|
||||
|
||||
However, it is possible that in the future
|
||||
other models may have different or more
|
||||
complex logic for the default decoder prompt.
|
||||
This motivates having a special helper method
|
||||
for default decoder prompts.
|
||||
|
||||
Returns:
|
||||
|
||||
* prompt_token_ids
|
||||
'''
|
||||
|
||||
bos_token_id = self.get_bos_token_id()
|
||||
assert bos_token_id is not None
|
||||
return [bos_token_id]
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
decoder_input_ids: Optional[List[int]],
|
||||
) -> List[int]:
|
||||
"""
|
||||
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
||||
|
||||
Based on
|
||||
|
||||
https://github.com/huggingface/transformers/blob/
|
||||
4037a2b5b1278736e566aec12e169100275545ea/
|
||||
src/transformers/generation/utils.py
|
||||
|
||||
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
|
||||
|
||||
Arguments:
|
||||
|
||||
* decoder_input_ids: input token ids to preprocess
|
||||
|
||||
Returns:
|
||||
|
||||
* Processed token list
|
||||
"""
|
||||
|
||||
decoder_start_token_id = self.get_decoder_start_token_id()
|
||||
assert decoder_start_token_id is not None
|
||||
|
||||
if decoder_input_ids is None:
|
||||
# no decoder prompt input ->
|
||||
# use decoder_start_token_id as decoder_input_ids
|
||||
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
|
||||
|
||||
if (len(decoder_input_ids) == 0
|
||||
or decoder_input_ids[0] != decoder_start_token_id):
|
||||
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
||||
|
||||
return decoder_input_ids
|
||||
|
||||
def _apply_prompt_adapter(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> List[int]:
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = (
|
||||
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
|
||||
+ prompt_token_ids)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def _tokenize_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> List[int]:
|
||||
"""
|
||||
Apply the model's tokenizer to a text prompt, returning the
|
||||
corresponding token IDs.
|
||||
"""
|
||||
tokenizer = self.get_tokenizer_group()
|
||||
|
||||
return tokenizer.encode(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> List[int]:
|
||||
"""Async version of :meth:`_tokenize_prompt`."""
|
||||
tokenizer = self.get_tokenizer_group()
|
||||
|
||||
return await tokenizer.encode_async(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
def _extract_prompt_components(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PromptComponents:
|
||||
'''
|
||||
Extract the components of any single encoder or decoder input prompt.
|
||||
|
||||
Arguments:
|
||||
|
||||
* request_id
|
||||
* inputs: single encoder or decoder input prompt
|
||||
* lora_request: this is only valid for decoder prompts
|
||||
|
||||
Returns:
|
||||
|
||||
* prompt
|
||||
* prompt_token_ids
|
||||
* multi_modal_data
|
||||
'''
|
||||
|
||||
parsed = parse_singleton_prompt(inputs)
|
||||
|
||||
if parsed["type"] == "str":
|
||||
prompt = parsed["content"]
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
elif parsed["type"] == "tokens":
|
||||
prompt = None
|
||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
elif parsed["type"] == "text":
|
||||
prompt = parsed["content"]["prompt"]
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
else:
|
||||
assert_never(parsed)
|
||||
|
||||
return prompt, prompt_token_ids, multi_modal_data
|
||||
|
||||
async def _extract_prompt_components_async(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PromptComponents:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
parsed = parse_singleton_prompt(inputs)
|
||||
|
||||
if parsed["type"] == "str":
|
||||
prompt = parsed["content"]
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
elif parsed["type"] == "tokens":
|
||||
prompt = None
|
||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
elif parsed["type"] == "text":
|
||||
prompt = parsed["content"]["prompt"]
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
else:
|
||||
assert_never(parsed)
|
||||
|
||||
return prompt, prompt_token_ids, multi_modal_data
|
||||
|
||||
def _build_enc_dec_llm_inputs(
|
||||
self,
|
||||
encoder_comps: PromptComponents,
|
||||
decoder_comps: DecoderPromptComponents,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
|
||||
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
|
||||
|
||||
if encoder_mm_data is not None or decoder_mm_data is not None:
|
||||
raise ValueError("Multi-modal encoder-decoder models are "
|
||||
"not supported yet")
|
||||
|
||||
decoder_prompt_ids = (
|
||||
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
|
||||
|
||||
return EncoderDecoderLLMInputs(
|
||||
prompt_token_ids=decoder_prompt_ids,
|
||||
prompt=decoder_prompt,
|
||||
encoder_prompt_token_ids=encoder_prompt_ids,
|
||||
encoder_prompt=encoder_prompt,
|
||||
)
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
'''
|
||||
For encoder/decoder models only:
|
||||
Process an input prompt into an
|
||||
:class:`EncoderDecoderLLMInputs` instance.
|
||||
|
||||
There are two types of input prompts:
|
||||
singleton prompts which carry only the
|
||||
encoder prompt, and explicit encoder/decoder
|
||||
prompts which carry both the encoder and the
|
||||
decoder prompts as member variables.
|
||||
|
||||
This function handles the following scenarios:
|
||||
* Singleton encoder prompt: extract encoder prompt
|
||||
token ids & infer default decoder prompt token ids
|
||||
* Explicit encoder/decoder prompt: extract encoder
|
||||
and decoder prompt token ids
|
||||
|
||||
Note that for Explicit encoder/decoder prompts,
|
||||
each sub-prompt (encoder or decoder prompt) can
|
||||
have any possible singleton type; thus this
|
||||
method relies on helper functions to obtain
|
||||
token ids for the sub-prompts.
|
||||
|
||||
Arguments:
|
||||
|
||||
* inputs: an input prompt
|
||||
* request_id
|
||||
|
||||
Returns:
|
||||
|
||||
* :class:`EncoderDecoderLLMInputs` instance
|
||||
'''
|
||||
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
encoder_comps = self._extract_prompt_components(
|
||||
inputs["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := inputs["decoder_prompt"]) is None:
|
||||
decoder_comps = None, None, None
|
||||
else:
|
||||
decoder_comps = self._extract_prompt_components(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
encoder_comps = self._extract_prompt_components(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
decoder_comps = None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
encoder_task = self._extract_prompt_components_async(
|
||||
inputs["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := inputs["decoder_prompt"]) is None:
|
||||
encoder_comps = await encoder_task
|
||||
decoder_comps = None, None, None
|
||||
else:
|
||||
decoder_task = self._extract_prompt_components_async(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
encoder_comps, decoder_comps = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
else:
|
||||
encoder_comps = await self._extract_prompt_components_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
decoder_comps = None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
||||
|
||||
def _build_decoder_only_llm_inputs(
|
||||
self,
|
||||
prompt_comps: PromptComponents,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> LLMInputs:
|
||||
prompt, prompt_token_ids, multi_modal_data = prompt_comps
|
||||
|
||||
prompt_token_ids = self._apply_prompt_adapter(
|
||||
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
'''
|
||||
For decoder-only models:
|
||||
Process an input prompt into an :class:`LLMInputs` instance.
|
||||
|
||||
Arguments:
|
||||
|
||||
* inputs: input prompt
|
||||
* request_id
|
||||
* lora_request
|
||||
* prompt_adapter_request
|
||||
|
||||
Returns:
|
||||
|
||||
* :class:`LLMInputs` instance
|
||||
'''
|
||||
|
||||
prompt_comps = self._extract_prompt_components(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
prompt_comps,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
async def _process_decoder_only_prompt_async(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||
prompt_comps = await self._extract_prompt_components_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
prompt_comps,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
|
||||
"""Preprocess the input prompt."""
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
return self._process_encoder_decoder_prompt(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||
"to decoder-only models")
|
||||
|
||||
# Decoder-only operation
|
||||
return self._process_decoder_only_prompt(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
async def preprocess_async(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
|
||||
"""Async version of :meth:`preprocess`."""
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
return await self._process_encoder_decoder_prompt_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||
"to decoder-only models")
|
||||
|
||||
# Decoder-only operation
|
||||
return await self._process_decoder_only_prompt_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
def is_encoder_decoder_model(self):
|
||||
return self.model_config.is_encoder_decoder_model
|
||||
Loading…
x
Reference in New Issue
Block a user