[Core] Make encoder-decoder inputs a nested structure to be more composable (#9604)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-11-05 10:07:31 +08:00 committed by GitHub
parent 04bbf38e05
commit bbc3619dc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 369 additions and 346 deletions

View File

@ -4,6 +4,7 @@ from typing import Sequence as GenericSequence
from typing import Tuple
from vllm import SamplingParams
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup
@ -27,10 +28,7 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id),
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
},
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt],
@ -63,23 +61,21 @@ def create_dummy_prompt_encoder_decoder(
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])
inputs = {
"prompt": decoder_prompt_str,
"prompt_token_ids": decoder_prompt_tokens,
"encoder_prompt": encoder_prompt_str,
"encoder_prompt_token_ids": encoder_prompt_tokens,
"multi_modal_data": None,
inputs: EncoderDecoderInputs = {
"decoder": token_inputs(decoder_prompt_tokens,
prompt=decoder_prompt_str),
"encoder": token_inputs(encoder_prompt_tokens,
prompt=encoder_prompt_str),
}
decoder_prompt = Sequence(int(request_id),
inputs=inputs,
block_size=block_size,
from_decoder_prompt=True)
inputs=inputs["decoder"],
block_size=block_size)
encoder_prompt = Sequence(int(request_id),
inputs=inputs,
block_size=block_size,
from_decoder_prompt=False)
inputs=inputs["encoder"],
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt],
sampling_params=SamplingParams(best_of=best_of),
@ -108,7 +104,7 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs={"prompt_token_ids": prompt_token_ids},
inputs=token_inputs(prompt_token_ids),
block_size=16,
)
@ -143,21 +139,19 @@ def create_seq_group_encoder_decoder(
prompt_token_ids = [0] * seq_prompt_len
inputs = {
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"encoder_prompt": "",
"encoder_prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
inputs: EncoderDecoderInputs = {
"decoder": token_inputs(prompt_token_ids),
"encoder": token_inputs(prompt_token_ids),
}
seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
# Construct decoder input sequences
seq = Sequence(seq_id=seq_id_start + seq_id_offset,
inputs=inputs,
block_size=16,
from_decoder_prompt=True)
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs=inputs["decoder"],
block_size=16,
)
for i in range(output_len):
seq.append_token_id(
@ -167,10 +161,11 @@ def create_seq_group_encoder_decoder(
seqs.append(seq)
# Encoder input sequence
encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
inputs=inputs,
block_size=16,
from_decoder_prompt=False)
encoder_seq = Sequence(
seq_id=seq_id_start + len(seq_output_lens),
inputs=inputs["encoder"],
block_size=16,
)
return SequenceGroup(request_id=request_id,
seqs=seqs,

View File

@ -4,6 +4,7 @@ import pytest
from transformers import PreTrainedTokenizer
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, Sequence, SequenceStatus
@ -15,7 +16,7 @@ def sequence_with_eos(text: str, eos_token: str,
"""
seq = Sequence(
seq_id=0,
inputs={"prompt_token_ids": []},
inputs=token_inputs([]),
block_size=16,
eos_token_id=eos_token_id,
)

View File

@ -6,6 +6,7 @@ from typing import List, Optional
import pytest
from vllm.inputs import token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Sequence
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -70,10 +71,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id,
inputs={
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
},
inputs=token_inputs(prompt_token_ids,
prompt=prompt),
block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id,
lora_request=lora_request)

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, Generator, List, Optional
import pytest
from transformers import AutoTokenizer
from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import (Detokenizer,
detokenize_incrementally)
@ -169,10 +170,7 @@ def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
return Sequence(
seq_id=0,
inputs={
"prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
},
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
block_size=16,
)

View File

@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
import torch
from typing_extensions import TypeIs, TypeVar
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
@ -29,9 +29,9 @@ from vllm.entrypoints.openai.logits_processors import (
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType,
TokensPrompt)
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
@ -638,7 +638,7 @@ class LLMEngine:
def _add_processed_request(
self,
request_id: str,
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
@ -669,18 +669,19 @@ class LLMEngine:
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = processed_inputs["decoder"]
encoder_inputs = processed_inputs["encoder"]
else:
decoder_inputs = processed_inputs
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
encoder_seq = None
if 'encoder_prompt_token_ids' in processed_inputs:
encoder_seq = Sequence(seq_id,
processed_inputs,
block_size,
eos_token_id,
lora_request,
prompt_adapter_request,
from_decoder_prompt=False)
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
prompt_adapter_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
@ -874,7 +875,7 @@ class LLMEngine:
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if self._is_token_prompt(prompt):
if is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
@ -884,10 +885,6 @@ class LLMEngine:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
@staticmethod
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def _create_sequence_group_with_sampling(
self,
request_id: str,
@ -1978,17 +1975,17 @@ class LLMEngine:
def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model()
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs],
def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]):
if self.model_config.is_multimodal_model:
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids = inputs.get("prompt_token_ids")
elif self.is_encoder_decoder_model():
prompt_ids = inputs.get("encoder_prompt_token_ids")
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_ids = inputs.get("prompt_token_ids")
prompt_inputs = inputs
prompt_ids = prompt_inputs.get("prompt_token_ids")
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")

View File

@ -1,11 +1,12 @@
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional, Union
from typing import AsyncGenerator, List, Mapping, Optional
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -60,7 +61,7 @@ class EngineClient(ABC):
async def beam_search(
self,
prompt: Union[PromptType, List[int]],
prompt: PromptType,
model_config: ModelConfig,
request_id: str,
params: BeamSearchParams,
@ -76,11 +77,19 @@ class EngineClient(ABC):
tokenizer = await self.get_tokenizer()
input_preprocessor = InputPreprocessor(model_config, tokenizer)
(prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
else:
processed_inputs = input_preprocessor._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")
mm_processor_kwargs = processed_inputs.get("mm_processor_kwargs")
tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function(

View File

@ -1,8 +1,8 @@
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts)
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import DummyData, InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
@ -22,9 +22,10 @@ __all__ = [
"ExplicitEncoderDecoderPrompt",
"TokenInputs",
"token_inputs",
"SingletonInputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",

View File

@ -1,4 +1,4 @@
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal,
Optional, Tuple, Union, cast)
from typing_extensions import NotRequired, TypedDict, TypeVar
@ -122,27 +122,30 @@ both decoder-only and encoder/decoder input types:
class TokenInputs(TypedDict):
"""Represents token-based inputs."""
type: Literal["token"]
"""The type of inputs."""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
prompt: NotRequired[Optional[str]]
prompt: NotRequired[str]
"""
The original prompt text corresponding to the token IDs, if available.
"""
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
multi_modal_placeholders: NotRequired[
Optional["MultiModalPlaceholderDict"]]
multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
"""
Placeholder ranges for the multi-modal data.
"""
mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]]
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
@ -159,7 +162,7 @@ def token_inputs(
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(prompt_token_ids=prompt_token_ids)
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
if prompt is not None:
inputs["prompt"] = prompt
@ -173,12 +176,6 @@ def token_inputs(
return inputs
SingletonInputs = TokenInputs
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
DecoderOnlyInputs = TokenInputs
"""
The inputs in :class:`~vllm.LLMEngine` before they are
@ -187,28 +184,30 @@ This specifies the data required for decoder-only models.
"""
class EncoderDecoderInputs(TokenInputs):
class EncoderDecoderInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder_prompt_token_ids: List[int]
"""The token IDs of the encoder prompt."""
encoder: TokenInputs
"""The inputs for the encoder portion."""
encoder_prompt: NotRequired[Optional[str]]
"""
The original encoder prompt text corresponding to the token IDs, if
available.
"""
decoder: TokenInputs
"""The inputs for the decoder portion."""
encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the encoder model,
if the model supports it.
"""
SingletonInputs = TokenInputs
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
"""
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)

View File

@ -4,9 +4,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt,
TextPrompt, TokensPrompt)
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
class ParsedText(TypedDict):
@ -98,12 +98,15 @@ def parse_singleton_prompt(
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_encoder_decoder_inputs(
inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
) -> TypeIs[EncoderDecoderInputs]:
return "encoder_prompt_token_ids" in inputs
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
return "encoder" in inputs and "decoder" in inputs

View File

@ -1,5 +1,5 @@
import asyncio
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import List, Optional
from typing_extensions import assert_never
@ -10,22 +10,12 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_warning_once
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType,
SingletonPrompt)
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
logger = init_logger(__name__)
PromptComponents = Tuple[Optional[str], List[int],
Optional["MultiModalDataDict"], Optional[Dict[str,
Any]]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional["MultiModalDataDict"],
Optional[Dict[str, Any]]]
class InputPreprocessor:
@ -115,7 +105,7 @@ class InputPreprocessor:
"default" decoder prompt be <BOS>.
However, it is possible that in the future
other models may have different or more
other models may have different or more
complex logic for the default decoder prompt.
This motivates having a special helper method
for default decoder prompts.
@ -132,7 +122,6 @@ class InputPreprocessor:
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]],
force_bos: bool = True,
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
@ -162,8 +151,8 @@ class InputPreprocessor:
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if force_bos and (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
@ -209,12 +198,12 @@ class InputPreprocessor:
prompt=prompt,
lora_request=lora_request)
def _extract_prompt_components(
def _prompt_to_llm_inputs(
self,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
) -> SingletonInputs:
'''
Extract the components of any single encoder or decoder input prompt.
@ -241,34 +230,52 @@ class InputPreprocessor:
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
mm_processor_kwargs = None
elif parsed["type"] == "tokens":
prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
elif parsed["type"] == "text":
prompt_text = parsed["content"]["prompt"]
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if parsed["type"] == "tokens":
tokens_content = parsed["content"]
prompt_token_ids = tokens_content["prompt_token_ids"]
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
return token_inputs(
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
if parsed["type"] == "text":
text_content = parsed["content"]
prompt_text = text_content["prompt"]
prompt_token_ids = self._tokenize_prompt(
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
else:
assert_never(parsed)
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
return (prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
async def _extract_prompt_components_async(
assert_never(parsed)
async def _prompt_to_llm_inputs_async(
self,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
) -> SingletonInputs:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt)
@ -279,59 +286,74 @@ class InputPreprocessor:
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
mm_processor_kwargs = None
elif parsed["type"] == "tokens":
prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
elif parsed["type"] == "text":
prompt_text = parsed["content"]["prompt"]
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if parsed["type"] == "tokens":
tokens_content = parsed["content"]
prompt_token_ids = tokens_content["prompt_token_ids"]
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
return token_inputs(
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
if parsed["type"] == "text":
text_content = parsed["content"]
prompt_text = text_content["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = parsed["content"].get("multi_modal_data")
mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs")
else:
assert_never(parsed)
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
return (prompt_text, prompt_token_ids, multi_modal_data,
mm_processor_kwargs)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs,
)
assert_never(parsed)
def _build_enc_dec_llm_inputs(
self,
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
mm_processor_kwargs: Dict[str, Any],
encoder_inputs: SingletonInputs,
decoder_inputs: Optional[SingletonInputs],
) -> EncoderDecoderInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
if encoder_inputs["type"] == "token":
pass
else:
assert_never(encoder_inputs)
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
if decoder_mm_data is not None:
raise ValueError(
"Multi-modality decoder inputs of encoder-decoder models are "
"not supported yet")
if decoder_inputs is None:
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
None)
decoder_inputs = token_inputs(dec_token_ids)
elif decoder_inputs["type"] == "token":
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
decoder_inputs["prompt_token_ids"])
decoder_inputs["prompt_token_ids"] = dec_token_ids
# For Multi-Modal models (e.g., mllama), the text input can be
# <|image|><|begin_of_text|>hello world. And we should not add
# another <|begin_of_text|> to the beginning.
decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation(
decoder_prompt_ids,
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
if "multi_modal_data" in decoder_inputs:
raise ValueError("Multi-modal decoder inputs of encoder-"
"decoder models are not supported yet")
else:
assert_never(encoder_inputs)
return EncoderDecoderInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
multi_modal_data=decoder_mm_data,
mm_processor_kwargs=mm_processor_kwargs,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
encoder_multi_modal_data=encoder_mm_data,
encoder=encoder_inputs,
decoder=decoder_inputs,
)
def _process_encoder_decoder_prompt(
@ -341,8 +363,7 @@ class InputPreprocessor:
) -> EncoderDecoderInputs:
'''
For encoder/decoder models only:
Process an input prompt into an
:class:`EncoderDecoderInputs` instance.
Process an input prompt into an :class:`EncoderDecoderInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
@ -361,7 +382,7 @@ class InputPreprocessor:
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments:
* prompt: an input prompt
@ -372,40 +393,31 @@ class InputPreprocessor:
* :class:`EncoderDecoderInputs` instance
'''
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
encoder_inputs: SingletonInputs
decoder_inputs: Optional[SingletonInputs]
if is_explicit_encoder_decoder_prompt(prompt):
encoder_comps = self._extract_prompt_components(
encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_comps = None, None, None, None
decoder_inputs = None
else:
decoder_comps = self._extract_prompt_components(
decoder_inputs = self._prompt_to_llm_inputs(
decoder_input,
request_id=request_id,
)
# Handle this carefully in case it was directly initialized by user
mm_processor_kwargs = prompt.get("mm_processor_kwargs", {})
else:
encoder_comps = self._extract_prompt_components(
encoder_inputs = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
# If there are no decoder components, we assume the
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
-1] is not None else {}
decoder_comps = None, None, None, None
return self._build_enc_dec_llm_inputs(
encoder_comps,
decoder_comps,
mm_processor_kwargs,
)
decoder_inputs = None
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
async def _process_encoder_decoder_prompt_async(
self,
@ -413,59 +425,50 @@ class InputPreprocessor:
request_id: str,
) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
encoder_inputs: SingletonInputs
decoder_inputs: Optional[SingletonInputs]
if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._extract_prompt_components_async(
encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None, None
encoder_inputs = await encoder_task
decoder_inputs = None
else:
decoder_task = self._extract_prompt_components_async(
decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)
mm_processor_kwargs = prompt["mm_processor_kwargs"]
else:
encoder_comps = await self._extract_prompt_components_async(
encoder_inputs = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
)
# If there are no decoder components, we assume the
# mm_processor_kwargs are in the encoder prompt
mm_processor_kwargs = encoder_comps[-1] if encoder_comps[
-1] is not None else {}
decoder_comps = None, None, None, None
return self._build_enc_dec_llm_inputs(
encoder_comps,
decoder_comps,
mm_processor_kwargs,
)
decoder_inputs = None
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
def _build_decoder_only_llm_inputs(
self,
prompt_comps: PromptComponents,
prompt_inputs: DecoderOnlyInputs,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> DecoderOnlyInputs:
(prompt, prompt_token_ids, multi_modal_data,
mm_processor_kwargs) = prompt_comps
if prompt_inputs["type"] == "token":
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request,
)
else:
assert_never(prompt_inputs)
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
mm_processor_kwargs=mm_processor_kwargs)
return prompt_inputs
def _process_decoder_only_prompt(
self,
@ -490,7 +493,7 @@ class InputPreprocessor:
* :class:`DecoderOnlyInputs` instance
'''
prompt_comps = self._extract_prompt_components(
prompt_comps = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
lora_request=lora_request,
@ -509,7 +512,7 @@ class InputPreprocessor:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> DecoderOnlyInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
lora_request=lora_request,
@ -526,7 +529,7 @@ class InputPreprocessor:
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
@ -554,7 +557,7 @@ class InputPreprocessor:
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
) -> ProcessorInputs:
"""Async version of :meth:`preprocess`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of

View File

@ -2,7 +2,7 @@ import functools
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Optional, Protocol, Type)
Optional, Protocol, Type, cast)
from torch import nn
from transformers import PretrainedConfig
@ -12,7 +12,7 @@ from vllm.logger import init_logger
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs)
from .data import DecoderOnlyInputs
from .data import ProcessorInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -109,7 +109,7 @@ class _MultiModalCounts(UserDict):
raise KeyError(msg) from exc
InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs]
InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
"""Preprocess the inputs to the model."""
@ -254,8 +254,8 @@ class InputRegistry:
def _default_input_processor(
self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
) -> DecoderOnlyInputs:
inputs: ProcessorInputs,
) -> ProcessorInputs:
"""The default input processor is a no-op."""
return inputs
@ -288,7 +288,7 @@ class InputRegistry:
.get(model_cls, self._default_input_processor)
def process_input(self, model_config: "ModelConfig",
inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
inputs: ProcessorInputs) -> ProcessorInputs:
"""
Apply an input processor to an instance of model inputs.
@ -308,7 +308,7 @@ class InputRegistry:
# If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs,
inputs.get("mm_processor_kwargs"),
cast(Dict[str, Any], inputs.get("mm_processor_kwargs")),
processor,
)

View File

@ -36,8 +36,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
EncoderDecoderInputs, InputContext)
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
InputContext, TokenInputs, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -52,6 +52,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SequenceData
from vllm.utils import is_list_of
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
@ -86,41 +87,58 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
return num_images
def input_processor_for_mllama(ctx: InputContext,
inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
# move encoder_prompt to prompt
if inputs.get("prompt") is None:
inputs["prompt"] = inputs["encoder_prompt"]
inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"]
def input_processor_for_mllama(
ctx: InputContext,
inputs: EncoderDecoderInputs,
) -> EncoderDecoderInputs:
# Example input to processor:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000],
# },
# }
# process multi-modal data
multi_modal_data = inputs.get("encoder_multi_modal_data")
# move encoder prompt to decoder
dec_inputs = TokenInputs(**inputs["encoder"])
if multi_modal_data is None or "image" not in multi_modal_data \
or multi_modal_data["image"] is None:
multi_modal_data = dec_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
# text-only
inputs["encoder_prompt"] = ""
inputs["encoder_prompt_token_ids"] = []
inputs["encoder_multi_modal_data"] = {}
return inputs
return EncoderDecoderInputs(
encoder=token_inputs([]),
decoder=dec_inputs,
)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_data = [image_data]
assert is_list_of(image_data, Image.Image)
if isinstance(multi_modal_data['image'], Image.Image):
multi_modal_data['image'] = [multi_modal_data['image']]
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images = _get_num_image_in_last_group(
inputs["prompt_token_ids"])
dec_inputs["prompt_token_ids"])
hf_config = ctx.model_config.hf_config
vision_config = hf_config.vision_config
num_tiles = 0
for image in multi_modal_data["image"][::-1]:
for image in image_data[::-1]:
width, height = image.size
tile_size = hf_config.vision_config.image_size
tile_size = vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas(
image_height=height,
image_width=width,
max_image_tiles=hf_config.vision_config.max_num_tiles,
max_image_tiles=vision_config.max_num_tiles,
tile_size=tile_size,
)
num_tiles_height = canvas_height // tile_size
@ -133,14 +151,34 @@ def input_processor_for_mllama(ctx: InputContext,
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert hf_config.vision_config.image_size % 14 == 0, \
assert vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14"
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
token_per_chunk = (vision_config.image_size // 14)**2 + 1
num_tokens = num_tiles * token_per_chunk
inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens
return inputs
# Example output from processor:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128256, ..., 128256],
# 'prompt': '<|image|><|image|>...<|image|>',
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# }
return EncoderDecoderInputs(
encoder=token_inputs(
prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens,
prompt=MLLAMA_IMAGE_TOKEN * num_tokens,
multi_modal_data=multi_modal_data,
),
decoder=dec_inputs,
)
def get_max_mllama_image_tokens(ctx: InputContext) -> int:

View File

@ -343,6 +343,11 @@ class _ModelRegistry:
def _raise_for_unsupported(self, architectures: List[str]):
all_supported_archs = self.get_supported_archs()
if any(arch in all_supported_archs for arch in architectures):
raise ValueError(
f"Model architectures {architectures} failed "
"to be inspected. Please check the logs for more details.")
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {all_supported_archs}")

View File

@ -9,12 +9,12 @@ from functools import cached_property, reduce
from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
from typing import Set, Tuple, Union
import msgspec
import torch
from typing_extensions import assert_never
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams
@ -379,15 +379,10 @@ class SequenceData(msgspec.Struct,
class Sequence:
"""Stores the data, status, and block information of a sequence.
The sequence is constructed from the :code:`SingletonInputs` instance
passed in through the :code:`inputs` constructor argument.
For encoder/decoder models, SingletonInputs encapsulates both a
decoder and encoder prompt, creating an ambiguity about which
prompt to construct the sequence from. The `from_decoder_prompt`
constructor argument signals whether to construct the Sequence
from the SingletonInputs decoder prompt, or encoder prompt.
The sequence is constructed from the :data:`DecoderOnlyInputs`
(for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
instance passed in through the :code:`inputs` constructor argument.
Args:
seq_id: The ID of the sequence.
@ -397,10 +392,6 @@ class Sequence:
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request.
from_decoder_prompt: Construct Sequence from SingletonInputs decoder
prompt (True) or encoder prompt (False.) Must be
True for decoder-only model.
"""
def __init__(
@ -411,7 +402,6 @@ class Sequence:
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
from_decoder_prompt: bool = True,
) -> None:
self.seq_id = seq_id
self.inputs = inputs
@ -419,33 +409,6 @@ class Sequence:
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.from_decoder_prompt = from_decoder_prompt
# For decoder-only models, a Sequence is constructed
# from an DecoderOnlyInputs instance (the `inputs` arg.)
#
# For encoder/decoder models the same `inputs`
# instance could be utilized to construct either an
# encoder sequence or a decoder sequence, because
# `DecoderOnlyInputs` has both decoder- and encoder-oriented
# member variables (i.e. it encapsulates both an encoder
# and a decoder prompt.) The decision of which type of sequence
# to generate is determined by the `from_decoder_prompt` argument.
#
# When constructing a encoder sequence
# (`from_decoder_prompt` False) it matters that
# the `DecoderOnlyInputs` instance stored in `inputs` is valid
# in the sense that its encoder-related member variables are
# populated; below, an exception is raised if this is
# not the case.
#
# When constructing a decoder sequence (`from_decoder_prompt` True)
# it does not matter whether `inputs` has its encoder-related
# member variables populated.
if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)):
raise ValueError("Cannot extract encoder input prompt from "
f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?")
self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
@ -470,45 +433,57 @@ class Sequence:
@cached_property
def prompt(self) -> Optional[str]:
# Select decoder or encoder input prompt str, as appropriate
prompt_key: str = ("prompt"
if self.from_decoder_prompt else "encoder_prompt")
inputs = self.inputs
return cast(Optional[str], self.inputs.get(prompt_key))
if inputs["type"] == "token":
return inputs.get("prompt")
assert_never(inputs)
@cached_property
def prompt_token_ids(self) -> List[int]:
# Select decoder or encoder input prompt token ids, as appropriate
prompt_token_ids_key: str = ("prompt_token_ids"
if self.from_decoder_prompt else
"encoder_prompt_token_ids")
# Cache computed prompt token ids
return cast(List[int], self.inputs.get(prompt_token_ids_key))
@property
def multi_modal_data(self) -> MultiModalDataDict:
inputs = self.inputs
if (inputs.get("multi_modal_data")
and inputs.get("encoder_multi_modal_data")):
raise ValueError(
"Multi-modal data in both encoder and decoder is not supported."
)
if inputs["type"] == "token":
return inputs.get("prompt_token_ids", [])
return cast(
MultiModalDataDict,
(inputs.get("multi_modal_data")
or inputs.get("encoder_multi_modal_data") or {}),
)
assert_never(inputs)
@cached_property
def prompt_embeds(self) -> Optional[torch.Tensor]:
inputs = self.inputs
if inputs["type"] == "token":
return None
assert_never(inputs)
@cached_property
def multi_modal_data(self) -> "MultiModalDataDict":
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_data", {})
assert_never(inputs)
@cached_property
def mm_processor_kwargs(self) -> Dict[str, Any]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("mm_processor_kwargs", {})
assert_never(inputs)
@property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
return self.inputs.get("multi_modal_placeholders") or {}
inputs = self.inputs
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.inputs.get("mm_processor_kwargs") or {}
if inputs["type"] == "token":
return inputs.get("multi_modal_placeholders", {})
assert_never(inputs)
@property
def lora_int_id(self) -> int: