mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 04:55:01 +08:00
[Core] Rename input data types (#8688)
This commit is contained in:
parent
1de76a0e55
commit
cee711fdbb
@ -25,7 +25,7 @@ Module Contents
|
||||
LLM Engine Inputs
|
||||
-----------------
|
||||
|
||||
.. autoclass:: vllm.inputs.LLMInputs
|
||||
.. autoclass:: vllm.inputs.DecoderOnlyInputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoImageProcessor, AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputContext, LLMInputs
|
||||
from vllm.inputs import InputContext, token_inputs
|
||||
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
@ -311,7 +311,7 @@ def test_input_mapper_override(model: str, image_assets: _ImageAssets,
|
||||
(4, 781),
|
||||
(16, 2653),
|
||||
])
|
||||
def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
|
||||
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
|
||||
num_crops: int, expected_max_tokens: int):
|
||||
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
|
||||
# NOTE: mm_processor_kwargs on the context in this test is unused, since
|
||||
@ -343,8 +343,8 @@ def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
|
||||
(16, 2653, 1),
|
||||
(16, 2653, 2),
|
||||
])
|
||||
def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
|
||||
num_crops: int, toks_per_img: int, num_imgs: int):
|
||||
def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
|
||||
toks_per_img: int, num_imgs: int):
|
||||
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
@ -374,7 +374,7 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
|
||||
(16, 1921, 1),
|
||||
(16, 1921, 2),
|
||||
])
|
||||
def test_input_processor_override(input_processor_for_phi3v: Callable,
|
||||
def test_input_processor_override(input_processor_for_phi3v,
|
||||
image_assets: _ImageAssets, model: str,
|
||||
num_crops: int, expected_toks_per_img: int,
|
||||
num_imgs: int):
|
||||
@ -393,16 +393,14 @@ def test_input_processor_override(input_processor_for_phi3v: Callable,
|
||||
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
|
||||
images = [image_assets[0].pil_image] * num_imgs
|
||||
|
||||
llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt),
|
||||
prompt=prompt,
|
||||
multi_modal_data={"image": images})
|
||||
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
|
||||
prompt=prompt,
|
||||
multi_modal_data={"image": images})
|
||||
|
||||
proc_llm_inputs = input_processor_for_phi3v(
|
||||
ctx=ctx,
|
||||
llm_inputs=llm_inputs,
|
||||
num_crops=num_crops,
|
||||
)
|
||||
processed_inputs = input_processor_for_phi3v(ctx,
|
||||
inputs,
|
||||
num_crops=num_crops)
|
||||
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
|
||||
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
|
||||
assert img_tok_count == expected_toks_per_img * num_imgs
|
||||
|
||||
@ -5,7 +5,7 @@ import pytest
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm.inputs import InputContext, LLMInputs
|
||||
from vllm.inputs import InputContext, token_inputs
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
|
||||
|
||||
@ -71,12 +71,12 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen,
|
||||
"""Happy cases for image inputs to Qwen's multimodal input processor."""
|
||||
prompt = "".join(
|
||||
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
|
||||
inputs = LLMInputs(
|
||||
inputs = token_inputs(
|
||||
prompt=prompt,
|
||||
# When processing multimodal data for a multimodal model, the qwen
|
||||
# input processor will overwrite the provided prompt_token_ids with
|
||||
# the image prompts
|
||||
prompt_token_ids=None,
|
||||
prompt_token_ids=[],
|
||||
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
|
||||
)
|
||||
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
|
||||
@ -134,9 +134,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen,
|
||||
trust_remote_code=True)
|
||||
prompt = "Picture 1: <img></img>\n"
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
inputs = LLMInputs(prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_data)
|
||||
inputs = token_inputs(prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_data)
|
||||
# Should fail since we have too many or too few dimensions for embeddings
|
||||
with pytest.raises(ValueError):
|
||||
input_processor_for_qwen(qwen_vl_context, inputs)
|
||||
|
||||
@ -5,7 +5,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.inputs import InputContext, LLMInputs
|
||||
from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
|
||||
from vllm.inputs.registry import InputRegistry
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
@ -31,7 +31,7 @@ def use_processor_mock():
|
||||
"""Patches the internal model input processor with an override callable."""
|
||||
|
||||
def custom_processor(ctx: InputContext,
|
||||
llm_inputs: LLMInputs,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
num_crops=DEFAULT_NUM_CROPS):
|
||||
# For testing purposes, we don't worry about the llm inputs / return
|
||||
@ -84,7 +84,7 @@ def test_default_processor_is_a_noop():
|
||||
dummy_registry = InputRegistry()
|
||||
ctx = build_model_context(DUMMY_MODEL_ID)
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
|
||||
proc_inputs = token_inputs(prompt_token_ids=[], prompt="")
|
||||
proc_outputs = processor(inputs=proc_inputs)
|
||||
assert proc_inputs is proc_outputs
|
||||
|
||||
@ -125,9 +125,9 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops,
|
||||
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
num_crops_val = processor(
|
||||
LLMInputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=inference_kwargs))
|
||||
token_inputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=inference_kwargs))
|
||||
assert num_crops_val == expected_seq_count
|
||||
|
||||
|
||||
@ -154,9 +154,9 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
# Should filter out the inference time kwargs
|
||||
num_crops_val = processor(
|
||||
LLMInputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=mm_processor_kwargs))
|
||||
token_inputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=mm_processor_kwargs))
|
||||
assert num_crops_val == DEFAULT_NUM_CROPS
|
||||
|
||||
|
||||
|
||||
@ -29,8 +29,8 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
|
||||
InputRegistry, LLMInputs, PromptType)
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderInputs, InputRegistry, PromptType)
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -635,7 +635,7 @@ class LLMEngine:
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str,
|
||||
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
|
||||
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
@ -1855,8 +1855,8 @@ class LLMEngine:
|
||||
def is_embedding_model(self):
|
||||
return self.model_config.is_embedding_model
|
||||
|
||||
def _validate_model_inputs(self, inputs: Union[LLMInputs,
|
||||
EncoderDecoderLLMInputs]):
|
||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderInputs]):
|
||||
if self.model_config.is_multimodal_model:
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
||||
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
|
||||
TokensPrompt, build_explicit_enc_dec_prompt,
|
||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
|
||||
ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
|
||||
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
|
||||
token_inputs, zip_enc_dec_prompts)
|
||||
from .registry import InputContext, InputRegistry
|
||||
|
||||
INPUT_REGISTRY = InputRegistry()
|
||||
@ -19,8 +20,11 @@ __all__ = [
|
||||
"PromptType",
|
||||
"SingletonPrompt",
|
||||
"ExplicitEncoderDecoderPrompt",
|
||||
"LLMInputs",
|
||||
"EncoderDecoderLLMInputs",
|
||||
"TokenInputs",
|
||||
"token_inputs",
|
||||
"SingletonInputs",
|
||||
"DecoderOnlyInputs",
|
||||
"EncoderDecoderInputs",
|
||||
"build_explicit_enc_dec_prompt",
|
||||
"to_enc_dec_tuple_list",
|
||||
"zip_enc_dec_prompts",
|
||||
@ -31,9 +35,9 @@ __all__ = [
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "PromptInput":
|
||||
import warnings
|
||||
import warnings
|
||||
|
||||
if name == "PromptInput":
|
||||
msg = ("PromptInput has been renamed to PromptType. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
@ -41,4 +45,21 @@ def __getattr__(name: str):
|
||||
|
||||
return PromptType
|
||||
|
||||
if name == "LLMInputs":
|
||||
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
return DecoderOnlyInputs
|
||||
|
||||
if name == "EncoderDecoderLLMInputs":
|
||||
msg = (
|
||||
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
return EncoderDecoderInputs
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
|
||||
Optional, Tuple, Union)
|
||||
Optional, Tuple, Union, cast)
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||
|
||||
@ -51,7 +51,7 @@ class TokensPrompt(TypedDict):
|
||||
|
||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
|
||||
"""
|
||||
Set of possible schemas for a single LLM input:
|
||||
Set of possible schemas for a single prompt:
|
||||
|
||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||
- A tokenized prompt (:class:`TokensPrompt`)
|
||||
@ -120,13 +120,8 @@ both decoder-only and encoder/decoder input types:
|
||||
"""
|
||||
|
||||
|
||||
class LLMInputs(TypedDict):
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
|
||||
This specifies the data required for decoder-only models.
|
||||
"""
|
||||
class TokenInputs(TypedDict):
|
||||
"""Represents token-based inputs."""
|
||||
prompt_token_ids: List[int]
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
@ -150,7 +145,40 @@ class LLMInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class EncoderDecoderLLMInputs(LLMInputs):
|
||||
def token_inputs(
|
||||
prompt_token_ids: List[int],
|
||||
prompt: Optional[str] = None,
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> TokenInputs:
|
||||
"""Construct :class:`TokenInputs` from optional values."""
|
||||
inputs = TokenInputs(prompt_token_ids=prompt_token_ids)
|
||||
|
||||
if prompt is not None:
|
||||
inputs["prompt"] = prompt
|
||||
if multi_modal_data is not None:
|
||||
inputs["multi_modal_data"] = multi_modal_data
|
||||
if mm_processor_kwargs is not None:
|
||||
inputs["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
SingletonInputs = TokenInputs
|
||||
"""
|
||||
A processed :class:`SingletonPrompt` which can be passed to
|
||||
:class:`vllm.sequence.Sequence`.
|
||||
"""
|
||||
|
||||
DecoderOnlyInputs = TokenInputs
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
This specifies the data required for decoder-only models.
|
||||
"""
|
||||
|
||||
|
||||
class EncoderDecoderInputs(TokenInputs):
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
@ -204,11 +232,12 @@ def zip_enc_dec_prompts(
|
||||
be zipped with the encoder/decoder prompts.
|
||||
"""
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
if isinstance(mm_processor_kwargs, Dict):
|
||||
mm_processor_kwargs = cast(Dict[str, Any], {})
|
||||
if isinstance(mm_processor_kwargs, dict):
|
||||
return [
|
||||
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
|
||||
mm_processor_kwargs)
|
||||
build_explicit_enc_dec_prompt(
|
||||
encoder_prompt, decoder_prompt,
|
||||
cast(Dict[str, Any], mm_processor_kwargs))
|
||||
for (encoder_prompt,
|
||||
decoder_prompt) in zip(enc_prompts, dec_prompts)
|
||||
]
|
||||
@ -229,9 +258,9 @@ def to_enc_dec_tuple_list(
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "PromptInput":
|
||||
import warnings
|
||||
import warnings
|
||||
|
||||
if name == "PromptInput":
|
||||
msg = ("PromptInput has been renamed to PromptType. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
@ -239,4 +268,21 @@ def __getattr__(name: str):
|
||||
|
||||
return PromptType
|
||||
|
||||
if name == "LLMInputs":
|
||||
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
return DecoderOnlyInputs
|
||||
|
||||
if name == "EncoderDecoderLLMInputs":
|
||||
msg = (
|
||||
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
|
||||
"The original name will be removed in an upcoming version.")
|
||||
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
return EncoderDecoderInputs
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@ -4,9 +4,9 @@ from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
||||
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
|
||||
TokensPrompt)
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
|
||||
ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt,
|
||||
TextPrompt, TokensPrompt)
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
@ -100,7 +100,7 @@ def is_explicit_encoder_decoder_prompt(
|
||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||
|
||||
|
||||
def is_valid_encoder_decoder_llm_inputs(
|
||||
inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
|
||||
) -> TypeIs[EncoderDecoderLLMInputs]:
|
||||
def is_encoder_decoder_inputs(
|
||||
inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
|
||||
) -> TypeIs[EncoderDecoderInputs]:
|
||||
return "encoder_prompt_token_ids" in inputs
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType,
|
||||
SingletonPrompt)
|
||||
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
||||
|
||||
@ -306,7 +306,7 @@ class InputPreprocessor:
|
||||
encoder_comps: PromptComponents,
|
||||
decoder_comps: DecoderPromptComponents,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
) -> EncoderDecoderInputs:
|
||||
encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
|
||||
decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
|
||||
|
||||
@ -324,7 +324,7 @@ class InputPreprocessor:
|
||||
decoder_prompt_ids,
|
||||
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
|
||||
|
||||
return EncoderDecoderLLMInputs(
|
||||
return EncoderDecoderInputs(
|
||||
prompt_token_ids=decoder_prompt_ids,
|
||||
prompt=decoder_prompt,
|
||||
multi_modal_data=decoder_mm_data,
|
||||
@ -338,11 +338,11 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
) -> EncoderDecoderInputs:
|
||||
'''
|
||||
For encoder/decoder models only:
|
||||
Process an input prompt into an
|
||||
:class:`EncoderDecoderLLMInputs` instance.
|
||||
:class:`EncoderDecoderInputs` instance.
|
||||
|
||||
There are two types of input prompts:
|
||||
singleton prompts which carry only the
|
||||
@ -369,7 +369,7 @@ class InputPreprocessor:
|
||||
|
||||
Returns:
|
||||
|
||||
* :class:`EncoderDecoderLLMInputs` instance
|
||||
* :class:`EncoderDecoderInputs` instance
|
||||
'''
|
||||
|
||||
encoder_comps: PromptComponents
|
||||
@ -411,7 +411,7 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
) -> EncoderDecoderInputs:
|
||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
@ -455,17 +455,17 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt_comps: PromptComponents,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> LLMInputs:
|
||||
) -> DecoderOnlyInputs:
|
||||
(prompt, prompt_token_ids, multi_modal_data,
|
||||
mm_processor_kwargs) = prompt_comps
|
||||
|
||||
prompt_token_ids = self._apply_prompt_adapter(
|
||||
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs)
|
||||
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
@ -473,10 +473,10 @@ class InputPreprocessor:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
) -> DecoderOnlyInputs:
|
||||
'''
|
||||
For decoder-only models:
|
||||
Process an input prompt into an :class:`LLMInputs` instance.
|
||||
Process an input prompt into an :class:`DecoderOnlyInputs` instance.
|
||||
|
||||
Arguments:
|
||||
|
||||
@ -487,7 +487,7 @@ class InputPreprocessor:
|
||||
|
||||
Returns:
|
||||
|
||||
* :class:`LLMInputs` instance
|
||||
* :class:`DecoderOnlyInputs` instance
|
||||
'''
|
||||
|
||||
prompt_comps = self._extract_prompt_components(
|
||||
@ -507,7 +507,7 @@ class InputPreprocessor:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
) -> DecoderOnlyInputs:
|
||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||
prompt_comps = await self._extract_prompt_components_async(
|
||||
prompt,
|
||||
@ -526,7 +526,7 @@ class InputPreprocessor:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
|
||||
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
|
||||
"""Preprocess the input prompt."""
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
@ -554,7 +554,7 @@ class InputPreprocessor:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
|
||||
) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
|
||||
"""Async version of :meth:`preprocess`."""
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
|
||||
resolve_mm_processor_kwargs)
|
||||
|
||||
from .data import LLMInputs
|
||||
from .data import DecoderOnlyInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -100,7 +100,7 @@ class _MultiModalCounts(UserDict):
|
||||
raise KeyError(msg) from exc
|
||||
|
||||
|
||||
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
|
||||
InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs]
|
||||
"""Preprocess the inputs to the model."""
|
||||
|
||||
|
||||
@ -134,7 +134,7 @@ class InputRegistry:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
dummy_seq_data = SequenceData.from_token_counts((0, seq_len))
|
||||
dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
|
||||
dummy_multi_modal_data = None
|
||||
|
||||
return dummy_seq_data, dummy_multi_modal_data
|
||||
@ -245,8 +245,11 @@ class InputRegistry:
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
def _default_input_processor(self, ctx: InputContext,
|
||||
inputs: LLMInputs) -> LLMInputs:
|
||||
def _default_input_processor(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""The default input processor is a no-op."""
|
||||
return inputs
|
||||
|
||||
@ -279,7 +282,7 @@ class InputRegistry:
|
||||
.get(model_cls, self._default_input_processor)
|
||||
|
||||
def process_input(self, model_config: "ModelConfig",
|
||||
inputs: LLMInputs) -> LLMInputs:
|
||||
inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
|
||||
"""
|
||||
Apply an input processor to an instance of model inputs.
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from transformers.models.blip.modeling_blip import BlipAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -63,7 +63,7 @@ def dummy_seq_data_for_blip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_token_counts(
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
@ -89,14 +89,14 @@ def dummy_image_for_blip(
|
||||
def input_processor_for_blip(
|
||||
model_config: ModelConfig,
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
llm_inputs: LLMInputs,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
@ -107,16 +107,16 @@ def input_processor_for_blip(
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
|
||||
|
||||
@ -9,7 +9,8 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
@ -421,7 +422,7 @@ def dummy_seq_data_for_blip2(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_token_counts(
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
@ -449,10 +450,10 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
@ -460,15 +461,15 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
# The original model places image tokens at the front
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
|
||||
new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
|
||||
new_token_ids += llm_inputs["prompt_token_ids"]
|
||||
new_token_ids += inputs["prompt_token_ids"]
|
||||
|
||||
new_prompt = llm_inputs.get("prompt")
|
||||
new_prompt = inputs.get("prompt")
|
||||
if new_prompt is not None:
|
||||
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
|
||||
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
|
||||
@ -11,7 +11,8 @@ from transformers import ChameleonConfig, ChameleonVQVAEConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@ -69,7 +70,7 @@ def dummy_seq_data_for_chameleon(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_token_counts(
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
@ -106,7 +107,8 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
def input_processor_for_chameleon(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
|
||||
"""
|
||||
Processing input prompt to insert required tokens for image placeholder.
|
||||
@ -114,16 +116,16 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
|
||||
""" # noqa
|
||||
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
|
||||
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
|
||||
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
|
||||
@ -137,9 +139,9 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
class ChameleonLayerNorm(nn.LayerNorm):
|
||||
|
||||
@ -14,7 +14,7 @@ from torch.nn import LayerNorm
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -149,20 +149,20 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]:
|
||||
return [index for index, value in enumerate(input_ids) if value == target]
|
||||
|
||||
|
||||
def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||
vision_config = getattr(hf_config, 'vision_config', None)
|
||||
|
||||
if vision_config is None:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
elif isinstance(vision_config, dict):
|
||||
image_placeholder_length = calculate_image_placeholder(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
input_ids = llm_inputs.get("prompt_token_ids")
|
||||
position_ids = llm_inputs.get("position_ids")
|
||||
input_ids = inputs.get("prompt_token_ids")
|
||||
position_ids = inputs.get("position_ids")
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.model,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code)
|
||||
@ -171,15 +171,15 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
raw_batch_data = tokenizer.apply_chat_template(
|
||||
conversation=[{
|
||||
"role": "user",
|
||||
"image": llm_inputs['multi_modal_data']["image"],
|
||||
"content": llm_inputs['prompt']
|
||||
"image": inputs['multi_modal_data']["image"],
|
||||
"content": inputs['prompt']
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True).data
|
||||
except Exception:
|
||||
logger.error("Failed to process content (%s)", llm_inputs['prompt'])
|
||||
logger.error("Failed to process content (%s)", inputs['prompt'])
|
||||
raise
|
||||
input_ids = raw_batch_data['input_ids'][0].tolist()
|
||||
|
||||
@ -214,9 +214,9 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
|
||||
assert len(new_input_ids) == len(new_position_ids)
|
||||
|
||||
llm_inputs["prompt_token_ids"] = new_input_ids
|
||||
llm_inputs["position_ids"] = new_position_ids
|
||||
return llm_inputs
|
||||
inputs["prompt_token_ids"] = new_input_ids
|
||||
inputs["position_ids"] = new_position_ids
|
||||
return inputs
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
|
||||
@ -11,7 +11,7 @@ from transformers.models.clip.modeling_clip import CLIPSdpaAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -62,7 +62,7 @@ def dummy_seq_data_for_clip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_token_counts(
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
@ -106,14 +106,14 @@ def dummy_video_for_clip(
|
||||
def input_processor_for_clip(
|
||||
model_config: ModelConfig,
|
||||
hf_config: CLIPVisionConfig,
|
||||
llm_inputs: LLMInputs,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[Union[int, List[int]]] = None,
|
||||
):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
@ -130,16 +130,16 @@ def input_processor_for_clip(
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
||||
|
||||
@ -27,7 +27,8 @@ from transformers import FuyuConfig, FuyuImageProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -149,10 +150,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
|
||||
return model_image_input
|
||||
|
||||
|
||||
def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
image_data = multi_modal_data["image"]
|
||||
@ -176,8 +177,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
# process prompts
|
||||
prompt = llm_inputs.get("prompt")
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
prompt = inputs.get("prompt")
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
tokenizer = cached_get_tokenizer(model_config.model)
|
||||
# dim0 is batch_size, dim1 is subseq_size which will always be 1
|
||||
image_input_ids: List[List[
|
||||
@ -190,9 +191,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
|
||||
1:] + boa_token
|
||||
|
||||
return LLMInputs(prompt=new_prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=new_multi_modal_data)
|
||||
return token_inputs(prompt=new_prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=new_multi_modal_data)
|
||||
|
||||
|
||||
def input_mapper_for_fuyu(ctx: InputContext, data: object):
|
||||
|
||||
@ -17,7 +17,8 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
@ -276,13 +277,13 @@ class InternVLInputPipeline:
|
||||
def input_processor(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
llm_inputs: LLMInputs,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
) -> LLMInputs:
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
) -> DecoderOnlyInputs:
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config()
|
||||
@ -311,8 +312,8 @@ class InternVLInputPipeline:
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
prompt = llm_inputs.get("prompt")
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
prompt = inputs.get("prompt")
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
|
||||
@ -320,9 +321,9 @@ class InternVLInputPipeline:
|
||||
num_patches)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
|
||||
return LLMInputs(prompt=prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt=prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
def input_mapper(
|
||||
self,
|
||||
|
||||
@ -9,7 +9,7 @@ from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
@ -125,10 +125,10 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
@ -151,7 +151,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
llm_inputs,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
@ -159,7 +159,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return input_processor_for_siglip(
|
||||
model_config,
|
||||
vision_config,
|
||||
llm_inputs,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
@ -12,7 +12,7 @@ from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -201,10 +201,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
def input_processor_for_llava_next(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
@ -239,7 +240,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
llm_inputs,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
@ -247,7 +248,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return input_processor_for_siglip(
|
||||
model_config,
|
||||
vision_config,
|
||||
llm_inputs,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
@ -11,7 +11,8 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
@ -139,10 +140,10 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
|
||||
|
||||
|
||||
def input_processor_for_llava_next_video(ctx: InputContext,
|
||||
llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "video" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
video_data = multi_modal_data["video"]
|
||||
|
||||
model_config = ctx.model_config
|
||||
@ -160,15 +161,15 @@ def input_processor_for_llava_next_video(ctx: InputContext,
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=hf_config.video_token_index,
|
||||
repeat_count=video_feature_size,
|
||||
)
|
||||
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
elif is_list_of(video_data, np.ndarray):
|
||||
raise NotImplementedError(
|
||||
|
||||
@ -15,8 +15,8 @@ from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
@ -37,8 +37,6 @@ from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
||||
|
||||
@ -252,10 +250,10 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
|
||||
|
||||
|
||||
def input_processor_when_multimodal_input_image(ctx: InputContext,
|
||||
llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
@ -290,7 +288,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
llm_inputs,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
@ -298,7 +296,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
|
||||
return input_processor_for_siglip(
|
||||
model_config,
|
||||
vision_config,
|
||||
llm_inputs,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
@ -308,10 +306,10 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
|
||||
|
||||
|
||||
def input_processor_when_multimodal_input_video(ctx: InputContext,
|
||||
llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "video" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
video_data = multi_modal_data["video"]
|
||||
|
||||
model_config = ctx.model_config
|
||||
@ -326,15 +324,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=hf_config.video_token_index,
|
||||
repeat_count=video_feature_size,
|
||||
)
|
||||
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
elif is_list_of(video_data, np.ndarray):
|
||||
raise NotImplementedError(
|
||||
@ -345,15 +343,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
|
||||
|
||||
|
||||
def input_processor_for_llava_onevision(ctx: InputContext,
|
||||
llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or ("video" not in multi_modal_data
|
||||
and "image" not in multi_modal_data):
|
||||
return llm_inputs
|
||||
return inputs
|
||||
if "image" in multi_modal_data:
|
||||
return input_processor_when_multimodal_input_image(ctx, llm_inputs)
|
||||
return input_processor_when_multimodal_input_image(ctx, inputs)
|
||||
if "video" in multi_modal_data:
|
||||
return input_processor_when_multimodal_input_video(ctx, llm_inputs)
|
||||
return input_processor_when_multimodal_input_video(ctx, inputs)
|
||||
|
||||
msg = "Unsupported multi data type"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@ -36,7 +36,8 @@ from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
|
||||
@ -256,7 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
|
||||
|
||||
|
||||
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
|
||||
return SequenceData.from_token_counts((0, seq_len))
|
||||
return SequenceData.from_prompt_token_counts((0, seq_len))
|
||||
|
||||
|
||||
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
|
||||
@ -279,10 +280,10 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
model_config = ctx.model_config
|
||||
version = get_version_by_config(model_config.hf_config)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
@ -297,8 +298,8 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return image_processor. \
|
||||
get_slice_image_placeholder(image_size, num_image)
|
||||
|
||||
prompt = llm_inputs.get("prompt")
|
||||
token_ids = llm_inputs.get("prompt_token_ids")
|
||||
prompt = inputs.get("prompt")
|
||||
token_ids = inputs.get("prompt_token_ids")
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(token_ids)
|
||||
|
||||
@ -332,12 +333,11 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
_build_image_input(ctx, image) for image in images
|
||||
]
|
||||
|
||||
llm_inputs = LLMInputs(
|
||||
return token_inputs(
|
||||
prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
return llm_inputs
|
||||
|
||||
|
||||
def input_mapper_for_minicpmv(ctx: InputContext, data: object):
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Mllama model."""
|
||||
import math
|
||||
from array import array
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -37,7 +36,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderInputs, InputContext)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -51,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
from .clip import CLIPMLP
|
||||
from .interfaces import SupportsMultiModal
|
||||
@ -86,24 +86,24 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
|
||||
return num_images
|
||||
|
||||
|
||||
def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
def input_processor_for_mllama(ctx: InputContext,
|
||||
inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderInputs]):
|
||||
# move encoder_prompt to prompt
|
||||
if llm_inputs.get("prompt") is None:
|
||||
llm_inputs["prompt"] = llm_inputs["encoder_prompt"]
|
||||
llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"]
|
||||
if inputs.get("prompt") is None:
|
||||
inputs["prompt"] = inputs["encoder_prompt"]
|
||||
inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"]
|
||||
|
||||
# process multi-modal data
|
||||
assert "decoder_multi_modal_data" not in llm_inputs, \
|
||||
"multi-modal data should be put in encoder message of mllama"
|
||||
multi_modal_data = llm_inputs.get("encoder_multi_modal_data")
|
||||
multi_modal_data = inputs.get("encoder_multi_modal_data")
|
||||
|
||||
if multi_modal_data is None or "image" not in multi_modal_data \
|
||||
or multi_modal_data["image"] is None:
|
||||
# text-only
|
||||
llm_inputs["encoder_prompt"] = ""
|
||||
llm_inputs["encoder_prompt_token_ids"] = []
|
||||
llm_inputs["encoder_multi_modal_data"] = {}
|
||||
return llm_inputs
|
||||
inputs["encoder_prompt"] = ""
|
||||
inputs["encoder_prompt_token_ids"] = []
|
||||
inputs["encoder_multi_modal_data"] = {}
|
||||
return inputs
|
||||
|
||||
if isinstance(multi_modal_data['image'], Image.Image):
|
||||
multi_modal_data['image'] = [multi_modal_data['image']]
|
||||
@ -111,7 +111,7 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
# are attended by the decoded tokens, we only need to
|
||||
# get the number of tiles for those images.
|
||||
num_decode_images = _get_num_image_in_last_group(
|
||||
llm_inputs["prompt_token_ids"])
|
||||
inputs["prompt_token_ids"])
|
||||
hf_config = ctx.model_config.hf_config
|
||||
num_tiles = 0
|
||||
for image in multi_modal_data["image"][::-1]:
|
||||
@ -137,11 +137,10 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
"chunk size should be multiple of 14"
|
||||
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
|
||||
num_tokens = num_tiles * token_per_chunk
|
||||
llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
|
||||
llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID
|
||||
] * num_tokens
|
||||
inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
|
||||
inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens
|
||||
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
|
||||
def get_max_mllama_image_tokens(ctx: InputContext) -> int:
|
||||
@ -154,17 +153,18 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int):
|
||||
# <|image|> * num_images + 0 * (seq_len - num_images)
|
||||
assert seq_len >= num_images, \
|
||||
"seq_len should be greater than or equal to num_images"
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[MLLAMA_IMAGE_TOKEN_ID]) * num_images
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(MLLAMA_IMAGE_TOKEN_ID, num_images),
|
||||
(0, seq_len - num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
|
||||
num_tokens = get_max_mllama_image_tokens(ctx) * num_images
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[MLLAMA_IMAGE_TOKEN_ID]) * num_tokens
|
||||
return SequenceData(token_ids)
|
||||
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(MLLAMA_IMAGE_TOKEN_ID, num_tokens))
|
||||
|
||||
|
||||
def dummy_image(num_images: int, ):
|
||||
|
||||
@ -23,7 +23,8 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -945,9 +946,9 @@ def pad_images(
|
||||
return images, image_input_idx, image_masks
|
||||
|
||||
|
||||
def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
prompt = llm_inputs.get("prompt", None)
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data", None)
|
||||
def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
prompt = inputs.get("prompt", None)
|
||||
multi_modal_data = inputs.get("multi_modal_data", None)
|
||||
if multi_modal_data is not None:
|
||||
image = multi_modal_data.get("image", None)
|
||||
else:
|
||||
@ -965,9 +966,7 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
elif prompt is not None:
|
||||
out = processor.process(prompt, image)
|
||||
else:
|
||||
out = processor.process(None,
|
||||
image,
|
||||
tokens=llm_inputs["prompt_token_ids"])
|
||||
out = processor.process(None, image, tokens=inputs["prompt_token_ids"])
|
||||
|
||||
image_processor = processor.image_processor
|
||||
max_total_crops = 1 + image_processor.max_crops
|
||||
@ -1020,9 +1019,9 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
|
||||
multi_modal_data = dict(image=image_data)
|
||||
|
||||
return LLMInputs(
|
||||
return token_inputs(
|
||||
prompt_token_ids=out["input_ids"],
|
||||
prompt=llm_inputs["prompt"],
|
||||
prompt=inputs["prompt"],
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
@ -7,7 +7,8 @@ from transformers import PaliGemmaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -68,7 +69,8 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
def input_processor_for_paligemma(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
|
||||
"""
|
||||
The correct prompt format needs to be:
|
||||
@ -77,9 +79,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
|
||||
""" # noqa
|
||||
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||
@ -91,8 +93,8 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
image_token_str_pad = image_token_str * image_feature_size
|
||||
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
|
||||
|
||||
orig_prompt = llm_inputs.get("prompt")
|
||||
orig_prompt_ids = llm_inputs.get("prompt_token_ids")
|
||||
orig_prompt = inputs.get("prompt")
|
||||
orig_prompt_ids = inputs.get("prompt_token_ids")
|
||||
|
||||
if orig_prompt is not None and image_token_str in orig_prompt:
|
||||
logger.warning(
|
||||
@ -106,9 +108,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
class PaliGemmaMultiModalProjector(nn.Module):
|
||||
|
||||
@ -27,7 +27,8 @@ from transformers import CLIPVisionConfig, PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -410,12 +411,12 @@ def _get_image_placeholder_token_id_candidates(
|
||||
|
||||
|
||||
def input_processor_for_phi3v(ctx: InputContext,
|
||||
llm_inputs: LLMInputs,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
num_crops: Optional[int] = None):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_image_processor_config()
|
||||
@ -447,7 +448,7 @@ def input_processor_for_phi3v(ctx: InputContext,
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
prompt = llm_inputs.get("prompt")
|
||||
prompt = inputs.get("prompt")
|
||||
if prompt is None:
|
||||
# for async server request, we assume prompt and its token_ids is always
|
||||
# in correct format. And num_image_tags == len(image_data) always True.
|
||||
@ -464,7 +465,7 @@ def input_processor_for_phi3v(ctx: InputContext,
|
||||
image_data), "The count of image_placeholder not match image's"
|
||||
new_prompt = prompt
|
||||
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"].copy()
|
||||
prompt_token_ids = inputs["prompt_token_ids"].copy()
|
||||
|
||||
print("prompt_token_ids (old)", prompt_token_ids)
|
||||
|
||||
@ -506,10 +507,9 @@ def input_processor_for_phi3v(ctx: InputContext,
|
||||
new_token_ids.append(token_id)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return llm_inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
|
||||
@ -14,7 +14,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
@ -62,7 +62,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
||||
image_feature_size = (size**2) // (patch_size**2)
|
||||
|
||||
num_image_tokens = image_feature_size * num_images
|
||||
seq_data = SequenceData.from_token_counts(
|
||||
seq_data = SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, num_image_tokens),
|
||||
(0, seq_len - num_image_tokens),
|
||||
)
|
||||
@ -102,8 +102,8 @@ def input_mapper_for_pixtral(ctx: InputContext,
|
||||
return MultiModalInputs({"images": images})
|
||||
|
||||
|
||||
def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is not None and "image" in multi_modal_data:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
@ -112,15 +112,15 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
||||
image_token_id = mm_encoder.special_ids.img
|
||||
|
||||
if image_token_id not in llm_inputs['prompt_token_ids']:
|
||||
if image_token_id not in inputs['prompt_token_ids']:
|
||||
raise ValueError(
|
||||
(f"You've passed {llm_inputs=} without {image_token_id=}"
|
||||
(f"You've passed {inputs=} without {image_token_id=}"
|
||||
" Make sure to process your input via mistral_common's"
|
||||
" tokenizer or pass a chat completion request. For more"
|
||||
" For more info, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/8411."))
|
||||
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
|
||||
|
||||
@ -22,7 +22,8 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -652,30 +653,30 @@ def get_image_text(image_num: int, padding: bool) -> str:
|
||||
|
||||
|
||||
def input_processor_for_qwen(ctx: InputContext,
|
||||
llm_inputs: LLMInputs) -> LLMInputs:
|
||||
inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
|
||||
"""Processes the inputs, which may or may not be multimodal.
|
||||
Multimodal inputs will only be processed if the model has a "visual"
|
||||
component in its model config, otherwise they'll be ignored.
|
||||
|
||||
Args:
|
||||
ctx: Context of the loaded model.
|
||||
llm_inputs: LLM inputs which may have a multi_modal_data attribute.
|
||||
inputs: LLM inputs which may have a multi_modal_data attribute.
|
||||
|
||||
Returns:
|
||||
If the model is language only or not multimodal inputs were provided,
|
||||
returns llm_inputs unmodified. Otherwise, processes the multimodal
|
||||
returns inputs unmodified. Otherwise, processes the multimodal
|
||||
images / image embeddings and adds the fixed-length image placeholders.
|
||||
"""
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
|
||||
# Only process images if we have multimodal data and a visual config
|
||||
hf_config = ctx.get_hf_config()
|
||||
if (multi_modal_data is None or "image" not in multi_modal_data
|
||||
or not hasattr(hf_config, "visual")):
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
prompt = llm_inputs.get("prompt")
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
prompt = inputs.get("prompt")
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
@ -713,9 +714,9 @@ def input_processor_for_qwen(ctx: InputContext,
|
||||
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
|
||||
return LLMInputs(prompt=new_prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt=new_prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
|
||||
@ -822,7 +823,7 @@ def dummy_data_for_qwen(
|
||||
# The presence of a visual config indicates this is a multimodal model.
|
||||
# If we don't have it, the model is considered an LLM for warmup purposes.
|
||||
if not hasattr(hf_config, "visual"):
|
||||
seq_data = SequenceData.from_token_counts((0, seq_len))
|
||||
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
|
||||
mm_data = None
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
@ -46,7 +46,8 @@ from vllm.attention.selector import (_Backend, backend_name_to_enum,
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.distributed import get_pp_group, parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
@ -716,7 +717,7 @@ def dummy_data_for_qwen2_vl(
|
||||
|
||||
hf_config = ctx.get_hf_config(Qwen2VLConfig)
|
||||
|
||||
dummy_seqdata = SequenceData.from_token_counts(
|
||||
dummy_seqdata = SequenceData.from_prompt_token_counts(
|
||||
(hf_config.vision_start_token_id, 1),
|
||||
(hf_config.image_token_id, max_llm_image_tokens),
|
||||
(hf_config.vision_end_token_id, 1),
|
||||
@ -799,11 +800,13 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
|
||||
return prompt_token_ids_with_data
|
||||
|
||||
|
||||
def input_processor_for_qwen2_vl(ctx: InputContext,
|
||||
llm_inputs: LLMInputs) -> LLMInputs:
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data", None)
|
||||
def input_processor_for_qwen2_vl(
|
||||
ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs,
|
||||
) -> DecoderOnlyInputs:
|
||||
multi_modal_data = inputs.get("multi_modal_data", None)
|
||||
if multi_modal_data is None:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
image_inputs = multi_modal_data.get("image", None)
|
||||
video_inputs = multi_modal_data.get("video", None)
|
||||
@ -817,7 +820,7 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
|
||||
# `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
|
||||
#
|
||||
# The following code is equivalent to:
|
||||
# prompt = llm_inputs["prompt"]
|
||||
# prompt = inputs["prompt"]
|
||||
# inputs = processor(text=[prompt],
|
||||
# images=image_inputs,
|
||||
# videos=video_inputs,
|
||||
@ -825,9 +828,9 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
|
||||
# return_tensors="pt")
|
||||
# prompt_token_ids = inputs["input_ids"][0].tolist()
|
||||
|
||||
prompt_token_ids = llm_inputs.get("prompt_token_ids", None)
|
||||
prompt_token_ids = inputs.get("prompt_token_ids", None)
|
||||
if prompt_token_ids is None:
|
||||
prompt = llm_inputs["prompt"]
|
||||
prompt = inputs["prompt"]
|
||||
prompt_token_ids = processor.tokenizer(
|
||||
prompt,
|
||||
padding=True,
|
||||
@ -868,9 +871,9 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
|
||||
image_processor,
|
||||
prompt_token_ids)
|
||||
|
||||
return LLMInputs(
|
||||
return token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt=llm_inputs["prompt"],
|
||||
prompt=inputs["prompt"],
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.inputs import DecoderOnlyInputs, token_inputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -67,7 +67,7 @@ def dummy_seq_data_for_siglip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
return SequenceData.from_token_counts(
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(image_token_id, image_feature_size * num_images),
|
||||
(0, seq_len - image_feature_size * num_images),
|
||||
)
|
||||
@ -111,14 +111,14 @@ def dummy_video_for_siglip(
|
||||
def input_processor_for_siglip(
|
||||
model_config: ModelConfig,
|
||||
hf_config: SiglipVisionConfig,
|
||||
llm_inputs: LLMInputs,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[Union[int, List[int]]] = None,
|
||||
):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
@ -135,14 +135,14 @@ def input_processor_for_siglip(
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(
|
||||
return token_inputs(
|
||||
prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
|
||||
@ -18,7 +18,7 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.inputs.data import LLMInputs
|
||||
from vllm.inputs.data import DecoderOnlyInputs, token_inputs
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -156,10 +156,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
||||
return MultiModalInputs({"audio_features": audio_features})
|
||||
|
||||
|
||||
def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "audio" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
return inputs
|
||||
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
audios = multi_modal_data["audio"]
|
||||
@ -196,16 +196,16 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
|
||||
repeat_count=audio_token_counts,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
class StackAudioFrames(nn.Module):
|
||||
|
||||
@ -13,8 +13,7 @@ from typing import Set, Tuple, Union, cast
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs
|
||||
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
@ -22,6 +21,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import SingletonInputs
|
||||
from vllm.multimodal.base import MultiModalDataDict
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
@ -29,6 +29,11 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
VLLM_INVALID_TOKEN_ID = -1
|
||||
|
||||
|
||||
def array_full(token_id: int, count: int):
|
||||
""":class:`array` equivalent of :func:`numpy.full`."""
|
||||
return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
|
||||
|
||||
|
||||
# We use dataclass for now because it is used for
|
||||
# openai server output, and msgspec is not serializable.
|
||||
# TODO(sang): Fix it.
|
||||
@ -173,22 +178,34 @@ class SequenceData(msgspec.Struct,
|
||||
_mrope_position_delta: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
|
||||
def from_prompt_token_counts(
|
||||
*token_counts: Tuple[int, int]) -> "SequenceData":
|
||||
"""
|
||||
Construct a :class:`SequenceData` instance by concatenating
|
||||
prompt token sequences.
|
||||
|
||||
Each tuple represents one token sequence, expressed in the form
|
||||
:code:`(token_id, count)`.
|
||||
"""
|
||||
if len(token_counts) == 0:
|
||||
return SequenceData.from_seqs([])
|
||||
|
||||
arrs = [
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
|
||||
for token_id, count in token_counts
|
||||
]
|
||||
prompt_token_ids_arr = reduce(
|
||||
array.__iadd__,
|
||||
(array_full(token_id, count) for token_id, count in token_counts),
|
||||
)
|
||||
|
||||
return SequenceData(reduce(array.__add__, arrs))
|
||||
return SequenceData(prompt_token_ids_arr)
|
||||
|
||||
@staticmethod
|
||||
def from_seqs(
|
||||
prompt_token_ids: GenericSequence[int],
|
||||
output_token_ids: Optional[GenericSequence[int]] = None,
|
||||
) -> "SequenceData":
|
||||
"""
|
||||
Construct a :class:`SequenceData` instance from prompt and output
|
||||
token sequences.
|
||||
"""
|
||||
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
prompt_token_ids)
|
||||
|
||||
@ -362,14 +379,14 @@ class SequenceData(msgspec.Struct,
|
||||
class Sequence:
|
||||
"""Stores the data, status, and block information of a sequence.
|
||||
|
||||
The sequence is constructed from the LLMInputs instance passed
|
||||
in through the `inputs` constructor argument.
|
||||
The sequence is constructed from the :code:`SingletonInputs` instance
|
||||
passed in through the :code:`inputs` constructor argument.
|
||||
|
||||
For encoder/decoder models, LLMInputs encapsulates both a
|
||||
For encoder/decoder models, SingletonInputs encapsulates both a
|
||||
decoder and encoder prompt, creating an ambiguity about which
|
||||
prompt to construct the sequence from. The `from_decoder_prompt`
|
||||
constructor argument signals whether to construct the Sequence
|
||||
from the LLMInputs decoder prompt, or encoder prompt.
|
||||
from the SingletonInputs decoder prompt, or encoder prompt.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
@ -379,16 +396,16 @@ class Sequence:
|
||||
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
|
||||
lora_request: LoRA request.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
|
||||
(True) or encoder prompt (False.) Must be True
|
||||
for decoder-only model.
|
||||
from_decoder_prompt: Construct Sequence from SingletonInputs decoder
|
||||
prompt (True) or encoder prompt (False.) Must be
|
||||
True for decoder-only model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
inputs: "LLMInputs",
|
||||
inputs: "SingletonInputs",
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@ -404,19 +421,19 @@ class Sequence:
|
||||
self.from_decoder_prompt = from_decoder_prompt
|
||||
|
||||
# For decoder-only models, a Sequence is constructed
|
||||
# from an LLMInputs instance (the `inputs` arg.)
|
||||
# from an DecoderOnlyInputs instance (the `inputs` arg.)
|
||||
#
|
||||
# For encoder/decoder models the same `inputs`
|
||||
# instance could be utilized to construct either an
|
||||
# encoder sequence or a decoder sequence, because
|
||||
# `LLMInputs` has both decoder- and encoder-oriented
|
||||
# `DecoderOnlyInputs` has both decoder- and encoder-oriented
|
||||
# member variables (i.e. it encapsulates both an encoder
|
||||
# and a decoder prompt.) The decision of which type of sequence
|
||||
# to generate is determined by the `from_decoder_prompt` argument.
|
||||
#
|
||||
# When constructing a encoder sequence
|
||||
# (`from_decoder_prompt` False) it matters that
|
||||
# the `LLMInputs` instance stored in `inputs` is valid
|
||||
# the `DecoderOnlyInputs` instance stored in `inputs` is valid
|
||||
# in the sense that its encoder-related member variables are
|
||||
# populated; below, an exception is raised if this is
|
||||
# not the case.
|
||||
@ -424,8 +441,7 @@ class Sequence:
|
||||
# When constructing a decoder sequence (`from_decoder_prompt` True)
|
||||
# it does not matter whether `inputs` has its encoder-related
|
||||
# member variables populated.
|
||||
if not (from_decoder_prompt
|
||||
or is_valid_encoder_decoder_llm_inputs(inputs)):
|
||||
if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)):
|
||||
raise ValueError("Cannot extract encoder input prompt from "
|
||||
f"invalid input {inputs}; did you forget the "
|
||||
"encoder input prompt fields?")
|
||||
@ -471,15 +487,19 @@ class Sequence:
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
if self.inputs.get("multi_modal_data") and self.inputs.get(
|
||||
"encoder_multi_modal_data"):
|
||||
inputs = self.inputs
|
||||
|
||||
if (inputs.get("multi_modal_data")
|
||||
and inputs.get("encoder_multi_modal_data")):
|
||||
raise ValueError(
|
||||
"Multi-modal data in both encoder and decoder is not supported."
|
||||
)
|
||||
inputs = self.inputs
|
||||
return self.inputs.get("multi_modal_data") or (cast(
|
||||
EncoderDecoderLLMInputs,
|
||||
inputs).get("encoder_multi_modal_data")) or {}
|
||||
|
||||
return cast(
|
||||
"MultiModalDataDict",
|
||||
(inputs.get("multi_modal_data")
|
||||
or inputs.get("encoder_multi_modal_data") or {}),
|
||||
)
|
||||
|
||||
@property
|
||||
def mm_processor_kwargs(self) -> Dict[str, Any]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user