[Core] Factor out input preprocessing to a separate class (#7329)

This commit is contained in:
Cyrus Leung 2024-09-13 10:56:13 +08:00 committed by GitHub
parent 8f44a92d85
commit 5ec9c0fb3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 590 additions and 540 deletions

View File

@ -11,9 +11,10 @@ def test_skip_tokenizer_initialization(model: str):
# token ids.
llm = LLM(model=model, skip_tokenizer_init=True)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError) as err:
with pytest.raises(ValueError, match="cannot pass text prompts when"):
llm.generate("abc", sampling_params)
assert "prompts must be None if" in str(err.value)
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
assert len(outputs) > 0

View File

@ -4,22 +4,17 @@ from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
PromptComponents, SchedulerOutputState)
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
@ -403,139 +398,6 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
model_inputs = await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
return self.input_processor(model_inputs)
async def add_request_async(
self,
request_id: str,
@ -553,12 +415,13 @@ class _AsyncLLMEngine(LLMEngine):
if arrival_time is None:
arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async(
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)
self._add_processed_request(
request_id=request_id,

View File

@ -6,10 +6,10 @@ from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, Union
from typing import Set, Type, Union
import torch
from typing_extensions import TypeVar, assert_never
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
@ -28,13 +28,11 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
InputRegistry, LLMInputs, PromptInputs)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
@ -75,11 +73,6 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
PromptComponents = Tuple[Optional[str], List[int],
Optional[MultiModalDataDict]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional[MultiModalDataDict]]
@dataclass
class SchedulerOutputState:
@ -313,6 +306,9 @@ class LLMEngine:
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
model_config)
@ -571,19 +567,15 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
"skip_tokenizer_init is True")
def get_tokenizer_group(
self,
group_type: Type[_G] = BaseTokenizerGroup,
*,
missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
) -> _G:
tokenizer_group = self.tokenizer
if tokenizer_group is None:
raise ValueError(missing_msg)
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
@ -615,52 +607,6 @@ class LLMEngine:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def _get_bos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
def _get_eos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _get_decoder_start_token_id(self) -> Optional[int]:
'''
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
'''
if not self.is_encoder_decoder_model():
logger.warning("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
logger.warning("Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
if dec_start_token_id is None:
logger.warning("Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available.")
dec_start_token_id = self._get_bos_token_id()
return dec_start_token_id
def _add_processed_request(
self,
request_id: str,
@ -675,7 +621,7 @@ class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self._get_eos_token_id(lora_request)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
@ -725,334 +671,6 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
_LLMInputComponentsType = Tuple[str, List[int]]
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]],
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on
https://github.com/huggingface/transformers/blob/
4037a2b5b1278736e566aec12e169100275545ea/
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Arguments:
* decoder_input_ids: input token ids to preprocess
Returns:
* Processed token list
"""
decoder_start_token_id = self._get_decoder_start_token_id()
assert decoder_start_token_id is not None
if decoder_input_ids is None:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
def _tokenize_prompt(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
'''
Wrapper around application of the model's tokenizer.
Arguments:
* prompt
* request_id
* lora_request
Returns:
* prompt token ids
'''
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
def _extract_prompt_components(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
'''
Extract the components of any single encoder or decoder input prompt.
Arguments:
* request_id
* inputs: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
* prompt
* prompt_token_ids
* multi_modal_data
'''
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = self._tokenize_prompt(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
def _apply_prompt_adapter(
self,
prompt_token_ids: List[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> List[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return prompt_token_ids
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
'''
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
Encoder/decoder models utilize the decoder
prompt in different ways; as new models are
added, it is intended that this function
will be extended to produce differing
default decoder prompts, depending on the
model variety.
Absent a special case, the default behavior
of this method is to mirror the behavior of
the HuggingFace (HF) GenerationMixin for a None
decoder prompt, which is to employ a logit processor
setting to force the first decoded token to be <BOS>.
Here, this behavior is approximated by having the
"default" decoder prompt be <BOS>.
However, it is possible that in the future
other models may have different or more
complex logic for the default decoder prompt.
This motivates having a special helper method
for default decoder prompts.
Returns:
* prompt_token_ids
'''
bos_token_id = self._get_bos_token_id()
assert bos_token_id is not None
return [bos_token_id]
def _build_enc_dec_llm_inputs(
self,
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
) -> EncoderDecoderLLMInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
)
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
'''
For encoder/decoder models only:
Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
encoder prompt, and explicit encoder/decoder
prompts which carry both the encoder and the
decoder prompts as member variables.
This function handles the following scenarios:
* Singleton encoder prompt: extract encoder prompt
token ids & infer default decoder prompt token ids
* Explicit encoder/decoder prompt: extract encoder
and decoder prompt token ids
Note that for Explicit encoder/decoder prompts,
each sub-prompt (encoder or decoder prompt) can
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments:
* inputs: an input prompt
* request_id
Returns:
* :class:`EncoderDecoderLLMInputs` instance
'''
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
decoder_comps = None, None, None
else:
decoder_comps = self._extract_prompt_components(
decoder_input,
request_id=request_id,
)
else:
encoder_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
def _build_decoder_only_llm_inputs(
self,
prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data)
def _process_decoder_only_prompt(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
'''
For decoder-only models:
Process an input prompt into an :class:`LLMInputs` instance.
Arguments:
* inputs: input prompt
* request_id
* lora_request
* prompt_adapter_request
Returns:
* :class:`LLMInputs` instance
'''
prompt_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
def process_model_inputs(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = self._process_encoder_decoder_prompt(
inputs,
request_id=request_id,
)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
model_inputs = self._process_decoder_only_prompt(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
return self.input_processor(model_inputs)
def add_request(
self,
request_id: str,
@ -1111,12 +729,13 @@ class LLMEngine:
if arrival_time is None:
arrival_time = time.time()
processed_inputs = self.process_model_inputs(
preprocessed_inputs = self.input_preprocessor.preprocess(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)
self._add_processed_request(
request_id=request_id,
@ -2043,7 +1662,7 @@ class LLMEngine:
metrics.model_execute_time)
def is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
return self.input_preprocessor.is_encoder_decoder_model()
def is_embedding_model(self):
return self.model_config.is_embedding_model

View File

@ -5,7 +5,8 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs)
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt)
class ParsedText(TypedDict):
@ -60,8 +61,38 @@ def parse_and_batch_prompt(
for elem in prompt
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
raise TypeError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class ParsedStrPrompt(TypedDict):
type: Literal["str"]
content: str
class ParsedTextPrompt(TypedDict):
type: Literal["text"]
content: TextPrompt
class ParsedTokensPrompt(TypedDict):
type: Literal["tokens"]
content: TokensPrompt
def parse_singleton_prompt(
inputs: SingletonPromptInputs,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
if isinstance(inputs, str):
return ParsedStrPrompt(type="str", content=inputs)
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
return ParsedTokensPrompt(type="tokens",
content=inputs) # type: ignore
elif "prompt" in inputs:
return ParsedTextPrompt(type="text", content=inputs)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_explicit_encoder_decoder_prompt(

536
vllm/inputs/preprocess.py Normal file
View File

@ -0,0 +1,536 @@
import asyncio
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
logger = init_logger(__name__)
PromptComponents = Tuple[Optional[str], List[int],
Optional["MultiModalDataDict"]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional["MultiModalDataDict"]]
class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[BaseTokenizerGroup],
) -> None:
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
def get_tokenizer_group(self) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError("You cannot pass text prompts when "
"`skip_tokenizer_init` is True")
return self.tokenizer
def get_bos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
def get_eos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def get_decoder_start_token_id(self) -> Optional[int]:
'''
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
'''
if not self.is_encoder_decoder_model():
logger.warning("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
logger.warning("Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
if dec_start_token_id is None:
logger.warning("Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available.")
dec_start_token_id = self.get_bos_token_id()
return dec_start_token_id
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
'''
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
Encoder/decoder models utilize the decoder
prompt in different ways; as new models are
added, it is intended that this function
will be extended to produce differing
default decoder prompts, depending on the
model variety.
Absent a special case, the default behavior
of this method is to mirror the behavior of
the HuggingFace (HF) GenerationMixin for a None
decoder prompt, which is to employ a logit processor
setting to force the first decoded token to be <BOS>.
Here, this behavior is approximated by having the
"default" decoder prompt be <BOS>.
However, it is possible that in the future
other models may have different or more
complex logic for the default decoder prompt.
This motivates having a special helper method
for default decoder prompts.
Returns:
* prompt_token_ids
'''
bos_token_id = self.get_bos_token_id()
assert bos_token_id is not None
return [bos_token_id]
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]],
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on
https://github.com/huggingface/transformers/blob/
4037a2b5b1278736e566aec12e169100275545ea/
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Arguments:
* decoder_input_ids: input token ids to preprocess
Returns:
* Processed token list
"""
decoder_start_token_id = self.get_decoder_start_token_id()
assert decoder_start_token_id is not None
if decoder_input_ids is None:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
def _apply_prompt_adapter(
self,
prompt_token_ids: List[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> List[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return prompt_token_ids
def _tokenize_prompt(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""
Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()
return tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group()
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
def _extract_prompt_components(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
'''
Extract the components of any single encoder or decoder input prompt.
Arguments:
* request_id
* inputs: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
* prompt
* prompt_token_ids
* multi_modal_data
'''
parsed = parse_singleton_prompt(inputs)
if parsed["type"] == "str":
prompt = parsed["content"]
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif parsed["type"] == "tokens":
prompt = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"]
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = parsed["content"].get("multi_modal_data")
else:
assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(inputs)
if parsed["type"] == "str":
prompt = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif parsed["type"] == "tokens":
prompt = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = parsed["content"].get("multi_modal_data")
else:
assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data
def _build_enc_dec_llm_inputs(
self,
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
) -> EncoderDecoderLLMInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
)
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
'''
For encoder/decoder models only:
Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
encoder prompt, and explicit encoder/decoder
prompts which carry both the encoder and the
decoder prompts as member variables.
This function handles the following scenarios:
* Singleton encoder prompt: extract encoder prompt
token ids & infer default decoder prompt token ids
* Explicit encoder/decoder prompt: extract encoder
and decoder prompt token ids
Note that for Explicit encoder/decoder prompts,
each sub-prompt (encoder or decoder prompt) can
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments:
* inputs: an input prompt
* request_id
Returns:
* :class:`EncoderDecoderLLMInputs` instance
'''
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
decoder_comps = None, None, None
else:
decoder_comps = self._extract_prompt_components(
decoder_input,
request_id=request_id,
)
else:
encoder_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
def _build_decoder_only_llm_inputs(
self,
prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data)
def _process_decoder_only_prompt(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
'''
For decoder-only models:
Process an input prompt into an :class:`LLMInputs` instance.
Arguments:
* inputs: input prompt
* request_id
* lora_request
* prompt_adapter_request
Returns:
* :class:`LLMInputs` instance
'''
prompt_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
def preprocess(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Preprocess the input prompt."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(
inputs,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return self._process_decoder_only_prompt(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
async def preprocess_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`preprocess`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
def is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model