[VLM][Core] Support profiling with multiple multi-modal inputs per prompt (#7126)

This commit is contained in:
Cyrus Leung 2024-08-15 01:55:42 +08:00 committed by GitHub
parent 70b746efcf
commit 3f674a49b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 572 additions and 216 deletions

View File

@ -17,4 +17,4 @@ Input Processing Pipeline
6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model.
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision model.

View File

@ -15,6 +15,9 @@ by following :ref:`this guide <adding_multimodal_plugin>`.
Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
..
TODO: Add usage of --limit-mm-per-prompt when multi-image input is officially supported
Guides
++++++

View File

@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i
3. Register maximum number of multi-modal tokens
------------------------------------------------
For each modality type that the model accepts as input, calculate the maximum possible number of tokens
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
.. code-block:: diff

View File

@ -0,0 +1,24 @@
import pytest
from vllm.engine.arg_utils import EngineArgs
from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("arg", "expected"), [
(None, None),
("image=16", {
"image": 16
}),
("image=16,video=2", {
"image": 16,
"video": 2
}),
])
def test_limit_mm_per_prompt_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--limit-mm-per-prompt", arg])
assert args.limit_mm_per_prompt == expected

View File

@ -59,7 +59,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""

View File

@ -49,7 +49,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""

View File

@ -117,7 +117,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""

View File

@ -69,7 +69,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""

View File

@ -177,7 +177,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""

View File

@ -61,7 +61,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
@ -176,7 +176,7 @@ def run_multi_image_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
@ -197,6 +197,7 @@ def run_multi_image_test(
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
limit_mm_per_prompt={"image": len(images)},
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,

View File

@ -72,7 +72,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""

View File

@ -73,7 +73,7 @@ def run_test(
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""

View File

@ -1,15 +1,22 @@
from contextlib import nullcontext
import numpy as np
import pytest
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.config import ModelConfig, MultiModalConfig
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
@pytest.fixture
def mm_registry():
return MultiModalRegistry()
@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_clip_image_processor(image_assets, dtype, size_factor):
def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
@ -24,6 +31,9 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
dtype=dtype,
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": 1})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
@ -32,7 +42,7 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
image,
return_tensors="pt",
)
vllm_result = MULTIMODAL_REGISTRY.map_input(
vllm_result = mm_registry.map_input(
model_config,
{"image": image},
)
@ -48,7 +58,8 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_llava_next_image_processor(image_assets, dtype, size_factor):
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
size_factor):
MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf"
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
@ -63,6 +74,9 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
dtype=dtype,
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": 1})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
@ -71,7 +85,7 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
image,
return_tensors="pt",
)
vllm_result = MULTIMODAL_REGISTRY.map_input(
vllm_result = mm_registry.map_input(
model_config,
{"image": image},
)
@ -83,3 +97,61 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
@pytest.mark.parametrize(
("num_images", "limit", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
)
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": limit})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
image = image_assets[0].pil_image
if num_images == 0:
mm_inputs = {}
elif num_images == 1:
mm_inputs = {"image": image}
else:
mm_inputs = {"image": [image] * num_images}
with nullcontext() if is_valid else pytest.raises(ValueError):
mm_registry.map_input(model_config, mm_inputs)
# NOTE: We don't test zero images since the HF processor doesn't support it
@pytest.mark.parametrize("num_images", [1, 2])
def test_image_mapper_multi(image_assets, mm_registry, num_images):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
)
mm_config = MultiModalConfig(limit_per_prompt={"image": num_images})
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
image = image_assets[0].pil_image
mm_inputs = {"image": [image] * num_images}
mapped_inputs = mm_registry.map_input(model_config, mm_inputs)
assert len(mapped_inputs["pixel_values"]) == num_images

View File

@ -1,7 +1,8 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple,
Type, Union)
import torch
from transformers import PretrainedConfig
@ -1429,10 +1430,15 @@ class PromptAdapterConfig:
@dataclass
class MultiModalConfig:
"""Configs the input data format and how models should run for
multimodal models."""
"""Controls the behavior of multimodal models."""
limit_per_prompt: Mapping[str, int]
"""
The maximum number of multi-modal input instances allowed per prompt
for each :class:`~vllm.multimodal.MultiModalPlugin`.
"""
# TODO: Add configs to init vision tower or not.
pass
_STR_DTYPE_TO_TORCH_DTYPE = {

View File

@ -2,7 +2,8 @@ import argparse
import dataclasses
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
Union)
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
@ -15,8 +16,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
logger = init_logger(__name__)
@ -29,11 +29,32 @@ def nullable_str(val: str):
return val
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
if len(val) == 0:
return None
out_dict: Dict[str, int] = {}
for item in val.split(","):
try:
key, value = item.split("=")
except TypeError as exc:
msg = "Each item should be in the form KEY=VALUE"
raise ValueError(msg) from exc
try:
out_dict[key] = int(value)
except ValueError as exc:
msg = f"Failed to parse value of item {key}={value}"
raise ValueError(msg) from exc
return out_dict
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[List[str]]] = None
served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
@ -81,6 +102,7 @@ class EngineArgs:
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
@ -435,6 +457,21 @@ class EngineArgs:
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# Multimodal related configs
parser.add_argument(
'--limit-mm-per-prompt',
type=nullable_kvs,
default=EngineArgs.limit_mm_per_prompt,
# The default value is given in
# MultiModalRegistry.init_mm_limits_per_prompt
help=('For each multimodal plugin, limit how many '
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to 1 for '
'each modality.'))
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
@ -709,7 +746,8 @@ class EngineArgs:
"CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")
multimodal_config = MultiModalConfig()
multimodal_config = MultiModalConfig(
limit_per_prompt=self.limit_mm_per_prompt or {})
device_config = DeviceConfig(device=self.device)
model_config = ModelConfig(

View File

@ -24,8 +24,9 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs,
PromptInputs, SingletonPromptInputs)
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs,
SingletonPromptInputs)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -180,6 +181,7 @@ class LLMEngine:
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
@ -265,8 +267,9 @@ class LLMEngine:
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.input_processor = INPUT_REGISTRY.create_input_processor(
self.model_config)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
model_config)
self.model_executor = executor_class(
model_config=model_config,

View File

@ -1,6 +1,8 @@
import functools
from collections import UserDict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
Tuple, Type)
from torch import nn
from transformers import PretrainedConfig
@ -12,7 +14,7 @@ from .data import LLMInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig, MultiModalConfig
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal import MultiModalDataDict, MultiModalRegistry
from vllm.sequence import SequenceData
logger = init_logger(__name__)
@ -65,15 +67,38 @@ class InputContext:
N = TypeVar("N", bound=Type[nn.Module])
DummyDataFactory = Callable[[InputContext, int],
Tuple["SequenceData",
Optional["MultiModalDataDict"]]]
"""
Create dummy data to be inputted into the model.
Note:
class DummyDataFactory(Protocol):
def __call__(
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
"""
"""
...
class _MultiModalCounts(UserDict):
"""
Wraps `mm_counts` for a more informative error message
when attempting to access a plugin that does not exist.
"""
def __getitem__(self, key: str) -> int:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no multi-modal plugin with the key: {key}. "
f"Available keys: {set(self.keys())}")
raise KeyError(msg) from exc
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
"""Preprocess the inputs to the model."""
@ -95,6 +120,7 @@ class InputRegistry:
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
The default dummy data factory represents the longest possible text
@ -133,8 +159,12 @@ class InputRegistry:
return wrapper
def dummy_data_for_profiling(self, model_config: "ModelConfig",
seq_len: int):
def dummy_data_for_profiling(
self,
model_config: "ModelConfig",
seq_len: int,
mm_registry: "MultiModalRegistry",
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data for profiling the memory usage of a model.
@ -142,6 +172,10 @@ class InputRegistry:
See also:
:ref:`enabling_multimodal_inputs`
Note:
This should be called after
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
@ -149,8 +183,29 @@ class InputRegistry:
model_cls, _ = get_model_architecture(model_config)
dummy_factory = self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
return dummy_factory(InputContext(model_config), seq_len)
seq_data, mm_data = dummy_factory(
InputContext(model_config),
seq_len,
_MultiModalCounts(mm_counts),
)
# Having more tokens is over-conservative but otherwise fine
num_tokens = seq_data.prompt_token_ids
assert len(num_tokens) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if mm_data is not None:
for k, v in mm_data.items():
num_items = len(v) if isinstance(v, list) else 1
num_expected = mm_counts[k]
assert num_items >= num_expected, (
f"Expected at least {num_expected} dummy '{k}' instances "
f"for profiling, but found {num_items} instances instead.")
return seq_data, mm_data
def _default_input_processor(self, ctx: InputContext,
inputs: LLMInputs) -> LLMInputs:

View File

@ -133,7 +133,7 @@ def _get_model_initialization_kwargs(
if supports_multimodal(model_class):
if multimodal_config is None:
raise ValueError("Provide vision related configurations "
raise ValueError("Provide multi-modal related configurations "
"through LLM entrypoint or engine arguments.")
extra_kwargs["multimodal_config"] = multimodal_config

View File

@ -31,13 +31,13 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
def get_blip_image_feature_size(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def get_max_blip_image_tokens(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_image_feature_size(hf_config)
@ -60,6 +60,7 @@ def dummy_seq_data_for_blip(
def dummy_image_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
@ -71,7 +72,7 @@ def dummy_image_for_blip(
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def input_processor_for_blip(

View File

@ -1,4 +1,5 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
@ -413,17 +414,39 @@ def get_max_blip2_image_tokens(ctx: InputContext):
raise NotImplementedError(msg)
def dummy_data_for_blip2(ctx: InputContext, seq_len: int):
def dummy_seq_data_for_blip2(
hf_config: Blip2Config,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip2_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_blip2_image_feature_size(hf_config)
token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
seq_data = SequenceData(token_ids)
seq_data = dummy_seq_data_for_blip2(
hf_config,
seq_len,
num_images,
image_token_id=BLIP2_IMAGE_TOKEN_ID,
)
if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config)
mm_data = dummy_image_for_blip(vision_config, num_images)
return seq_data, mm_data

View File

@ -1,6 +1,6 @@
from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Optional, Tuple,
TypedDict)
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict)
import torch
import torch.nn.functional as F
@ -19,8 +19,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -61,6 +60,7 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
def dummy_seq_data_for_chameleon(
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
@ -70,12 +70,14 @@ def dummy_seq_data_for_chameleon(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_image_for_chameleon(
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
@ -87,17 +89,20 @@ def dummy_image_for_chameleon(
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int):
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_chameleon(
seq_len,
num_images,
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
)
mm_data = dummy_image_for_chameleon()
mm_data = dummy_image_for_chameleon(num_images)
return seq_data, mm_data

View File

@ -43,6 +43,7 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
def dummy_seq_data_for_clip(
hf_config: CLIPVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
@ -52,13 +53,14 @@ def dummy_seq_data_for_clip(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_image_for_clip(
hf_config: CLIPVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
@ -70,7 +72,7 @@ def dummy_image_for_clip(
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def input_processor_for_clip(

View File

@ -16,7 +16,7 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
@ -29,8 +29,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -94,27 +93,33 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
return (ncol + 1) * nrow
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int):
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol, nrow = get_max_fuyu_image_feature_size()
image_feature_size = get_max_fuyu_image_tokens(ctx)
token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
token_ids += [0] * (seq_len - image_feature_size)
image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
token_ids = image_token_ids * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_image_for_fuyu(
num_images: int,
*,
image_width: int,
image_height: int,
):
image = Image.new("RGB", (image_width, image_height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len)
mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH,
MAX_IMAGE_FEATURE_SIZE_HEIGHT)
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
mm_data = dummy_image_for_fuyu(num_images,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return seq_data, mm_data

View File

@ -11,14 +11,11 @@ logger = init_logger(__name__)
@runtime_checkable
class SupportsMultiModal(Protocol):
"""
The interface required for all multimodal (vision or audio) language
models.
"""
"""The interface required for all multi-modal models."""
supports_multimodal: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multimodal inputs.
A flag that indicates this model supports multi-modal inputs.
Note:
There is no need to redefine this flag if this class is in the

View File

@ -5,7 +5,8 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
@ -230,7 +231,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
def input_mapper_for_internvl(ctx: InputContext, data: object):
hf_config = ctx.get_hf_config(PretrainedConfig)
hf_config = ctx.get_hf_config()
use_thumbnail = hf_config.use_thumbnail
min_num = hf_config.min_dynamic_patch
@ -256,7 +257,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
})
def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
image_feature_size = get_max_internvl_image_tokens(ctx)
model_config = ctx.model_config
@ -268,6 +271,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
@ -281,6 +285,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
mm_data = dummy_image_for_clip(
vision_config,
num_images,
image_width_override=max_image_width,
image_height_override=max_image_height,
)

View File

@ -1,5 +1,6 @@
import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
@ -9,8 +10,7 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -88,9 +88,11 @@ def get_max_llava_image_tokens(ctx: InputContext):
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_max_llava_image_tokens(ctx)
@ -98,21 +100,23 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(vision_config)
mm_data = dummy_image_for_clip(vision_config, num_images)
return seq_data, mm_data
elif isinstance(vision_config, SiglipVisionConfig):
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_siglip(vision_config)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"

View File

@ -1,5 +1,6 @@
import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
@ -13,8 +14,7 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -158,9 +158,11 @@ def get_max_llava_next_image_tokens(ctx: InputContext):
)
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_max_llava_next_image_tokens(ctx)
@ -168,12 +170,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
vision_config,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
@ -183,12 +187,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_siglip(
vision_config,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)

View File

@ -24,8 +24,8 @@
import math
import re
from functools import partial
from typing import (Any, Callable, Iterable, List, Optional, Tuple, TypedDict,
Union)
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np
import torch
@ -42,8 +42,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
@ -408,22 +407,24 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
return getattr(hf_config, "query_num", 64)
def dummy_seq_data_for_minicpmv(seq_len: int):
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
token_ids = [0] * seq_len
return SequenceData(token_ids)
def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
width = height = hf_config.image_size
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config()
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_minicpmv(seq_len)
mm_data = dummy_image_for_minicpmv(hf_config)
seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
mm_data = dummy_image_for_minicpmv(hf_config, num_images)
return seq_data, mm_data

View File

@ -1,4 +1,5 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
from torch import nn
@ -9,8 +10,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaModel
@ -57,17 +57,20 @@ def get_max_paligemma_image_tokens(ctx: InputContext):
return get_max_siglip_image_tokens(vision_config)
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
)
mm_data = dummy_image_for_siglip(vision_config)
mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data

View File

@ -15,7 +15,8 @@
# limitations under the License.
import re
from functools import lru_cache
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import numpy as np
import torch
@ -28,8 +29,7 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -347,18 +347,22 @@ def get_max_phi3v_image_tokens(ctx: InputContext):
)
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
image_feature_size = get_max_phi3v_image_tokens(ctx)
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
num_images,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)

View File

@ -52,6 +52,7 @@ def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
def dummy_seq_data_for_siglip(
hf_config: SiglipVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
@ -61,13 +62,14 @@ def dummy_seq_data_for_siglip(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
token_ids = [image_token_id] * image_feature_size * num_images
token_ids += [0] * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_image_for_siglip(
hf_config: SiglipVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
@ -79,7 +81,7 @@ def dummy_image_for_siglip(
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def input_processor_for_siglip(

View File

@ -1,9 +1,9 @@
import sys
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final
import numpy as np
import torch
@ -116,17 +116,30 @@ class MultiModalInputs(_MultiModalInputsBase):
batched_inputs)
_T = TypeVar("_T")
MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.
The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""
@final
class MultiModalDataBuiltins(TypedDict, total=False):
"""Modality types that are predefined by vLLM."""
image: Image.Image
"""The input image."""
image: MultiModalData[Image.Image]
"""The input image(s)."""
audio: Tuple[np.ndarray, Union[int, float]]
"""The input audio and its sampling rate."""
audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
"""The input audio item(s) and corresponding sampling rate(s)."""
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
MultiModalDataDict = Union[MultiModalDataBuiltins,
Mapping[str, MultiModalData[object]]]
"""
A dictionary containing an item for each modality type to input.
@ -137,7 +150,8 @@ Note:
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
MultiModalInputs]
"""
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
@ -181,8 +195,11 @@ class MultiModalPlugin(ABC):
raise NotImplementedError
@abstractmethod
def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[object],
) -> MultiModalInputs:
"""
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
@ -225,7 +242,7 @@ class MultiModalPlugin(ABC):
return wrapper
def map_input(self, model_config: ModelConfig,
data: object) -> MultiModalInputs:
data: MultiModalData[object]) -> MultiModalInputs:
"""
Transform the data into a dictionary of model inputs using the
input mapper registered for that model.
@ -254,8 +271,8 @@ class MultiModalPlugin(ABC):
@abstractmethod
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
Calculate the maximum number of tokens, corresponding to a single
instance of multimodal data, that are passed to the language model.
"""
raise NotImplementedError
@ -269,8 +286,9 @@ class MultiModalPlugin(ABC):
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of multi-modal tokens input to the
language model for a model class.
Register the maximum number of tokens, corresponding to a single
instance of multimodal data, that are passed to the language model
for a model class.
If `None` is provided, then the default calculation is used instead.

View File

@ -11,7 +11,7 @@ from vllm.transformers_utils.image_processor import get_image_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.utils import is_list_of
from .base import MultiModalInputs, MultiModalPlugin
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
logger = init_logger(__name__)
@ -110,8 +110,11 @@ class ImagePlugin(MultiModalPlugin):
model_config.model,
trust_remote_code=model_config.trust_remote_code)
def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[object],
) -> MultiModalInputs:
model_config = ctx.model_config
# PIL image

View File

@ -1,19 +1,33 @@
import functools
from typing import Dict, Optional, Sequence
from collections import UserDict
from typing import Dict, Mapping, Optional, Sequence
import torch
from vllm.config import ModelConfig
from vllm.config import ModelConfig, MultiModalConfig
from vllm.logger import init_logger
from .audio import AudioPlugin
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
MultiModalPlugin, MultiModalTokensCalc)
MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
from .image import ImagePlugin
logger = init_logger(__name__)
class _MultiModalLimits(UserDict):
"""
Wraps `_limits_by_model` for a more informative error message
when attempting to access a model that does not exist.
"""
def __getitem__(self, key: ModelConfig) -> Dict[str, int]:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
"forget to call `init_mm_limits_per_prompt`?")
raise KeyError(msg) from exc
class MultiModalRegistry:
"""
A registry that dispatches data processing to the
@ -28,6 +42,11 @@ class MultiModalRegistry:
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
self._plugins = {p.get_data_key(): p for p in plugins}
# This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
self._limits_by_model = _MultiModalLimits()
def register_plugin(self, plugin: MultiModalPlugin) -> None:
"""
Register a multi-modal plugin so it can be recognized by vLLM.
@ -86,13 +105,24 @@ class MultiModalRegistry:
via the input mapper registered for that model.
See :meth:`MultiModalPlugin.map_input` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
merged_dict: Dict[str, torch.Tensor] = {}
merged_dict: Dict[str, NestedTensors] = {}
for data_key, data_value in data.items():
input_dict = self._get_plugin(data_key) \
.map_input(model_config, data_value)
plugin = self._get_plugin(data_key)
num_items = len(data_value) if isinstance(data_value, list) else 1
max_items = self._limits_by_model[model_config][data_key]
if num_items > max_items:
raise ValueError(
f"You set {data_key}={max_items} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but found {num_items} items "
"in the same prompt.")
input_dict = plugin.map_input(model_config, data_value)
for input_key, input_tensor in input_dict.items():
if input_key in merged_dict:
raise ValueError(f"The input mappers (keys={set(data)}) "
@ -115,8 +145,9 @@ class MultiModalRegistry:
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of tokens, belonging to a
specific modality, input to the language model for a model class.
Register the maximum number of tokens, corresponding to a single
instance of multimodal data belonging to a specific modality, that are
passed to the language model for a model class.
"""
return self._get_plugin(data_type_key) \
.register_max_multimodal_tokens(max_mm_tokens)
@ -126,8 +157,8 @@ class MultiModalRegistry:
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of image tokens
input to the language model for a model class.
Register the maximum number of image tokens, corresponding to a single
image, that are passed to the language model for a model class.
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)
@ -137,7 +168,61 @@ class MultiModalRegistry:
for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
return sum(
plugin.get_max_multimodal_tokens(model_config)
for plugin in self._plugins.values())
limits_per_plugin = self._limits_by_model[model_config]
return sum((limits_per_plugin[key] *
plugin.get_max_multimodal_tokens(model_config))
for key, plugin in self._plugins.items())
def init_mm_limits_per_prompt(
self,
model_config: ModelConfig,
multimodal_config: Optional[MultiModalConfig],
) -> None:
"""
Initialize the maximum number of multi-modal input instances for each
modality that are allowed per prompt for a model class.
"""
if model_config in self._limits_by_model:
logger.warning(
"`mm_limits` has already been set for model=%s, and will "
"be overwritten by the new values.", model_config.model)
if multimodal_config is None:
limits_per_plugin = self._disabled_limits_per_plugin
else:
config_limits_per_plugin = multimodal_config.limit_per_prompt
extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
if extra_keys:
logger.warning(
"Detected extra keys in `--limit-mm-per-prompt` which "
"are not registered as multi-modal plugins: %s. "
"They will be ignored.", extra_keys)
# NOTE: Currently the default is set to 1 for each plugin
# TODO: Automatically determine the limits based on budget
# once more models support multi-image inputs
limits_per_plugin = {
key: config_limits_per_plugin.get(key, 1)
for key in self._plugins
}
self._limits_by_model[model_config] = limits_per_plugin
def get_mm_limits_per_prompt(
self,
model_config: ModelConfig,
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class.
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
return self._limits_by_model[model_config]

View File

@ -13,7 +13,6 @@ import threading
import uuid
import warnings
from asyncio import FIRST_COMPLETED, ensure_future
from collections import defaultdict
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
@ -760,16 +759,6 @@ class CudaMemoryProfiler:
gc.collect()
def str_to_int_tuple(s: str) -> Tuple[int, ...]:
"""Convert a string to a tuple of integers."""
try:
return tuple(map(int, s.split(",")))
except ValueError as e:
raise ValueError(
"String must be a series of integers separated by commas "
f"(e.g., 1, 2, 3). Given input: {s}") from e
def make_ndarray_with_pad(
x: List[List[T]],
pad: T,
@ -863,23 +852,6 @@ def is_list_of(
assert_never(check)
def merge_dicts(dict1: Dict[K, List[T]],
dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
"""
merged_dict: Dict[K, List[T]] = defaultdict(list)
for key, value in dict1.items():
merged_dict[key].extend(value)
for key, value in dict2.items():
merged_dict[key].extend(value)
return dict(merged_dict)
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
Tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""

View File

@ -12,9 +12,10 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
SequenceGroupMetadata)
@ -83,6 +84,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
'''
EncoderDecoderModelRunner constructor.
@ -271,6 +274,16 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
seqs: List[SequenceGroupMetadata] = []
model_config = self.model_config
mm_config = self.multimodal_config
input_registry = self.input_registry
mm_registry = self.mm_registry
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
if max_mm_tokens > 0:
raise NotImplementedError(
"Multi-modal encoder-decoder models are not supported yet")
batch_size = 0
for group_id in range(max_num_seqs):
@ -278,8 +291,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, _ = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
seq_data, _ = input_registry \
.dummy_data_for_profiling(model_config, seq_len, mm_registry)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (

View File

@ -31,7 +31,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
@ -43,7 +43,7 @@ from vllm.model_executor.models.interfaces import (supports_lora,
supports_multimodal)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
MultiModalInputs, MultiModalRegistry)
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
@ -807,6 +807,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.model_config = model_config
self.parallel_config = parallel_config
@ -860,8 +862,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
) if num_attn_heads else None
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
self.input_registry = input_registry
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry \
.create_input_mapper(model_config)
# Lazy initialization
self.model: nn.Module # Set after load_model
@ -902,7 +906,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
assert supports_lora(self.model), "Model does not support LoRA"
assert not supports_multimodal(
self.model
), "To be tested: multimodal language model with LoRA settings."
), "To be tested: Multi-modal model with LoRA settings."
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
@ -1046,17 +1050,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for vision encoding, which needs
# to be accounted for when calculating the GPU blocks for
# Additional GPU memory may be needed for multi-modal encoding, which
# needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
model_config = self.model_config
mm_config = self.multimodal_config
if supports_multimodal(self.model):
max_mm_tokens = MULTIMODAL_REGISTRY \
.get_max_multimodal_tokens(model_config)
input_registry = self.input_registry
mm_registry = self.mm_registry
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens)
@ -1074,13 +1082,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq_data, dummy_multi_modal_data = input_registry \
.dummy_data_for_profiling(model_config, seq_len, mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),

View File

@ -9,12 +9,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_multimodal
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
MultiModalInputs, MultiModalRegistry)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
@ -89,6 +88,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
*args,
**kwargs,
):
@ -120,8 +121,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
)
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
self.input_registry = input_registry
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry \
.create_input_mapper(model_config)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
@ -157,17 +160,21 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for vision encoding, which needs
# to be accounted for when calculating the GPU blocks for
# Additional GPU memory may be needed for multi-modal encoding, which
# needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
model_config = self.model_config
mm_config = self.multimodal_config
if supports_multimodal(self.model):
max_mm_tokens = MULTIMODAL_REGISTRY \
.get_max_multimodal_tokens(model_config)
input_registry = self.input_registry
mm_registry = self.mm_registry
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens)
@ -183,13 +190,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
# Having more tokens is over-conservative but otherwise fine
assert len(seq_data.prompt_token_ids) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but got: {len(seq_data.prompt_token_ids)}")
seq_data, dummy_multi_modal_data = input_registry \
.dummy_data_for_profiling(model_config, seq_len, mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),