mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 05:55:01 +08:00
[Core] Make encoder-decoder inputs a nested structure to be more composable (#9604)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
04bbf38e05
commit
bbc3619dc8
@ -4,6 +4,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.inputs import EncoderDecoderInputs, token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup
|
||||
|
||||
@ -27,10 +28,7 @@ def create_dummy_prompt(
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id),
|
||||
inputs={
|
||||
"prompt": prompt_str,
|
||||
"prompt_token_ids": prompt_tokens,
|
||||
},
|
||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
||||
block_size=block_size)
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[prompt],
|
||||
@ -63,23 +61,21 @@ def create_dummy_prompt_encoder_decoder(
|
||||
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
|
||||
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
|
||||
|
||||
inputs = {
|
||||
"prompt": decoder_prompt_str,
|
||||
"prompt_token_ids": decoder_prompt_tokens,
|
||||
"encoder_prompt": encoder_prompt_str,
|
||||
"encoder_prompt_token_ids": encoder_prompt_tokens,
|
||||
"multi_modal_data": None,
|
||||
inputs: EncoderDecoderInputs = {
|
||||
"decoder": token_inputs(decoder_prompt_tokens,
|
||||
prompt=decoder_prompt_str),
|
||||
"encoder": token_inputs(encoder_prompt_tokens,
|
||||
prompt=encoder_prompt_str),
|
||||
}
|
||||
|
||||
decoder_prompt = Sequence(int(request_id),
|
||||
inputs=inputs,
|
||||
block_size=block_size,
|
||||
from_decoder_prompt=True)
|
||||
inputs=inputs["decoder"],
|
||||
block_size=block_size)
|
||||
|
||||
encoder_prompt = Sequence(int(request_id),
|
||||
inputs=inputs,
|
||||
block_size=block_size,
|
||||
from_decoder_prompt=False)
|
||||
inputs=inputs["encoder"],
|
||||
block_size=block_size)
|
||||
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[decoder_prompt],
|
||||
sampling_params=SamplingParams(best_of=best_of),
|
||||
@ -108,7 +104,7 @@ def create_seq_group(
|
||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||
seq = Sequence(
|
||||
seq_id=seq_id_start + seq_id_offset,
|
||||
inputs={"prompt_token_ids": prompt_token_ids},
|
||||
inputs=token_inputs(prompt_token_ids),
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
@ -143,21 +139,19 @@ def create_seq_group_encoder_decoder(
|
||||
|
||||
prompt_token_ids = [0] * seq_prompt_len
|
||||
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"encoder_prompt": "",
|
||||
"encoder_prompt_token_ids": prompt_token_ids,
|
||||
"multi_modal_data": None,
|
||||
inputs: EncoderDecoderInputs = {
|
||||
"decoder": token_inputs(prompt_token_ids),
|
||||
"encoder": token_inputs(prompt_token_ids),
|
||||
}
|
||||
|
||||
seqs = []
|
||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||
# Construct decoder input sequences
|
||||
seq = Sequence(seq_id=seq_id_start + seq_id_offset,
|
||||
inputs=inputs,
|
||||
block_size=16,
|
||||
from_decoder_prompt=True)
|
||||
seq = Sequence(
|
||||
seq_id=seq_id_start + seq_id_offset,
|
||||
inputs=inputs["decoder"],
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
for i in range(output_len):
|
||||
seq.append_token_id(
|
||||
@ -167,10 +161,11 @@ def create_seq_group_encoder_decoder(
|
||||
seqs.append(seq)
|
||||
|
||||
# Encoder input sequence
|
||||
encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
|
||||
inputs=inputs,
|
||||
block_size=16,
|
||||
from_decoder_prompt=False)
|
||||
encoder_seq = Sequence(
|
||||
seq_id=seq_id_start + len(seq_output_lens),
|
||||
inputs=inputs["encoder"],
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
return SequenceGroup(request_id=request_id,
|
||||
seqs=seqs,
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob, Sequence, SequenceStatus
|
||||
|
||||
@ -15,7 +16,7 @@ def sequence_with_eos(text: str, eos_token: str,
|
||||
"""
|
||||
seq = Sequence(
|
||||
seq_id=0,
|
||||
inputs={"prompt_token_ids": []},
|
||||
inputs=token_inputs([]),
|
||||
block_size=16,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Sequence
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
@ -70,10 +71,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
|
||||
hashes[-1].append([])
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
seq = Sequence(seq_id,
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
},
|
||||
inputs=token_inputs(prompt_token_ids,
|
||||
prompt=prompt),
|
||||
block_size=block_size,
|
||||
eos_token_id=tokenizer.tokenizer.eos_token_id,
|
||||
lora_request=lora_request)
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, Dict, Generator, List, Optional
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||
from vllm.transformers_utils.detokenizer import (Detokenizer,
|
||||
detokenize_incrementally)
|
||||
@ -169,10 +170,7 @@ def create_sequence(prompt_token_ids=None):
|
||||
prompt_token_ids = prompt_token_ids or [1]
|
||||
return Sequence(
|
||||
seq_id=0,
|
||||
inputs={
|
||||
"prompt": "<s>",
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
},
|
||||
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeIs, TypeVar
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
@ -29,9 +29,9 @@ from vllm.entrypoints.openai.logits_processors import (
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderInputs, InputRegistry, PromptType,
|
||||
TokensPrompt)
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType)
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import get_bad_words_logits_processors
|
||||
@ -638,7 +638,7 @@ class LLMEngine:
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str,
|
||||
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
|
||||
processed_inputs: ProcessorInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
@ -669,18 +669,19 @@ class LLMEngine:
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
decoder_inputs = processed_inputs["decoder"]
|
||||
encoder_inputs = processed_inputs["encoder"]
|
||||
else:
|
||||
decoder_inputs = processed_inputs
|
||||
encoder_inputs = None
|
||||
|
||||
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
|
||||
lora_request, prompt_adapter_request)
|
||||
|
||||
encoder_seq = None
|
||||
if 'encoder_prompt_token_ids' in processed_inputs:
|
||||
encoder_seq = Sequence(seq_id,
|
||||
processed_inputs,
|
||||
block_size,
|
||||
eos_token_id,
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
from_decoder_prompt=False)
|
||||
encoder_seq = (None if encoder_inputs is None else Sequence(
|
||||
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
|
||||
prompt_adapter_request))
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
@ -874,7 +875,7 @@ class LLMEngine:
|
||||
# This needs to happen before multimodal input pre-processing, which
|
||||
# may add dummy <image> tokens that aren't part of the tokenizer's
|
||||
# vocabulary.
|
||||
if self._is_token_prompt(prompt):
|
||||
if is_token_prompt(prompt):
|
||||
prompt_ids = prompt["prompt_token_ids"]
|
||||
if len(prompt_ids) == 0:
|
||||
# Empty prompt check is handled later
|
||||
@ -884,10 +885,6 @@ class LLMEngine:
|
||||
raise ValueError(
|
||||
"Token id {} is out of vocabulary".format(max_input_id))
|
||||
|
||||
@staticmethod
|
||||
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
|
||||
def _create_sequence_group_with_sampling(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -1978,17 +1975,17 @@ class LLMEngine:
|
||||
def is_encoder_decoder_model(self):
|
||||
return self.input_preprocessor.is_encoder_decoder_model()
|
||||
|
||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderInputs],
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs,
|
||||
lora_request: Optional[LoRARequest]):
|
||||
if self.model_config.is_multimodal_model:
|
||||
if is_encoder_decoder_inputs(inputs):
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
prompt_ids = inputs.get("prompt_token_ids")
|
||||
elif self.is_encoder_decoder_model():
|
||||
prompt_ids = inputs.get("encoder_prompt_token_ids")
|
||||
prompt_inputs = inputs["decoder" if self.model_config.
|
||||
is_multimodal_model else "encoder"]
|
||||
else:
|
||||
prompt_ids = inputs.get("prompt_token_ids")
|
||||
prompt_inputs = inputs
|
||||
|
||||
prompt_ids = prompt_inputs.get("prompt_token_ids")
|
||||
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator, List, Mapping, Optional, Union
|
||||
from typing import AsyncGenerator, List, Mapping, Optional
|
||||
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -60,7 +61,7 @@ class EngineClient(ABC):
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: Union[PromptType, List[int]],
|
||||
prompt: PromptType,
|
||||
model_config: ModelConfig,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
@ -76,11 +77,19 @@ class EngineClient(ABC):
|
||||
tokenizer = await self.get_tokenizer()
|
||||
input_preprocessor = InputPreprocessor(model_config, tokenizer)
|
||||
|
||||
(prompt_text, prompt_token_ids, multi_modal_data,
|
||||
mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
processed_inputs = input_preprocessor._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
prompt_token_ids = processed_inputs["prompt_token_ids"]
|
||||
prompt_text = processed_inputs.get("prompt")
|
||||
multi_modal_data = processed_inputs.get("multi_modal_data")
|
||||
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
|
||||
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
sort_beams_key = create_sort_beams_key_function(
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
|
||||
ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
|
||||
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
|
||||
token_inputs, zip_enc_dec_prompts)
|
||||
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
|
||||
TokensPrompt, build_explicit_enc_dec_prompt,
|
||||
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
||||
from .registry import DummyData, InputContext, InputRegistry
|
||||
|
||||
INPUT_REGISTRY = InputRegistry()
|
||||
@ -22,9 +22,10 @@ __all__ = [
|
||||
"ExplicitEncoderDecoderPrompt",
|
||||
"TokenInputs",
|
||||
"token_inputs",
|
||||
"SingletonInputs",
|
||||
"DecoderOnlyInputs",
|
||||
"EncoderDecoderInputs",
|
||||
"ProcessorInputs",
|
||||
"SingletonInputs",
|
||||
"build_explicit_enc_dec_prompt",
|
||||
"to_enc_dec_tuple_list",
|
||||
"zip_enc_dec_prompts",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal,
|
||||
Optional, Tuple, Union, cast)
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||
@ -122,27 +122,30 @@ both decoder-only and encoder/decoder input types:
|
||||
|
||||
class TokenInputs(TypedDict):
|
||||
"""Represents token-based inputs."""
|
||||
|
||||
type: Literal["token"]
|
||||
"""The type of inputs."""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
prompt: NotRequired[Optional[str]]
|
||||
prompt: NotRequired[str]
|
||||
"""
|
||||
The original prompt text corresponding to the token IDs, if available.
|
||||
"""
|
||||
|
||||
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
|
||||
multi_modal_data: NotRequired["MultiModalDataDict"]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
multi_modal_placeholders: NotRequired[
|
||||
Optional["MultiModalPlaceholderDict"]]
|
||||
multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
|
||||
"""
|
||||
Placeholder ranges for the multi-modal data.
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]]
|
||||
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
||||
"""
|
||||
Optional multi-modal processor kwargs to be forwarded to the
|
||||
multimodal input mapper & processor. Note that if multiple modalities
|
||||
@ -159,7 +162,7 @@ def token_inputs(
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> TokenInputs:
|
||||
"""Construct :class:`TokenInputs` from optional values."""
|
||||
inputs = TokenInputs(prompt_token_ids=prompt_token_ids)
|
||||
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
|
||||
|
||||
if prompt is not None:
|
||||
inputs["prompt"] = prompt
|
||||
@ -173,12 +176,6 @@ def token_inputs(
|
||||
return inputs
|
||||
|
||||
|
||||
SingletonInputs = TokenInputs
|
||||
"""
|
||||
A processed :class:`SingletonPrompt` which can be passed to
|
||||
:class:`vllm.sequence.Sequence`.
|
||||
"""
|
||||
|
||||
DecoderOnlyInputs = TokenInputs
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
@ -187,28 +184,30 @@ This specifies the data required for decoder-only models.
|
||||
"""
|
||||
|
||||
|
||||
class EncoderDecoderInputs(TokenInputs):
|
||||
class EncoderDecoderInputs(TypedDict):
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
|
||||
This specifies the required data for encoder-decoder models.
|
||||
"""
|
||||
encoder_prompt_token_ids: List[int]
|
||||
"""The token IDs of the encoder prompt."""
|
||||
encoder: TokenInputs
|
||||
"""The inputs for the encoder portion."""
|
||||
|
||||
encoder_prompt: NotRequired[Optional[str]]
|
||||
"""
|
||||
The original encoder prompt text corresponding to the token IDs, if
|
||||
available.
|
||||
"""
|
||||
decoder: TokenInputs
|
||||
"""The inputs for the decoder portion."""
|
||||
|
||||
encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
|
||||
"""
|
||||
Optional multi-modal data to pass to the encoder model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
SingletonInputs = TokenInputs
|
||||
"""
|
||||
A processed :class:`SingletonPrompt` which can be passed to
|
||||
:class:`vllm.sequence.Sequence`.
|
||||
"""
|
||||
|
||||
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
|
||||
"""
|
||||
The inputs to :data:`vllm.inputs.InputProcessor`.
|
||||
"""
|
||||
|
||||
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
|
||||
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
|
||||
|
||||
@ -4,9 +4,9 @@ from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
|
||||
ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt,
|
||||
TextPrompt, TokensPrompt)
|
||||
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
|
||||
TokensPrompt)
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
@ -98,12 +98,15 @@ def parse_singleton_prompt(
|
||||
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
|
||||
|
||||
|
||||
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
|
||||
|
||||
def is_explicit_encoder_decoder_prompt(
|
||||
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||
|
||||
|
||||
def is_encoder_decoder_inputs(
|
||||
inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
|
||||
) -> TypeIs[EncoderDecoderInputs]:
|
||||
return "encoder_prompt_token_ids" in inputs
|
||||
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
|
||||
return "encoder" in inputs and "decoder" in inputs
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@ -10,22 +10,12 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType,
|
||||
SingletonPrompt)
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
|
||||
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
|
||||
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PromptComponents = Tuple[Optional[str], List[int],
|
||||
Optional["MultiModalDataDict"], Optional[Dict[str,
|
||||
Any]]]
|
||||
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
||||
Optional["MultiModalDataDict"],
|
||||
Optional[Dict[str, Any]]]
|
||||
|
||||
|
||||
class InputPreprocessor:
|
||||
|
||||
@ -115,7 +105,7 @@ class InputPreprocessor:
|
||||
"default" decoder prompt be <BOS>.
|
||||
|
||||
However, it is possible that in the future
|
||||
other models may have different or more
|
||||
other models may have different or more
|
||||
complex logic for the default decoder prompt.
|
||||
This motivates having a special helper method
|
||||
for default decoder prompts.
|
||||
@ -132,7 +122,6 @@ class InputPreprocessor:
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
decoder_input_ids: Optional[List[int]],
|
||||
force_bos: bool = True,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
||||
@ -162,8 +151,8 @@ class InputPreprocessor:
|
||||
# use decoder_start_token_id as decoder_input_ids
|
||||
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
|
||||
|
||||
if force_bos and (len(decoder_input_ids) == 0
|
||||
or decoder_input_ids[0] != decoder_start_token_id):
|
||||
if (len(decoder_input_ids) == 0
|
||||
or decoder_input_ids[0] != decoder_start_token_id):
|
||||
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
||||
|
||||
return decoder_input_ids
|
||||
@ -209,12 +198,12 @@ class InputPreprocessor:
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
def _extract_prompt_components(
|
||||
def _prompt_to_llm_inputs(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PromptComponents:
|
||||
) -> SingletonInputs:
|
||||
'''
|
||||
Extract the components of any single encoder or decoder input prompt.
|
||||
|
||||
@ -241,34 +230,52 @@ class InputPreprocessor:
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
mm_processor_kwargs = None
|
||||
elif parsed["type"] == "tokens":
|
||||
prompt_text = None
|
||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||
elif parsed["type"] == "text":
|
||||
prompt_text = parsed["content"]["prompt"]
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
|
||||
if parsed["type"] == "tokens":
|
||||
tokens_content = parsed["content"]
|
||||
|
||||
prompt_token_ids = tokens_content["prompt_token_ids"]
|
||||
multi_modal_data = tokens_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if parsed["type"] == "text":
|
||||
text_content = parsed["content"]
|
||||
|
||||
prompt_text = text_content["prompt"]
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||
else:
|
||||
assert_never(parsed)
|
||||
multi_modal_data = text_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
|
||||
|
||||
return (prompt_text, prompt_token_ids, multi_modal_data,
|
||||
mm_processor_kwargs)
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
async def _extract_prompt_components_async(
|
||||
assert_never(parsed)
|
||||
|
||||
async def _prompt_to_llm_inputs_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PromptComponents:
|
||||
) -> SingletonInputs:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
@ -279,59 +286,74 @@ class InputPreprocessor:
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
mm_processor_kwargs = None
|
||||
elif parsed["type"] == "tokens":
|
||||
prompt_text = None
|
||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||
elif parsed["type"] == "text":
|
||||
prompt_text = parsed["content"]["prompt"]
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
|
||||
if parsed["type"] == "tokens":
|
||||
tokens_content = parsed["content"]
|
||||
|
||||
prompt_token_ids = tokens_content["prompt_token_ids"]
|
||||
multi_modal_data = tokens_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
|
||||
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if parsed["type"] == "text":
|
||||
text_content = parsed["content"]
|
||||
|
||||
prompt_text = text_content["prompt"]
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
|
||||
else:
|
||||
assert_never(parsed)
|
||||
multi_modal_data = text_content.get("multi_modal_data")
|
||||
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
|
||||
|
||||
return (prompt_text, prompt_token_ids, multi_modal_data,
|
||||
mm_processor_kwargs)
|
||||
return token_inputs(
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
|
||||
def _build_enc_dec_llm_inputs(
|
||||
self,
|
||||
encoder_comps: PromptComponents,
|
||||
decoder_comps: DecoderPromptComponents,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
encoder_inputs: SingletonInputs,
|
||||
decoder_inputs: Optional[SingletonInputs],
|
||||
) -> EncoderDecoderInputs:
|
||||
encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
|
||||
decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
|
||||
if encoder_inputs["type"] == "token":
|
||||
pass
|
||||
else:
|
||||
assert_never(encoder_inputs)
|
||||
|
||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
||||
# If the feature combo become valid
|
||||
if decoder_mm_data is not None:
|
||||
raise ValueError(
|
||||
"Multi-modality decoder inputs of encoder-decoder models are "
|
||||
"not supported yet")
|
||||
if decoder_inputs is None:
|
||||
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
None)
|
||||
decoder_inputs = token_inputs(dec_token_ids)
|
||||
elif decoder_inputs["type"] == "token":
|
||||
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
decoder_inputs["prompt_token_ids"])
|
||||
decoder_inputs["prompt_token_ids"] = dec_token_ids
|
||||
|
||||
# For Multi-Modal models (e.g., mllama), the text input can be
|
||||
# <|image|><|begin_of_text|>hello world. And we should not add
|
||||
# another <|begin_of_text|> to the beginning.
|
||||
decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation(
|
||||
decoder_prompt_ids,
|
||||
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
|
||||
if "multi_modal_data" in decoder_inputs:
|
||||
raise ValueError("Multi-modal decoder inputs of encoder-"
|
||||
"decoder models are not supported yet")
|
||||
else:
|
||||
assert_never(encoder_inputs)
|
||||
|
||||
return EncoderDecoderInputs(
|
||||
prompt_token_ids=decoder_prompt_ids,
|
||||
prompt=decoder_prompt,
|
||||
multi_modal_data=decoder_mm_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
encoder_prompt_token_ids=encoder_prompt_ids,
|
||||
encoder_prompt=encoder_prompt,
|
||||
encoder_multi_modal_data=encoder_mm_data,
|
||||
encoder=encoder_inputs,
|
||||
decoder=decoder_inputs,
|
||||
)
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
@ -341,8 +363,7 @@ class InputPreprocessor:
|
||||
) -> EncoderDecoderInputs:
|
||||
'''
|
||||
For encoder/decoder models only:
|
||||
Process an input prompt into an
|
||||
:class:`EncoderDecoderInputs` instance.
|
||||
Process an input prompt into an :class:`EncoderDecoderInputs` instance.
|
||||
|
||||
There are two types of input prompts:
|
||||
singleton prompts which carry only the
|
||||
@ -361,7 +382,7 @@ class InputPreprocessor:
|
||||
have any possible singleton type; thus this
|
||||
method relies on helper functions to obtain
|
||||
token ids for the sub-prompts.
|
||||
|
||||
|
||||
Arguments:
|
||||
|
||||
* prompt: an input prompt
|
||||
@ -372,40 +393,31 @@ class InputPreprocessor:
|
||||
* :class:`EncoderDecoderInputs` instance
|
||||
'''
|
||||
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
encoder_inputs: SingletonInputs
|
||||
decoder_inputs: Optional[SingletonInputs]
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
encoder_comps = self._extract_prompt_components(
|
||||
encoder_inputs = self._prompt_to_llm_inputs(
|
||||
prompt["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
decoder_comps = None, None, None, None
|
||||
decoder_inputs = None
|
||||
else:
|
||||
decoder_comps = self._extract_prompt_components(
|
||||
decoder_inputs = self._prompt_to_llm_inputs(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
# Handle this carefully in case it was directly initialized by user
|
||||
mm_processor_kwargs = prompt.get("mm_processor_kwargs", {})
|
||||
else:
|
||||
encoder_comps = self._extract_prompt_components(
|
||||
encoder_inputs = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
# If there are no decoder components, we assume the
|
||||
# mm_processor_kwargs are in the encoder prompt
|
||||
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
|
||||
-1] is not None else {}
|
||||
decoder_comps = None, None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(
|
||||
encoder_comps,
|
||||
decoder_comps,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
decoder_inputs = None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
self,
|
||||
@ -413,59 +425,50 @@ class InputPreprocessor:
|
||||
request_id: str,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
encoder_inputs: SingletonInputs
|
||||
decoder_inputs: Optional[SingletonInputs]
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
encoder_task = self._extract_prompt_components_async(
|
||||
encoder_task = self._prompt_to_llm_inputs_async(
|
||||
prompt["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
encoder_comps = await encoder_task
|
||||
decoder_comps = None, None, None, None
|
||||
encoder_inputs = await encoder_task
|
||||
decoder_inputs = None
|
||||
else:
|
||||
decoder_task = self._extract_prompt_components_async(
|
||||
decoder_task = self._prompt_to_llm_inputs_async(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
encoder_comps, decoder_comps = await asyncio.gather(
|
||||
encoder_inputs, decoder_inputs = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
mm_processor_kwargs = prompt["mm_processor_kwargs"]
|
||||
else:
|
||||
encoder_comps = await self._extract_prompt_components_async(
|
||||
encoder_inputs = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
# If there are no decoder components, we assume the
|
||||
# mm_processor_kwargs are in the encoder prompt
|
||||
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
|
||||
-1] is not None else {}
|
||||
decoder_comps = None, None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(
|
||||
encoder_comps,
|
||||
decoder_comps,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
decoder_inputs = None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
def _build_decoder_only_llm_inputs(
|
||||
self,
|
||||
prompt_comps: PromptComponents,
|
||||
prompt_inputs: DecoderOnlyInputs,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> DecoderOnlyInputs:
|
||||
(prompt, prompt_token_ids, multi_modal_data,
|
||||
mm_processor_kwargs) = prompt_comps
|
||||
if prompt_inputs["type"] == "token":
|
||||
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
||||
prompt_inputs["prompt_token_ids"],
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
else:
|
||||
assert_never(prompt_inputs)
|
||||
|
||||
prompt_token_ids = self._apply_prompt_adapter(
|
||||
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
return prompt_inputs
|
||||
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
@ -490,7 +493,7 @@ class InputPreprocessor:
|
||||
* :class:`DecoderOnlyInputs` instance
|
||||
'''
|
||||
|
||||
prompt_comps = self._extract_prompt_components(
|
||||
prompt_comps = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
@ -509,7 +512,7 @@ class InputPreprocessor:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||
prompt_comps = await self._extract_prompt_components_async(
|
||||
prompt_comps = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
@ -526,7 +529,7 @@ class InputPreprocessor:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
|
||||
) -> ProcessorInputs:
|
||||
"""Preprocess the input prompt."""
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
@ -554,7 +557,7 @@ class InputPreprocessor:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
|
||||
) -> ProcessorInputs:
|
||||
"""Async version of :meth:`preprocess`."""
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
|
||||
@ -2,7 +2,7 @@ import functools
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
|
||||
Optional, Protocol, Type)
|
||||
Optional, Protocol, Type, cast)
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
@ -12,7 +12,7 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
|
||||
resolve_mm_processor_kwargs)
|
||||
|
||||
from .data import DecoderOnlyInputs
|
||||
from .data import ProcessorInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -109,7 +109,7 @@ class _MultiModalCounts(UserDict):
|
||||
raise KeyError(msg) from exc
|
||||
|
||||
|
||||
InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs]
|
||||
InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
|
||||
"""Preprocess the inputs to the model."""
|
||||
|
||||
|
||||
@ -254,8 +254,8 @@ class InputRegistry:
|
||||
def _default_input_processor(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs,
|
||||
) -> DecoderOnlyInputs:
|
||||
inputs: ProcessorInputs,
|
||||
) -> ProcessorInputs:
|
||||
"""The default input processor is a no-op."""
|
||||
return inputs
|
||||
|
||||
@ -288,7 +288,7 @@ class InputRegistry:
|
||||
.get(model_cls, self._default_input_processor)
|
||||
|
||||
def process_input(self, model_config: "ModelConfig",
|
||||
inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
|
||||
inputs: ProcessorInputs) -> ProcessorInputs:
|
||||
"""
|
||||
Apply an input processor to an instance of model inputs.
|
||||
|
||||
@ -308,7 +308,7 @@ class InputRegistry:
|
||||
# If it's empty, it'll fall back to the default kwarg values
|
||||
mm_processor_kwargs = resolve_mm_processor_kwargs(
|
||||
model_config.mm_processor_kwargs,
|
||||
inputs.get("mm_processor_kwargs"),
|
||||
cast(Dict[str, Any], inputs.get("mm_processor_kwargs")),
|
||||
processor,
|
||||
)
|
||||
|
||||
|
||||
@ -36,8 +36,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
EncoderDecoderInputs, InputContext)
|
||||
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
|
||||
InputContext, TokenInputs, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -52,6 +52,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import CLIPMLP
|
||||
from .interfaces import SupportsMultiModal
|
||||
@ -86,41 +87,58 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
|
||||
return num_images
|
||||
|
||||
|
||||
def input_processor_for_mllama(ctx: InputContext,
|
||||
inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderInputs]):
|
||||
# move encoder_prompt to prompt
|
||||
if inputs.get("prompt") is None:
|
||||
inputs["prompt"] = inputs["encoder_prompt"]
|
||||
inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"]
|
||||
def input_processor_for_mllama(
|
||||
ctx: InputContext,
|
||||
inputs: EncoderDecoderInputs,
|
||||
) -> EncoderDecoderInputs:
|
||||
# Example input to processor:
|
||||
# {
|
||||
# 'encoder': {
|
||||
# 'type': 'token',
|
||||
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
|
||||
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
|
||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
||||
# },
|
||||
# 'decoder': {
|
||||
# 'type': 'token',
|
||||
# 'prompt_token_ids': [128000],
|
||||
# },
|
||||
# }
|
||||
|
||||
# process multi-modal data
|
||||
multi_modal_data = inputs.get("encoder_multi_modal_data")
|
||||
# move encoder prompt to decoder
|
||||
dec_inputs = TokenInputs(**inputs["encoder"])
|
||||
|
||||
if multi_modal_data is None or "image" not in multi_modal_data \
|
||||
or multi_modal_data["image"] is None:
|
||||
multi_modal_data = dec_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
# text-only
|
||||
inputs["encoder_prompt"] = ""
|
||||
inputs["encoder_prompt_token_ids"] = []
|
||||
inputs["encoder_multi_modal_data"] = {}
|
||||
return inputs
|
||||
return EncoderDecoderInputs(
|
||||
encoder=token_inputs([]),
|
||||
decoder=dec_inputs,
|
||||
)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
image_data = [image_data]
|
||||
|
||||
assert is_list_of(image_data, Image.Image)
|
||||
|
||||
if isinstance(multi_modal_data['image'], Image.Image):
|
||||
multi_modal_data['image'] = [multi_modal_data['image']]
|
||||
# Since only the last group of consecutive images
|
||||
# are attended by the decoded tokens, we only need to
|
||||
# get the number of tiles for those images.
|
||||
num_decode_images = _get_num_image_in_last_group(
|
||||
inputs["prompt_token_ids"])
|
||||
dec_inputs["prompt_token_ids"])
|
||||
|
||||
hf_config = ctx.model_config.hf_config
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
num_tiles = 0
|
||||
for image in multi_modal_data["image"][::-1]:
|
||||
for image in image_data[::-1]:
|
||||
width, height = image.size
|
||||
tile_size = hf_config.vision_config.image_size
|
||||
tile_size = vision_config.image_size
|
||||
canvas_height, canvas_width = get_optimal_tiled_canvas(
|
||||
image_height=height,
|
||||
image_width=width,
|
||||
max_image_tiles=hf_config.vision_config.max_num_tiles,
|
||||
max_image_tiles=vision_config.max_num_tiles,
|
||||
tile_size=tile_size,
|
||||
)
|
||||
num_tiles_height = canvas_height // tile_size
|
||||
@ -133,14 +151,34 @@ def input_processor_for_mllama(ctx: InputContext,
|
||||
# Set encoder prompt length based on the number of tiles.
|
||||
# This tells the block manager to allocate correct number
|
||||
# of slots for encoder tokens.
|
||||
assert hf_config.vision_config.image_size % 14 == 0, \
|
||||
assert vision_config.image_size % 14 == 0, \
|
||||
"chunk size should be multiple of 14"
|
||||
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
|
||||
token_per_chunk = (vision_config.image_size // 14)**2 + 1
|
||||
num_tokens = num_tiles * token_per_chunk
|
||||
inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
|
||||
inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens
|
||||
|
||||
return inputs
|
||||
# Example output from processor:
|
||||
# {
|
||||
# 'encoder': {
|
||||
# 'type': 'token',
|
||||
# 'prompt_token_ids': [128256, 128256, ..., 128256],
|
||||
# 'prompt': '<|image|><|image|>...<|image|>',
|
||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
||||
# },
|
||||
# 'decoder': {
|
||||
# 'type': 'token',
|
||||
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
|
||||
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
|
||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
||||
# },
|
||||
# }
|
||||
return EncoderDecoderInputs(
|
||||
encoder=token_inputs(
|
||||
prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens,
|
||||
prompt=MLLAMA_IMAGE_TOKEN * num_tokens,
|
||||
multi_modal_data=multi_modal_data,
|
||||
),
|
||||
decoder=dec_inputs,
|
||||
)
|
||||
|
||||
|
||||
def get_max_mllama_image_tokens(ctx: InputContext) -> int:
|
||||
|
||||
@ -343,6 +343,11 @@ class _ModelRegistry:
|
||||
def _raise_for_unsupported(self, architectures: List[str]):
|
||||
all_supported_archs = self.get_supported_archs()
|
||||
|
||||
if any(arch in all_supported_archs for arch in architectures):
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} failed "
|
||||
"to be inspected. Please check the logs for more details.")
|
||||
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {all_supported_archs}")
|
||||
|
||||
117
vllm/sequence.py
117
vllm/sequence.py
@ -9,12 +9,12 @@ from functools import cached_property, reduce
|
||||
from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List,
|
||||
Mapping, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Union, cast
|
||||
from typing import Set, Tuple, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -379,15 +379,10 @@ class SequenceData(msgspec.Struct,
|
||||
|
||||
class Sequence:
|
||||
"""Stores the data, status, and block information of a sequence.
|
||||
|
||||
The sequence is constructed from the :code:`SingletonInputs` instance
|
||||
passed in through the :code:`inputs` constructor argument.
|
||||
|
||||
For encoder/decoder models, SingletonInputs encapsulates both a
|
||||
decoder and encoder prompt, creating an ambiguity about which
|
||||
prompt to construct the sequence from. The `from_decoder_prompt`
|
||||
constructor argument signals whether to construct the Sequence
|
||||
from the SingletonInputs decoder prompt, or encoder prompt.
|
||||
|
||||
The sequence is constructed from the :data:`DecoderOnlyInputs`
|
||||
(for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
|
||||
instance passed in through the :code:`inputs` constructor argument.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
@ -397,10 +392,6 @@ class Sequence:
|
||||
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
|
||||
lora_request: LoRA request.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
from_decoder_prompt: Construct Sequence from SingletonInputs decoder
|
||||
prompt (True) or encoder prompt (False.) Must be
|
||||
True for decoder-only model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -411,7 +402,6 @@ class Sequence:
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
from_decoder_prompt: bool = True,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.inputs = inputs
|
||||
@ -419,33 +409,6 @@ class Sequence:
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.from_decoder_prompt = from_decoder_prompt
|
||||
|
||||
# For decoder-only models, a Sequence is constructed
|
||||
# from an DecoderOnlyInputs instance (the `inputs` arg.)
|
||||
#
|
||||
# For encoder/decoder models the same `inputs`
|
||||
# instance could be utilized to construct either an
|
||||
# encoder sequence or a decoder sequence, because
|
||||
# `DecoderOnlyInputs` has both decoder- and encoder-oriented
|
||||
# member variables (i.e. it encapsulates both an encoder
|
||||
# and a decoder prompt.) The decision of which type of sequence
|
||||
# to generate is determined by the `from_decoder_prompt` argument.
|
||||
#
|
||||
# When constructing a encoder sequence
|
||||
# (`from_decoder_prompt` False) it matters that
|
||||
# the `DecoderOnlyInputs` instance stored in `inputs` is valid
|
||||
# in the sense that its encoder-related member variables are
|
||||
# populated; below, an exception is raised if this is
|
||||
# not the case.
|
||||
#
|
||||
# When constructing a decoder sequence (`from_decoder_prompt` True)
|
||||
# it does not matter whether `inputs` has its encoder-related
|
||||
# member variables populated.
|
||||
if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)):
|
||||
raise ValueError("Cannot extract encoder input prompt from "
|
||||
f"invalid input {inputs}; did you forget the "
|
||||
"encoder input prompt fields?")
|
||||
|
||||
self.data = SequenceData.from_seqs(self.prompt_token_ids)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
@ -470,45 +433,57 @@ class Sequence:
|
||||
|
||||
@cached_property
|
||||
def prompt(self) -> Optional[str]:
|
||||
# Select decoder or encoder input prompt str, as appropriate
|
||||
prompt_key: str = ("prompt"
|
||||
if self.from_decoder_prompt else "encoder_prompt")
|
||||
inputs = self.inputs
|
||||
|
||||
return cast(Optional[str], self.inputs.get(prompt_key))
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("prompt")
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
# Select decoder or encoder input prompt token ids, as appropriate
|
||||
prompt_token_ids_key: str = ("prompt_token_ids"
|
||||
if self.from_decoder_prompt else
|
||||
"encoder_prompt_token_ids")
|
||||
|
||||
# Cache computed prompt token ids
|
||||
return cast(List[int], self.inputs.get(prompt_token_ids_key))
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> MultiModalDataDict:
|
||||
inputs = self.inputs
|
||||
|
||||
if (inputs.get("multi_modal_data")
|
||||
and inputs.get("encoder_multi_modal_data")):
|
||||
raise ValueError(
|
||||
"Multi-modal data in both encoder and decoder is not supported."
|
||||
)
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("prompt_token_ids", [])
|
||||
|
||||
return cast(
|
||||
MultiModalDataDict,
|
||||
(inputs.get("multi_modal_data")
|
||||
or inputs.get("encoder_multi_modal_data") or {}),
|
||||
)
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return None
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_data", {})
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("mm_processor_kwargs", {})
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@property
|
||||
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
|
||||
return self.inputs.get("multi_modal_placeholders") or {}
|
||||
inputs = self.inputs
|
||||
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
return self.inputs.get("mm_processor_kwargs") or {}
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_placeholders", {})
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user