[Core] Make encoder-decoder inputs a nested structure to be more composable (#9604)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-11-05 10:07:31 +08:00 committed by GitHub
parent 04bbf38e05
commit bbc3619dc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 369 additions and 346 deletions

View File

@ -4,6 +4,7 @@ from typing import Sequence as GenericSequence
from typing import Tuple from typing import Tuple
from vllm import SamplingParams from vllm import SamplingParams
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup from vllm.sequence import Logprob, Sequence, SequenceGroup
@ -27,10 +28,7 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length)) prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt = Sequence(int(request_id),
inputs={ inputs=token_inputs(prompt_tokens, prompt=prompt_str),
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
},
block_size=block_size) block_size=block_size)
seq_group = SequenceGroup(request_id=request_id, seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt], seqs=[prompt],
@ -63,23 +61,21 @@ def create_dummy_prompt_encoder_decoder(
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
inputs = { inputs: EncoderDecoderInputs = {
"prompt": decoder_prompt_str, "decoder": token_inputs(decoder_prompt_tokens,
"prompt_token_ids": decoder_prompt_tokens, prompt=decoder_prompt_str),
"encoder_prompt": encoder_prompt_str, "encoder": token_inputs(encoder_prompt_tokens,
"encoder_prompt_token_ids": encoder_prompt_tokens, prompt=encoder_prompt_str),
"multi_modal_data": None,
} }
decoder_prompt = Sequence(int(request_id), decoder_prompt = Sequence(int(request_id),
inputs=inputs, inputs=inputs["decoder"],
block_size=block_size, block_size=block_size)
from_decoder_prompt=True)
encoder_prompt = Sequence(int(request_id), encoder_prompt = Sequence(int(request_id),
inputs=inputs, inputs=inputs["encoder"],
block_size=block_size, block_size=block_size)
from_decoder_prompt=False)
seq_group = SequenceGroup(request_id=request_id, seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt], seqs=[decoder_prompt],
sampling_params=SamplingParams(best_of=best_of), sampling_params=SamplingParams(best_of=best_of),
@ -108,7 +104,7 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens): for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence( seq = Sequence(
seq_id=seq_id_start + seq_id_offset, seq_id=seq_id_start + seq_id_offset,
inputs={"prompt_token_ids": prompt_token_ids}, inputs=token_inputs(prompt_token_ids),
block_size=16, block_size=16,
) )
@ -143,21 +139,19 @@ def create_seq_group_encoder_decoder(
prompt_token_ids = [0] * seq_prompt_len prompt_token_ids = [0] * seq_prompt_len
inputs = { inputs: EncoderDecoderInputs = {
"prompt": "", "decoder": token_inputs(prompt_token_ids),
"prompt_token_ids": prompt_token_ids, "encoder": token_inputs(prompt_token_ids),
"encoder_prompt": "",
"encoder_prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
} }
seqs = [] seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens): for seq_id_offset, output_len in enumerate(seq_output_lens):
# Construct decoder input sequences # Construct decoder input sequences
seq = Sequence(seq_id=seq_id_start + seq_id_offset, seq = Sequence(
inputs=inputs, seq_id=seq_id_start + seq_id_offset,
block_size=16, inputs=inputs["decoder"],
from_decoder_prompt=True) block_size=16,
)
for i in range(output_len): for i in range(output_len):
seq.append_token_id( seq.append_token_id(
@ -167,10 +161,11 @@ def create_seq_group_encoder_decoder(
seqs.append(seq) seqs.append(seq)
# Encoder input sequence # Encoder input sequence
encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens), encoder_seq = Sequence(
inputs=inputs, seq_id=seq_id_start + len(seq_output_lens),
block_size=16, inputs=inputs["encoder"],
from_decoder_prompt=False) block_size=16,
)
return SequenceGroup(request_id=request_id, return SequenceGroup(request_id=request_id,
seqs=seqs, seqs=seqs,

View File

@ -4,6 +4,7 @@ import pytest
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, Sequence, SequenceStatus from vllm.sequence import Logprob, Sequence, SequenceStatus
@ -15,7 +16,7 @@ def sequence_with_eos(text: str, eos_token: str,
""" """
seq = Sequence( seq = Sequence(
seq_id=0, seq_id=0,
inputs={"prompt_token_ids": []}, inputs=token_inputs([]),
block_size=16, block_size=16,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
) )

View File

@ -6,6 +6,7 @@ from typing import List, Optional
import pytest import pytest
from vllm.inputs import token_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import Sequence from vllm.sequence import Sequence
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -70,10 +71,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
hashes[-1].append([]) hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, seq = Sequence(seq_id,
inputs={ inputs=token_inputs(prompt_token_ids,
"prompt": prompt, prompt=prompt),
"prompt_token_ids": prompt_token_ids,
},
block_size=block_size, block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id, eos_token_id=tokenizer.tokenizer.eos_token_id,
lora_request=lora_request) lora_request=lora_request)

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, Generator, List, Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import (Detokenizer, from vllm.transformers_utils.detokenizer import (Detokenizer,
detokenize_incrementally) detokenize_incrementally)
@ -169,10 +170,7 @@ def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1] prompt_token_ids = prompt_token_ids or [1]
return Sequence( return Sequence(
seq_id=0, seq_id=0,
inputs={ inputs=token_inputs(prompt_token_ids, prompt="<s>"),
"prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
},
block_size=16, block_size=16,
) )

View File

@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload from typing import Set, Type, Union, cast, overload
import torch import torch
from typing_extensions import TypeIs, TypeVar from typing_extensions import TypeVar
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
@ -29,9 +29,9 @@ from vllm.entrypoints.openai.logits_processors import (
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
EncoderDecoderInputs, InputRegistry, PromptType, PromptType)
TokensPrompt) from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors from vllm.logits_process import get_bad_words_logits_processors
@ -638,7 +638,7 @@ class LLMEngine:
def _add_processed_request( def _add_processed_request(
self, self,
request_id: str, request_id: str,
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
@ -669,18 +669,19 @@ class LLMEngine:
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = processed_inputs["decoder"]
encoder_inputs = processed_inputs["encoder"]
else:
decoder_inputs = processed_inputs
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request) lora_request, prompt_adapter_request)
encoder_seq = None encoder_seq = (None if encoder_inputs is None else Sequence(
if 'encoder_prompt_token_ids' in processed_inputs: seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
encoder_seq = Sequence(seq_id, prompt_adapter_request))
processed_inputs,
block_size,
eos_token_id,
lora_request,
prompt_adapter_request,
from_decoder_prompt=False)
# Create a SequenceGroup based on SamplingParams or PoolingParams # Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams): if isinstance(params, SamplingParams):
@ -874,7 +875,7 @@ class LLMEngine:
# This needs to happen before multimodal input pre-processing, which # This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's # may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary. # vocabulary.
if self._is_token_prompt(prompt): if is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"] prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0: if len(prompt_ids) == 0:
# Empty prompt check is handled later # Empty prompt check is handled later
@ -884,10 +885,6 @@ class LLMEngine:
raise ValueError( raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id)) "Token id {} is out of vocabulary".format(max_input_id))
@staticmethod
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def _create_sequence_group_with_sampling( def _create_sequence_group_with_sampling(
self, self,
request_id: str, request_id: str,
@ -1978,17 +1975,17 @@ class LLMEngine:
def is_encoder_decoder_model(self): def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model() return self.input_preprocessor.is_encoder_decoder_model()
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, def _validate_model_inputs(self, inputs: ProcessorInputs,
EncoderDecoderInputs],
lora_request: Optional[LoRARequest]): lora_request: Optional[LoRARequest]):
if self.model_config.is_multimodal_model: if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len # For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length # restricts the decoder prompt length
prompt_ids = inputs.get("prompt_token_ids") prompt_inputs = inputs["decoder" if self.model_config.
elif self.is_encoder_decoder_model(): is_multimodal_model else "encoder"]
prompt_ids = inputs.get("encoder_prompt_token_ids")
else: else:
prompt_ids = inputs.get("prompt_token_ids") prompt_inputs = inputs
prompt_ids = prompt_inputs.get("prompt_token_ids")
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")

View File

@ -1,11 +1,12 @@
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional, Union from typing import AsyncGenerator, List, Mapping, Optional
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -60,7 +61,7 @@ class EngineClient(ABC):
async def beam_search( async def beam_search(
self, self,
prompt: Union[PromptType, List[int]], prompt: PromptType,
model_config: ModelConfig, model_config: ModelConfig,
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
@ -76,11 +77,19 @@ class EngineClient(ABC):
tokenizer = await self.get_tokenizer() tokenizer = await self.get_tokenizer()
input_preprocessor = InputPreprocessor(model_config, tokenizer) input_preprocessor = InputPreprocessor(model_config, tokenizer)
(prompt_text, prompt_token_ids, multi_modal_data, if is_explicit_encoder_decoder_prompt(prompt):
mm_processor_kwargs) = input_preprocessor._extract_prompt_components( raise NotImplementedError
prompt, else:
request_id=request_id, processed_inputs = input_preprocessor._prompt_to_llm_inputs(
) prompt,
request_id=request_id,
)
prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
tokenized_length = len(prompt_token_ids) tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function( sort_beams_key = create_sort_beams_key_function(

View File

@ -1,8 +1,8 @@
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, TokensPrompt, build_explicit_enc_dec_prompt,
token_inputs, zip_enc_dec_prompts) to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import DummyData, InputContext, InputRegistry from .registry import DummyData, InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
@ -22,9 +22,10 @@ __all__ = [
"ExplicitEncoderDecoderPrompt", "ExplicitEncoderDecoderPrompt",
"TokenInputs", "TokenInputs",
"token_inputs", "token_inputs",
"SingletonInputs",
"DecoderOnlyInputs", "DecoderOnlyInputs",
"EncoderDecoderInputs", "EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"build_explicit_enc_dec_prompt", "build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list", "to_enc_dec_tuple_list",
"zip_enc_dec_prompts", "zip_enc_dec_prompts",

View File

@ -1,4 +1,4 @@
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal,
Optional, Tuple, Union, cast) Optional, Tuple, Union, cast)
from typing_extensions import NotRequired, TypedDict, TypeVar from typing_extensions import NotRequired, TypedDict, TypeVar
@ -122,27 +122,30 @@ both decoder-only and encoder/decoder input types:
class TokenInputs(TypedDict): class TokenInputs(TypedDict):
"""Represents token-based inputs.""" """Represents token-based inputs."""
type: Literal["token"]
"""The type of inputs."""
prompt_token_ids: List[int] prompt_token_ids: List[int]
"""The token IDs of the prompt.""" """The token IDs of the prompt."""
prompt: NotRequired[Optional[str]] prompt: NotRequired[str]
""" """
The original prompt text corresponding to the token IDs, if available. The original prompt text corresponding to the token IDs, if available.
""" """
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] multi_modal_data: NotRequired["MultiModalDataDict"]
""" """
Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
""" """
multi_modal_placeholders: NotRequired[ multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
Optional["MultiModalPlaceholderDict"]]
""" """
Placeholder ranges for the multi-modal data. Placeholder ranges for the multi-modal data.
""" """
mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] mm_processor_kwargs: NotRequired[Dict[str, Any]]
""" """
Optional multi-modal processor kwargs to be forwarded to the Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities multimodal input mapper & processor. Note that if multiple modalities
@ -159,7 +162,7 @@ def token_inputs(
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs: ) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values.""" """Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(prompt_token_ids=prompt_token_ids) inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
if prompt is not None: if prompt is not None:
inputs["prompt"] = prompt inputs["prompt"] = prompt
@ -173,12 +176,6 @@ def token_inputs(
return inputs return inputs
SingletonInputs = TokenInputs
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
DecoderOnlyInputs = TokenInputs DecoderOnlyInputs = TokenInputs
""" """
The inputs in :class:`~vllm.LLMEngine` before they are The inputs in :class:`~vllm.LLMEngine` before they are
@ -187,28 +184,30 @@ This specifies the data required for decoder-only models.
""" """
class EncoderDecoderInputs(TokenInputs): class EncoderDecoderInputs(TypedDict):
""" """
The inputs in :class:`~vllm.LLMEngine` before they are The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor. passed to the model executor.
This specifies the required data for encoder-decoder models. This specifies the required data for encoder-decoder models.
""" """
encoder_prompt_token_ids: List[int] encoder: TokenInputs
"""The token IDs of the encoder prompt.""" """The inputs for the encoder portion."""
encoder_prompt: NotRequired[Optional[str]] decoder: TokenInputs
""" """The inputs for the decoder portion."""
The original encoder prompt text corresponding to the token IDs, if
available.
"""
encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the encoder model,
if the model supports it.
"""
SingletonInputs = TokenInputs
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
"""
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) _T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)

View File

@ -4,9 +4,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt, ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
TextPrompt, TokensPrompt) TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
@ -98,12 +98,15 @@ def parse_singleton_prompt(
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def is_explicit_encoder_decoder_prompt( def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_encoder_decoder_inputs( def is_encoder_decoder_inputs(
inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
) -> TypeIs[EncoderDecoderInputs]: return "encoder" in inputs and "decoder" in inputs
return "encoder_prompt_token_ids" in inputs

View File

@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import List, Optional
from typing_extensions import assert_never from typing_extensions import assert_never
@ -10,22 +10,12 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType, from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
SingletonPrompt) PromptType, SingletonInputs, SingletonPrompt, token_inputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
logger = init_logger(__name__) logger = init_logger(__name__)
PromptComponents = Tuple[Optional[str], List[int],
Optional["MultiModalDataDict"], Optional[Dict[str,
Any]]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional["MultiModalDataDict"],
Optional[Dict[str, Any]]]
class InputPreprocessor: class InputPreprocessor:
@ -115,7 +105,7 @@ class InputPreprocessor:
"default" decoder prompt be <BOS>. "default" decoder prompt be <BOS>.
However, it is possible that in the future However, it is possible that in the future
other models may have different or more other models may have different or more
complex logic for the default decoder prompt. complex logic for the default decoder prompt.
This motivates having a special helper method This motivates having a special helper method
for default decoder prompts. for default decoder prompts.
@ -132,7 +122,6 @@ class InputPreprocessor:
def _prepare_decoder_input_ids_for_generation( def _prepare_decoder_input_ids_for_generation(
self, self,
decoder_input_ids: Optional[List[int]], decoder_input_ids: Optional[List[int]],
force_bos: bool = True,
) -> List[int]: ) -> List[int]:
""" """
Prepares `decoder_input_ids` for generation with encoder-decoder models. Prepares `decoder_input_ids` for generation with encoder-decoder models.
@ -162,8 +151,8 @@ class InputPreprocessor:
# use decoder_start_token_id as decoder_input_ids # use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt() decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if force_bos and (len(decoder_input_ids) == 0 if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id): or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids return decoder_input_ids
@ -209,12 +198,12 @@ class InputPreprocessor:
prompt=prompt, prompt=prompt,
lora_request=lora_request) lora_request=lora_request)
def _extract_prompt_components( def _prompt_to_llm_inputs(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> PromptComponents: ) -> SingletonInputs:
''' '''
Extract the components of any single encoder or decoder input prompt. Extract the components of any single encoder or decoder input prompt.
@ -241,34 +230,52 @@ class InputPreprocessor:
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = None
mm_processor_kwargs = None return token_inputs(
elif parsed["type"] == "tokens": prompt=prompt_text,
prompt_text = None prompt_token_ids=prompt_token_ids,
prompt_token_ids = parsed["content"]["prompt_token_ids"] )
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") if parsed["type"] == "tokens":
elif parsed["type"] == "text": tokens_content = parsed["content"]
prompt_text = parsed["content"]["prompt"]
prompt_token_ids = tokens_content["prompt_token_ids"]
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
return token_inputs(
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
if parsed["type"] == "text":
text_content = parsed["content"]
prompt_text = text_content["prompt"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") mm_processor_kwargs = text_content.get("mm_processor_kwargs")
else:
assert_never(parsed)
return (prompt_text, prompt_token_ids, multi_modal_data, return token_inputs(
mm_processor_kwargs) prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
async def _extract_prompt_components_async( assert_never(parsed)
async def _prompt_to_llm_inputs_async(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> PromptComponents: ) -> SingletonInputs:
"""Async version of :meth:`_extract_prompt_components`.""" """Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
@ -279,59 +286,74 @@ class InputPreprocessor:
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = None
mm_processor_kwargs = None return token_inputs(
elif parsed["type"] == "tokens": prompt=prompt_text,
prompt_text = None prompt_token_ids=prompt_token_ids,
prompt_token_ids = parsed["content"]["prompt_token_ids"] )
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") if parsed["type"] == "tokens":
elif parsed["type"] == "text": tokens_content = parsed["content"]
prompt_text = parsed["content"]["prompt"]
prompt_token_ids = tokens_content["prompt_token_ids"]
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
return token_inputs(
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
if parsed["type"] == "text":
text_content = parsed["content"]
prompt_text = text_content["prompt"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") mm_processor_kwargs = text_content.get("mm_processor_kwargs")
else:
assert_never(parsed)
return (prompt_text, prompt_token_ids, multi_modal_data, return token_inputs(
mm_processor_kwargs) prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
assert_never(parsed)
def _build_enc_dec_llm_inputs( def _build_enc_dec_llm_inputs(
self, self,
encoder_comps: PromptComponents, encoder_inputs: SingletonInputs,
decoder_comps: DecoderPromptComponents, decoder_inputs: Optional[SingletonInputs],
mm_processor_kwargs: Dict[str, Any],
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps if encoder_inputs["type"] == "token":
decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps pass
else:
assert_never(encoder_inputs)
# Reminder: Please update docs/source/serving/compatibility_matrix.rst if decoder_inputs is None:
# If the feature combo become valid dec_token_ids = self._prepare_decoder_input_ids_for_generation(
if decoder_mm_data is not None: None)
raise ValueError( decoder_inputs = token_inputs(dec_token_ids)
"Multi-modality decoder inputs of encoder-decoder models are " elif decoder_inputs["type"] == "token":
"not supported yet") dec_token_ids = self._prepare_decoder_input_ids_for_generation(
decoder_inputs["prompt_token_ids"])
decoder_inputs["prompt_token_ids"] = dec_token_ids
# For Multi-Modal models (e.g., mllama), the text input can be if "multi_modal_data" in decoder_inputs:
# <|image|><|begin_of_text|>hello world. And we should not add raise ValueError("Multi-modal decoder inputs of encoder-"
# another <|begin_of_text|> to the beginning. "decoder models are not supported yet")
decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation( else:
decoder_prompt_ids, assert_never(encoder_inputs)
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
return EncoderDecoderInputs( return EncoderDecoderInputs(
prompt_token_ids=decoder_prompt_ids, encoder=encoder_inputs,
prompt=decoder_prompt, decoder=decoder_inputs,
multi_modal_data=decoder_mm_data,
mm_processor_kwargs=mm_processor_kwargs,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
encoder_multi_modal_data=encoder_mm_data,
) )
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
@ -341,8 +363,7 @@ class InputPreprocessor:
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
''' '''
For encoder/decoder models only: For encoder/decoder models only:
Process an input prompt into an Process an input prompt into an :class:`EncoderDecoderInputs` instance.
:class:`EncoderDecoderInputs` instance.
There are two types of input prompts: There are two types of input prompts:
singleton prompts which carry only the singleton prompts which carry only the
@ -361,7 +382,7 @@ class InputPreprocessor:
have any possible singleton type; thus this have any possible singleton type; thus this
method relies on helper functions to obtain method relies on helper functions to obtain
token ids for the sub-prompts. token ids for the sub-prompts.
Arguments: Arguments:
* prompt: an input prompt * prompt: an input prompt
@ -372,40 +393,31 @@ class InputPreprocessor:
* :class:`EncoderDecoderInputs` instance * :class:`EncoderDecoderInputs` instance
''' '''
encoder_comps: PromptComponents encoder_inputs: SingletonInputs
decoder_comps: DecoderPromptComponents decoder_inputs: Optional[SingletonInputs]
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
encoder_comps = self._extract_prompt_components( encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"], prompt["encoder_prompt"],
request_id=request_id, request_id=request_id,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_comps = None, None, None, None decoder_inputs = None
else: else:
decoder_comps = self._extract_prompt_components( decoder_inputs = self._prompt_to_llm_inputs(
decoder_input, decoder_input,
request_id=request_id, request_id=request_id,
) )
# Handle this carefully in case it was directly initialized by user
mm_processor_kwargs = prompt.get("mm_processor_kwargs", {})
else: else:
encoder_comps = self._extract_prompt_components( encoder_inputs = self._prompt_to_llm_inputs(
prompt, prompt,
request_id=request_id, request_id=request_id,
) )
# If there are no decoder components, we assume the
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
-1] is not None else {}
decoder_comps = None, None, None, None
return self._build_enc_dec_llm_inputs( decoder_inputs = None
encoder_comps,
decoder_comps, return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
mm_processor_kwargs,
)
async def _process_encoder_decoder_prompt_async( async def _process_encoder_decoder_prompt_async(
self, self,
@ -413,59 +425,50 @@ class InputPreprocessor:
request_id: str, request_id: str,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`.""" """Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents encoder_inputs: SingletonInputs
decoder_comps: DecoderPromptComponents decoder_inputs: Optional[SingletonInputs]
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._extract_prompt_components_async( encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"], prompt["encoder_prompt"],
request_id=request_id, request_id=request_id,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_comps = await encoder_task encoder_inputs = await encoder_task
decoder_comps = None, None, None, None decoder_inputs = None
else: else:
decoder_task = self._extract_prompt_components_async( decoder_task = self._prompt_to_llm_inputs_async(
decoder_input, decoder_input,
request_id=request_id, request_id=request_id,
) )
encoder_comps, decoder_comps = await asyncio.gather( encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task) encoder_task, decoder_task)
mm_processor_kwargs = prompt["mm_processor_kwargs"]
else: else:
encoder_comps = await self._extract_prompt_components_async( encoder_inputs = await self._prompt_to_llm_inputs_async(
prompt, prompt,
request_id=request_id, request_id=request_id,
) )
# If there are no decoder components, we assume the
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
-1] is not None else {}
decoder_comps = None, None, None, None
return self._build_enc_dec_llm_inputs( decoder_inputs = None
encoder_comps,
decoder_comps, return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
mm_processor_kwargs,
)
def _build_decoder_only_llm_inputs( def _build_decoder_only_llm_inputs(
self, self,
prompt_comps: PromptComponents, prompt_inputs: DecoderOnlyInputs,
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
(prompt, prompt_token_ids, multi_modal_data, if prompt_inputs["type"] == "token":
mm_processor_kwargs) = prompt_comps prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request,
)
else:
assert_never(prompt_inputs)
prompt_token_ids = self._apply_prompt_adapter( return prompt_inputs
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
@ -490,7 +493,7 @@ class InputPreprocessor:
* :class:`DecoderOnlyInputs` instance * :class:`DecoderOnlyInputs` instance
''' '''
prompt_comps = self._extract_prompt_components( prompt_comps = self._prompt_to_llm_inputs(
prompt, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
@ -509,7 +512,7 @@ class InputPreprocessor:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
"""Async version of :meth:`_process_decoder_only_prompt`.""" """Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async( prompt_comps = await self._prompt_to_llm_inputs_async(
prompt, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
@ -526,7 +529,7 @@ class InputPreprocessor:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: ) -> ProcessorInputs:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
if self.is_encoder_decoder_model(): if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
@ -554,7 +557,7 @@ class InputPreprocessor:
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]: ) -> ProcessorInputs:
"""Async version of :meth:`preprocess`.""" """Async version of :meth:`preprocess`."""
if self.is_encoder_decoder_model(): if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of

View File

@ -2,7 +2,7 @@ import functools
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Optional, Protocol, Type) Optional, Protocol, Type, cast)
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -12,7 +12,7 @@ from vllm.logger import init_logger
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs) resolve_mm_processor_kwargs)
from .data import DecoderOnlyInputs from .data import ProcessorInputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -109,7 +109,7 @@ class _MultiModalCounts(UserDict):
raise KeyError(msg) from exc raise KeyError(msg) from exc
InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs] InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
"""Preprocess the inputs to the model.""" """Preprocess the inputs to the model."""
@ -254,8 +254,8 @@ class InputRegistry:
def _default_input_processor( def _default_input_processor(
self, self,
ctx: InputContext, ctx: InputContext,
inputs: DecoderOnlyInputs, inputs: ProcessorInputs,
) -> DecoderOnlyInputs: ) -> ProcessorInputs:
"""The default input processor is a no-op.""" """The default input processor is a no-op."""
return inputs return inputs
@ -288,7 +288,7 @@ class InputRegistry:
.get(model_cls, self._default_input_processor) .get(model_cls, self._default_input_processor)
def process_input(self, model_config: "ModelConfig", def process_input(self, model_config: "ModelConfig",
inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: inputs: ProcessorInputs) -> ProcessorInputs:
""" """
Apply an input processor to an instance of model inputs. Apply an input processor to an instance of model inputs.
@ -308,7 +308,7 @@ class InputRegistry:
# If it's empty, it'll fall back to the default kwarg values # If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs = resolve_mm_processor_kwargs( mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs, model_config.mm_processor_kwargs,
inputs.get("mm_processor_kwargs"), cast(Dict[str, Any], inputs.get("mm_processor_kwargs")),
processor, processor,
) )

View File

@ -36,8 +36,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
EncoderDecoderInputs, InputContext) InputContext, TokenInputs, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -52,6 +52,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
from vllm.utils import is_list_of
from .clip import CLIPMLP from .clip import CLIPMLP
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
@ -86,41 +87,58 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
return num_images return num_images
def input_processor_for_mllama(ctx: InputContext, def input_processor_for_mllama(
inputs: Union[DecoderOnlyInputs, ctx: InputContext,
EncoderDecoderInputs]): inputs: EncoderDecoderInputs,
# move encoder_prompt to prompt ) -> EncoderDecoderInputs:
if inputs.get("prompt") is None: # Example input to processor:
inputs["prompt"] = inputs["encoder_prompt"] # {
inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"] # 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000],
# },
# }
# process multi-modal data # move encoder prompt to decoder
multi_modal_data = inputs.get("encoder_multi_modal_data") dec_inputs = TokenInputs(**inputs["encoder"])
if multi_modal_data is None or "image" not in multi_modal_data \ multi_modal_data = dec_inputs.get("multi_modal_data")
or multi_modal_data["image"] is None: if multi_modal_data is None or "image" not in multi_modal_data:
# text-only # text-only
inputs["encoder_prompt"] = "" return EncoderDecoderInputs(
inputs["encoder_prompt_token_ids"] = [] encoder=token_inputs([]),
inputs["encoder_multi_modal_data"] = {} decoder=dec_inputs,
return inputs )
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_data = [image_data]
assert is_list_of(image_data, Image.Image)
if isinstance(multi_modal_data['image'], Image.Image):
multi_modal_data['image'] = [multi_modal_data['image']]
# Since only the last group of consecutive images # Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to # are attended by the decoded tokens, we only need to
# get the number of tiles for those images. # get the number of tiles for those images.
num_decode_images = _get_num_image_in_last_group( num_decode_images = _get_num_image_in_last_group(
inputs["prompt_token_ids"]) dec_inputs["prompt_token_ids"])
hf_config = ctx.model_config.hf_config hf_config = ctx.model_config.hf_config
vision_config = hf_config.vision_config
num_tiles = 0 num_tiles = 0
for image in multi_modal_data["image"][::-1]: for image in image_data[::-1]:
width, height = image.size width, height = image.size
tile_size = hf_config.vision_config.image_size tile_size = vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas( canvas_height, canvas_width = get_optimal_tiled_canvas(
image_height=height, image_height=height,
image_width=width, image_width=width,
max_image_tiles=hf_config.vision_config.max_num_tiles, max_image_tiles=vision_config.max_num_tiles,
tile_size=tile_size, tile_size=tile_size,
) )
num_tiles_height = canvas_height // tile_size num_tiles_height = canvas_height // tile_size
@ -133,14 +151,34 @@ def input_processor_for_mllama(ctx: InputContext,
# Set encoder prompt length based on the number of tiles. # Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number # This tells the block manager to allocate correct number
# of slots for encoder tokens. # of slots for encoder tokens.
assert hf_config.vision_config.image_size % 14 == 0, \ assert vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14" "chunk size should be multiple of 14"
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 token_per_chunk = (vision_config.image_size // 14)**2 + 1
num_tokens = num_tiles * token_per_chunk num_tokens = num_tiles * token_per_chunk
inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens
return inputs # Example output from processor:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128256, ..., 128256],
# 'prompt': '<|image|><|image|>...<|image|>',
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# }
return EncoderDecoderInputs(
encoder=token_inputs(
prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens,
prompt=MLLAMA_IMAGE_TOKEN * num_tokens,
multi_modal_data=multi_modal_data,
),
decoder=dec_inputs,
)
def get_max_mllama_image_tokens(ctx: InputContext) -> int: def get_max_mllama_image_tokens(ctx: InputContext) -> int:

View File

@ -343,6 +343,11 @@ class _ModelRegistry:
def _raise_for_unsupported(self, architectures: List[str]): def _raise_for_unsupported(self, architectures: List[str]):
all_supported_archs = self.get_supported_archs() all_supported_archs = self.get_supported_archs()
if any(arch in all_supported_archs for arch in architectures):
raise ValueError(
f"Model architectures {architectures} failed "
"to be inspected. Please check the logs for more details.")
raise ValueError( raise ValueError(
f"Model architectures {architectures} are not supported for now. " f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {all_supported_archs}") f"Supported architectures: {all_supported_archs}")

View File

@ -9,12 +9,12 @@ from functools import cached_property, reduce
from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List,
Mapping, Optional) Mapping, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast from typing import Set, Tuple, Union
import msgspec import msgspec
import torch import torch
from typing_extensions import assert_never
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
@ -379,15 +379,10 @@ class SequenceData(msgspec.Struct,
class Sequence: class Sequence:
"""Stores the data, status, and block information of a sequence. """Stores the data, status, and block information of a sequence.
The sequence is constructed from the :code:`SingletonInputs` instance The sequence is constructed from the :data:`DecoderOnlyInputs`
passed in through the :code:`inputs` constructor argument. (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
instance passed in through the :code:`inputs` constructor argument.
For encoder/decoder models, SingletonInputs encapsulates both a
decoder and encoder prompt, creating an ambiguity about which
prompt to construct the sequence from. The `from_decoder_prompt`
constructor argument signals whether to construct the Sequence
from the SingletonInputs decoder prompt, or encoder prompt.
Args: Args:
seq_id: The ID of the sequence. seq_id: The ID of the sequence.
@ -397,10 +392,6 @@ class Sequence:
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
lora_request: LoRA request. lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
from_decoder_prompt: Construct Sequence from SingletonInputs decoder
prompt (True) or encoder prompt (False.) Must be
True for decoder-only model.
""" """
def __init__( def __init__(
@ -411,7 +402,6 @@ class Sequence:
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
from_decoder_prompt: bool = True,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.inputs = inputs self.inputs = inputs
@ -419,33 +409,6 @@ class Sequence:
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.from_decoder_prompt = from_decoder_prompt
# For decoder-only models, a Sequence is constructed
# from an DecoderOnlyInputs instance (the `inputs` arg.)
#
# For encoder/decoder models the same `inputs`
# instance could be utilized to construct either an
# encoder sequence or a decoder sequence, because
# `DecoderOnlyInputs` has both decoder- and encoder-oriented
# member variables (i.e. it encapsulates both an encoder
# and a decoder prompt.) The decision of which type of sequence
# to generate is determined by the `from_decoder_prompt` argument.
#
# When constructing a encoder sequence
# (`from_decoder_prompt` False) it matters that
# the `DecoderOnlyInputs` instance stored in `inputs` is valid
# in the sense that its encoder-related member variables are
# populated; below, an exception is raised if this is
# not the case.
#
# When constructing a decoder sequence (`from_decoder_prompt` True)
# it does not matter whether `inputs` has its encoder-related
# member variables populated.
if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)):
raise ValueError("Cannot extract encoder input prompt from "
f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?")
self.data = SequenceData.from_seqs(self.prompt_token_ids) self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
@ -470,45 +433,57 @@ class Sequence:
@cached_property @cached_property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
# Select decoder or encoder input prompt str, as appropriate inputs = self.inputs
prompt_key: str = ("prompt"
if self.from_decoder_prompt else "encoder_prompt")
return cast(Optional[str], self.inputs.get(prompt_key)) if inputs["type"] == "token":
return inputs.get("prompt")
assert_never(inputs)
@cached_property @cached_property
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> List[int]:
# Select decoder or encoder input prompt token ids, as appropriate
prompt_token_ids_key: str = ("prompt_token_ids"
if self.from_decoder_prompt else
"encoder_prompt_token_ids")
# Cache computed prompt token ids
return cast(List[int], self.inputs.get(prompt_token_ids_key))
@property
def multi_modal_data(self) -> MultiModalDataDict:
inputs = self.inputs inputs = self.inputs
if (inputs.get("multi_modal_data") if inputs["type"] == "token":
and inputs.get("encoder_multi_modal_data")): return inputs.get("prompt_token_ids", [])
raise ValueError(
"Multi-modal data in both encoder and decoder is not supported."
)
return cast( assert_never(inputs)
MultiModalDataDict,
(inputs.get("multi_modal_data") @cached_property
or inputs.get("encoder_multi_modal_data") or {}), def prompt_embeds(self) -> Optional[torch.Tensor]:
) inputs = self.inputs
if inputs["type"] == "token":
return None
assert_never(inputs)
@cached_property
def multi_modal_data(self) -> "MultiModalDataDict":
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_data", {})
assert_never(inputs)
@cached_property
def mm_processor_kwargs(self) -> Dict[str, Any]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("mm_processor_kwargs", {})
assert_never(inputs)
@property @property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
return self.inputs.get("multi_modal_placeholders") or {} inputs = self.inputs
@property if inputs["type"] == "token":
def mm_processor_kwargs(self) -> Dict[str, Any]: return inputs.get("multi_modal_placeholders", {})
return self.inputs.get("mm_processor_kwargs") or {}
assert_never(inputs)
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int: