mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
[VLM][Core] Support profiling with multiple multi-modal inputs per prompt (#7126)
This commit is contained in:
parent
70b746efcf
commit
3f674a49b5
@ -17,4 +17,4 @@ Input Processing Pipeline
|
||||
|
||||
6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
|
||||
|
||||
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model.
|
||||
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision model.
|
||||
|
||||
@ -15,6 +15,9 @@ by following :ref:`this guide <adding_multimodal_plugin>`.
|
||||
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
|
||||
|
||||
..
|
||||
TODO: Add usage of --limit-mm-per-prompt when multi-image input is officially supported
|
||||
|
||||
Guides
|
||||
++++++
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i
|
||||
3. Register maximum number of multi-modal tokens
|
||||
------------------------------------------------
|
||||
|
||||
For each modality type that the model accepts as input, calculate the maximum possible number of tokens
|
||||
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance
|
||||
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
|
||||
|
||||
.. code-block:: diff
|
||||
|
||||
24
tests/engine/test_arg_utils.py
Normal file
24
tests/engine/test_arg_utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("arg", "expected"), [
|
||||
(None, None),
|
||||
("image=16", {
|
||||
"image": 16
|
||||
}),
|
||||
("image=16,video=2", {
|
||||
"image": 16,
|
||||
"video": 2
|
||||
}),
|
||||
])
|
||||
def test_limit_mm_per_prompt_parser(arg, expected):
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
if arg is None:
|
||||
args = parser.parse_args([])
|
||||
else:
|
||||
args = parser.parse_args(["--limit-mm-per-prompt", arg])
|
||||
|
||||
assert args.limit_mm_per_prompt == expected
|
||||
@ -59,7 +59,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalData objects and corresponding
|
||||
vision language config as input.
|
||||
MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
|
||||
@ -49,7 +49,7 @@ def run_test(
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
|
||||
@ -117,7 +117,7 @@ def run_test(
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
|
||||
@ -69,7 +69,7 @@ def run_test(
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
|
||||
@ -177,7 +177,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
|
||||
@ -61,7 +61,7 @@ def run_test(
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
@ -176,7 +176,7 @@ def run_multi_image_test(
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
@ -197,6 +197,7 @@ def run_multi_image_test(
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
limit_mm_per_prompt={"image": len(images)},
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
|
||||
@ -72,7 +72,7 @@ def run_test(
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
|
||||
@ -73,7 +73,7 @@ def run_test(
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
|
||||
@ -1,15 +1,22 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.config import ModelConfig, MultiModalConfig
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm_registry():
|
||||
return MultiModalRegistry()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
|
||||
def test_clip_image_processor(image_assets, dtype, size_factor):
|
||||
def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
|
||||
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
|
||||
@ -24,6 +31,9 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
)
|
||||
mm_config = MultiModalConfig(limit_per_prompt={"image": 1})
|
||||
|
||||
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
|
||||
|
||||
for asset in image_assets:
|
||||
image = rescale_image_size(asset.pil_image, size_factor)
|
||||
@ -32,7 +42,7 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
|
||||
image,
|
||||
return_tensors="pt",
|
||||
)
|
||||
vllm_result = MULTIMODAL_REGISTRY.map_input(
|
||||
vllm_result = mm_registry.map_input(
|
||||
model_config,
|
||||
{"image": image},
|
||||
)
|
||||
@ -48,7 +58,8 @@ def test_clip_image_processor(image_assets, dtype, size_factor):
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
|
||||
def test_llava_next_image_processor(image_assets, dtype, size_factor):
|
||||
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
|
||||
size_factor):
|
||||
MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf"
|
||||
|
||||
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
|
||||
@ -63,6 +74,9 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
)
|
||||
mm_config = MultiModalConfig(limit_per_prompt={"image": 1})
|
||||
|
||||
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
|
||||
|
||||
for asset in image_assets:
|
||||
image = rescale_image_size(asset.pil_image, size_factor)
|
||||
@ -71,7 +85,7 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
|
||||
image,
|
||||
return_tensors="pt",
|
||||
)
|
||||
vllm_result = MULTIMODAL_REGISTRY.map_input(
|
||||
vllm_result = mm_registry.map_input(
|
||||
model_config,
|
||||
{"image": image},
|
||||
)
|
||||
@ -83,3 +97,61 @@ def test_llava_next_image_processor(image_assets, dtype, size_factor):
|
||||
|
||||
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
|
||||
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("num_images", "limit", "is_valid"),
|
||||
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
|
||||
(2, 1, False), (2, 2, True)],
|
||||
)
|
||||
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
revision=None,
|
||||
)
|
||||
mm_config = MultiModalConfig(limit_per_prompt={"image": limit})
|
||||
|
||||
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
|
||||
|
||||
image = image_assets[0].pil_image
|
||||
if num_images == 0:
|
||||
mm_inputs = {}
|
||||
elif num_images == 1:
|
||||
mm_inputs = {"image": image}
|
||||
else:
|
||||
mm_inputs = {"image": [image] * num_images}
|
||||
|
||||
with nullcontext() if is_valid else pytest.raises(ValueError):
|
||||
mm_registry.map_input(model_config, mm_inputs)
|
||||
|
||||
|
||||
# NOTE: We don't test zero images since the HF processor doesn't support it
|
||||
@pytest.mark.parametrize("num_images", [1, 2])
|
||||
def test_image_mapper_multi(image_assets, mm_registry, num_images):
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
revision=None,
|
||||
)
|
||||
mm_config = MultiModalConfig(limit_per_prompt={"image": num_images})
|
||||
|
||||
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
|
||||
|
||||
image = image_assets[0].pil_image
|
||||
mm_inputs = {"image": [image] * num_images}
|
||||
|
||||
mapped_inputs = mm_registry.map_input(model_config, mm_inputs)
|
||||
assert len(mapped_inputs["pixel_values"]) == num_images
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import enum
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
|
||||
from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple,
|
||||
Type, Union)
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@ -1429,10 +1430,15 @@ class PromptAdapterConfig:
|
||||
|
||||
@dataclass
|
||||
class MultiModalConfig:
|
||||
"""Configs the input data format and how models should run for
|
||||
multimodal models."""
|
||||
"""Controls the behavior of multimodal models."""
|
||||
|
||||
limit_per_prompt: Mapping[str, int]
|
||||
"""
|
||||
The maximum number of multi-modal input instances allowed per prompt
|
||||
for each :class:`~vllm.multimodal.MultiModalPlugin`.
|
||||
"""
|
||||
|
||||
# TODO: Add configs to init vision tower or not.
|
||||
pass
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
|
||||
@ -2,7 +2,8 @@ import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
|
||||
Union)
|
||||
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
@ -15,8 +16,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -29,11 +29,32 @@ def nullable_str(val: str):
|
||||
return val
|
||||
|
||||
|
||||
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
|
||||
if len(val) == 0:
|
||||
return None
|
||||
|
||||
out_dict: Dict[str, int] = {}
|
||||
for item in val.split(","):
|
||||
try:
|
||||
key, value = item.split("=")
|
||||
except TypeError as exc:
|
||||
msg = "Each item should be in the form KEY=VALUE"
|
||||
raise ValueError(msg) from exc
|
||||
|
||||
try:
|
||||
out_dict[key] = int(value)
|
||||
except ValueError as exc:
|
||||
msg = f"Failed to parse value of item {key}={value}"
|
||||
raise ValueError(msg) from exc
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineArgs:
|
||||
"""Arguments for vLLM engine."""
|
||||
model: str = 'facebook/opt-125m'
|
||||
served_model_name: Optional[Union[List[str]]] = None
|
||||
served_model_name: Optional[Union[str, List[str]]] = None
|
||||
tokenizer: Optional[str] = None
|
||||
skip_tokenizer_init: bool = False
|
||||
tokenizer_mode: str = 'auto'
|
||||
@ -81,6 +102,7 @@ class EngineArgs:
|
||||
# notice.
|
||||
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
|
||||
tokenizer_pool_extra_config: Optional[dict] = None
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
||||
enable_lora: bool = False
|
||||
max_loras: int = 1
|
||||
max_lora_rank: int = 16
|
||||
@ -435,6 +457,21 @@ class EngineArgs:
|
||||
'This should be a JSON string that will be '
|
||||
'parsed into a dictionary. Ignored if '
|
||||
'tokenizer_pool_size is 0.')
|
||||
|
||||
# Multimodal related configs
|
||||
parser.add_argument(
|
||||
'--limit-mm-per-prompt',
|
||||
type=nullable_kvs,
|
||||
default=EngineArgs.limit_mm_per_prompt,
|
||||
# The default value is given in
|
||||
# MultiModalRegistry.init_mm_limits_per_prompt
|
||||
help=('For each multimodal plugin, limit how many '
|
||||
'input instances to allow for each prompt. '
|
||||
'Expects a comma-separated list of items, '
|
||||
'e.g.: `image=16,video=2` allows a maximum of 16 '
|
||||
'images and 2 videos per prompt. Defaults to 1 for '
|
||||
'each modality.'))
|
||||
|
||||
# LoRA related configs
|
||||
parser.add_argument('--enable-lora',
|
||||
action='store_true',
|
||||
@ -709,7 +746,8 @@ class EngineArgs:
|
||||
"CPU offload space must be non-negative"
|
||||
f", but got {self.cpu_offload_gb}")
|
||||
|
||||
multimodal_config = MultiModalConfig()
|
||||
multimodal_config = MultiModalConfig(
|
||||
limit_per_prompt=self.limit_mm_per_prompt or {})
|
||||
|
||||
device_config = DeviceConfig(device=self.device)
|
||||
model_config = ModelConfig(
|
||||
|
||||
@ -24,8 +24,9 @@ from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs,
|
||||
PromptInputs, SingletonPromptInputs)
|
||||
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
|
||||
InputRegistry, LLMInputs, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -180,6 +181,7 @@ class LLMEngine:
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Initializing an LLM engine (v%s) with config: "
|
||||
@ -265,8 +267,9 @@ class LLMEngine:
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
|
||||
self.input_processor = INPUT_REGISTRY.create_input_processor(
|
||||
self.model_config)
|
||||
self.input_registry = input_registry
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
model_config)
|
||||
|
||||
self.model_executor = executor_class(
|
||||
model_config=model_config,
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import functools
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type
|
||||
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
|
||||
Tuple, Type)
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
@ -12,7 +14,7 @@ from .data import LLMInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, MultiModalConfig
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalRegistry
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -65,15 +67,38 @@ class InputContext:
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
DummyDataFactory = Callable[[InputContext, int],
|
||||
Tuple["SequenceData",
|
||||
Optional["MultiModalDataDict"]]]
|
||||
|
||||
class DummyDataFactory(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
|
||||
"""
|
||||
Create dummy data to be inputted into the model.
|
||||
|
||||
Note:
|
||||
:data:`InputProcessor` is not applied to the dummy data.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class _MultiModalCounts(UserDict):
|
||||
"""
|
||||
Wraps `mm_counts` for a more informative error message
|
||||
when attempting to access a plugin that does not exist.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: str) -> int:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError as exc:
|
||||
msg = (f"There is no multi-modal plugin with the key: {key}. "
|
||||
f"Available keys: {set(self.keys())}")
|
||||
raise KeyError(msg) from exc
|
||||
|
||||
|
||||
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
|
||||
"""Preprocess the inputs to the model."""
|
||||
@ -95,6 +120,7 @@ class InputRegistry:
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
|
||||
"""
|
||||
The default dummy data factory represents the longest possible text
|
||||
@ -133,8 +159,12 @@ class InputRegistry:
|
||||
|
||||
return wrapper
|
||||
|
||||
def dummy_data_for_profiling(self, model_config: "ModelConfig",
|
||||
seq_len: int):
|
||||
def dummy_data_for_profiling(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
@ -142,6 +172,10 @@ class InputRegistry:
|
||||
|
||||
See also:
|
||||
:ref:`enabling_multimodal_inputs`
|
||||
|
||||
Note:
|
||||
This should be called after
|
||||
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
@ -149,8 +183,29 @@ class InputRegistry:
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
dummy_factory = self._dummy_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
|
||||
return dummy_factory(InputContext(model_config), seq_len)
|
||||
seq_data, mm_data = dummy_factory(
|
||||
InputContext(model_config),
|
||||
seq_len,
|
||||
_MultiModalCounts(mm_counts),
|
||||
)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
num_tokens = seq_data.prompt_token_ids
|
||||
assert len(num_tokens) >= seq_len, (
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(num_tokens)} tokens instead.")
|
||||
|
||||
if mm_data is not None:
|
||||
for k, v in mm_data.items():
|
||||
num_items = len(v) if isinstance(v, list) else 1
|
||||
num_expected = mm_counts[k]
|
||||
assert num_items >= num_expected, (
|
||||
f"Expected at least {num_expected} dummy '{k}' instances "
|
||||
f"for profiling, but found {num_items} instances instead.")
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
def _default_input_processor(self, ctx: InputContext,
|
||||
inputs: LLMInputs) -> LLMInputs:
|
||||
|
||||
@ -133,7 +133,7 @@ def _get_model_initialization_kwargs(
|
||||
|
||||
if supports_multimodal(model_class):
|
||||
if multimodal_config is None:
|
||||
raise ValueError("Provide vision related configurations "
|
||||
raise ValueError("Provide multi-modal related configurations "
|
||||
"through LLM entrypoint or engine arguments.")
|
||||
|
||||
extra_kwargs["multimodal_config"] = multimodal_config
|
||||
|
||||
@ -31,13 +31,13 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||
|
||||
|
||||
def get_blip_image_feature_size(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
|
||||
return get_blip_num_patches(image_size=hf_config.image_size,
|
||||
patch_size=hf_config.patch_size)
|
||||
|
||||
|
||||
def get_max_blip_image_tokens(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
|
||||
return get_blip_image_feature_size(hf_config)
|
||||
|
||||
|
||||
@ -60,6 +60,7 @@ def dummy_seq_data_for_blip(
|
||||
|
||||
def dummy_image_for_blip(
|
||||
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
|
||||
num_images: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
@ -71,7 +72,7 @@ def dummy_image_for_blip(
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image}
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def input_processor_for_blip(
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -413,17 +414,39 @@ def get_max_blip2_image_tokens(ctx: InputContext):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def dummy_data_for_blip2(ctx: InputContext, seq_len: int):
|
||||
def dummy_seq_data_for_blip2(
|
||||
hf_config: Blip2Config,
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size * num_images
|
||||
token_ids += [0] * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size = get_blip2_image_feature_size(hf_config)
|
||||
token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
|
||||
token_ids += [0] * (seq_len - image_feature_size)
|
||||
seq_data = SequenceData(token_ids)
|
||||
seq_data = dummy_seq_data_for_blip2(
|
||||
hf_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=BLIP2_IMAGE_TOKEN_ID,
|
||||
)
|
||||
|
||||
if isinstance(vision_config, Blip2VisionConfig):
|
||||
mm_data = dummy_image_for_blip(vision_config)
|
||||
mm_data = dummy_image_for_blip(vision_config, num_images)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from functools import cached_property
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Optional, Tuple,
|
||||
TypedDict)
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
Tuple, TypedDict)
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -19,8 +19,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -61,6 +60,7 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
|
||||
|
||||
def dummy_seq_data_for_chameleon(
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
@ -70,12 +70,14 @@ def dummy_seq_data_for_chameleon(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size
|
||||
token_ids += [0] * (seq_len - image_feature_size)
|
||||
token_ids = [image_token_id] * image_feature_size * num_images
|
||||
token_ids += [0] * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_image_for_chameleon(
|
||||
num_images: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
@ -87,17 +89,20 @@ def dummy_image_for_chameleon(
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image}
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int):
|
||||
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
seq_data = dummy_seq_data_for_chameleon(
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_chameleon()
|
||||
mm_data = dummy_image_for_chameleon(num_images)
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
|
||||
@ -43,6 +43,7 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
|
||||
def dummy_seq_data_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
@ -52,13 +53,14 @@ def dummy_seq_data_for_clip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size
|
||||
token_ids += [0] * (seq_len - image_feature_size)
|
||||
token_ids = [image_token_id] * image_feature_size * num_images
|
||||
token_ids += [0] * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_image_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
num_images: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
@ -70,7 +72,7 @@ def dummy_image_for_clip(
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image}
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def input_processor_for_clip(
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -29,8 +29,7 @@ from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -94,27 +93,33 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
|
||||
return (ncol + 1) * nrow
|
||||
|
||||
|
||||
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int):
|
||||
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
|
||||
ncol, nrow = get_max_fuyu_image_feature_size()
|
||||
image_feature_size = get_max_fuyu_image_tokens(ctx)
|
||||
|
||||
token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
|
||||
token_ids += [0] * (seq_len - image_feature_size)
|
||||
image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
|
||||
token_ids = image_token_ids * num_images
|
||||
token_ids += [0] * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_image_for_fuyu(
|
||||
num_images: int,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
):
|
||||
image = Image.new("RGB", (image_width, image_height), color=0)
|
||||
return {"image": image}
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int):
|
||||
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len)
|
||||
mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT)
|
||||
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
|
||||
mm_data = dummy_image_for_fuyu(num_images,
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
|
||||
@ -11,14 +11,11 @@ logger = init_logger(__name__)
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsMultiModal(Protocol):
|
||||
"""
|
||||
The interface required for all multimodal (vision or audio) language
|
||||
models.
|
||||
"""
|
||||
"""The interface required for all multi-modal models."""
|
||||
|
||||
supports_multimodal: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports multimodal inputs.
|
||||
A flag that indicates this model supports multi-modal inputs.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
|
||||
@ -5,7 +5,8 @@
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
import itertools
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -230,7 +231,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
|
||||
|
||||
def input_mapper_for_internvl(ctx: InputContext, data: object):
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
@ -256,7 +257,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
|
||||
})
|
||||
|
||||
|
||||
def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
|
||||
def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size = get_max_internvl_image_tokens(ctx)
|
||||
model_config = ctx.model_config
|
||||
@ -268,6 +271,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=tokenizer.encode(IMG_CONTEXT,
|
||||
add_special_tokens=False)[0],
|
||||
image_feature_size_override=image_feature_size,
|
||||
@ -281,6 +285,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
|
||||
|
||||
mm_data = dummy_image_for_clip(
|
||||
vision_config,
|
||||
num_images,
|
||||
image_width_override=max_image_width,
|
||||
image_height_override=max_image_height,
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -9,8 +10,7 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -88,9 +88,11 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
|
||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||
|
||||
@ -98,21 +100,23 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(vision_config)
|
||||
mm_data = dummy_image_for_clip(vision_config, num_images)
|
||||
return seq_data, mm_data
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(vision_config)
|
||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -13,8 +14,7 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -158,9 +158,11 @@ def get_max_llava_next_image_tokens(ctx: InputContext):
|
||||
)
|
||||
|
||||
|
||||
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size = get_max_llava_next_image_tokens(ctx)
|
||||
|
||||
@ -168,12 +170,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(
|
||||
vision_config,
|
||||
num_images,
|
||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
@ -183,12 +187,14 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
seq_data = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(
|
||||
vision_config,
|
||||
num_images,
|
||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
@ -24,8 +24,8 @@
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Iterable, List, Optional, Tuple, TypedDict,
|
||||
Union)
|
||||
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -42,8 +42,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
@ -408,22 +407,24 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
|
||||
return getattr(hf_config, "query_num", 64)
|
||||
|
||||
|
||||
def dummy_seq_data_for_minicpmv(seq_len: int):
|
||||
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
|
||||
token_ids = [0] * seq_len
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_image_for_minicpmv(hf_config: PretrainedConfig):
|
||||
def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int):
|
||||
width = height = hf_config.image_size
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image}
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int):
|
||||
def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config()
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
seq_data = dummy_seq_data_for_minicpmv(seq_len)
|
||||
mm_data = dummy_image_for_minicpmv(hf_config)
|
||||
seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images)
|
||||
mm_data = dummy_image_for_minicpmv(hf_config, num_images)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -9,8 +10,7 @@ from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.gemma import GemmaModel
|
||||
@ -57,17 +57,20 @@ def get_max_paligemma_image_tokens(ctx: InputContext):
|
||||
return get_max_siglip_image_tokens(vision_config)
|
||||
|
||||
|
||||
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
|
||||
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
seq_data = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(vision_config)
|
||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
|
||||
@ -15,7 +15,8 @@
|
||||
# limitations under the License.
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -28,8 +29,7 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -347,18 +347,22 @@ def get_max_phi3v_image_tokens(ctx: InputContext):
|
||||
)
|
||||
|
||||
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size = get_max_phi3v_image_tokens(ctx)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=_IMAGE_TOKEN_ID,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
num_images,
|
||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
@ -52,6 +52,7 @@ def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
|
||||
def dummy_seq_data_for_siglip(
|
||||
hf_config: SiglipVisionConfig,
|
||||
seq_len: int,
|
||||
num_images: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
@ -61,13 +62,14 @@ def dummy_seq_data_for_siglip(
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size
|
||||
token_ids += [0] * (seq_len - image_feature_size)
|
||||
token_ids = [image_token_id] * image_feature_size * num_images
|
||||
token_ids += [0] * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_image_for_siglip(
|
||||
hf_config: SiglipVisionConfig,
|
||||
num_images: int,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
@ -79,7 +81,7 @@ def dummy_image_for_siglip(
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image}
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def input_processor_for_siglip(
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Mapping, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast
|
||||
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -116,17 +116,30 @@ class MultiModalInputs(_MultiModalInputsBase):
|
||||
batched_inputs)
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
MultiModalData: TypeAlias = Union[_T, List[_T]]
|
||||
"""
|
||||
Either a single data instance, or a list of data instances.
|
||||
|
||||
The number of data instances allowed per modality is restricted by
|
||||
`--limit-mm-per-prompt`.
|
||||
"""
|
||||
|
||||
|
||||
@final
|
||||
class MultiModalDataBuiltins(TypedDict, total=False):
|
||||
"""Modality types that are predefined by vLLM."""
|
||||
|
||||
image: Image.Image
|
||||
"""The input image."""
|
||||
image: MultiModalData[Image.Image]
|
||||
"""The input image(s)."""
|
||||
|
||||
audio: Tuple[np.ndarray, Union[int, float]]
|
||||
"""The input audio and its sampling rate."""
|
||||
audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
|
||||
"""The input audio item(s) and corresponding sampling rate(s)."""
|
||||
|
||||
|
||||
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
|
||||
MultiModalDataDict = Union[MultiModalDataBuiltins,
|
||||
Mapping[str, MultiModalData[object]]]
|
||||
"""
|
||||
A dictionary containing an item for each modality type to input.
|
||||
|
||||
@ -137,7 +150,8 @@ Note:
|
||||
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
||||
"""
|
||||
|
||||
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
|
||||
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
|
||||
MultiModalInputs]
|
||||
"""
|
||||
Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
|
||||
@ -181,8 +195,11 @@ class MultiModalPlugin(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _default_input_mapper(self, ctx: InputContext,
|
||||
data: object) -> MultiModalInputs:
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`. This is similar in concept to
|
||||
@ -225,7 +242,7 @@ class MultiModalPlugin(ABC):
|
||||
return wrapper
|
||||
|
||||
def map_input(self, model_config: ModelConfig,
|
||||
data: object) -> MultiModalInputs:
|
||||
data: MultiModalData[object]) -> MultiModalInputs:
|
||||
"""
|
||||
Transform the data into a dictionary of model inputs using the
|
||||
input mapper registered for that model.
|
||||
@ -254,8 +271,8 @@ class MultiModalPlugin(ABC):
|
||||
@abstractmethod
|
||||
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||
"""
|
||||
Calculate the maximum number of multimodal tokens input to the language
|
||||
model. This does not include tokens that correspond to the input text.
|
||||
Calculate the maximum number of tokens, corresponding to a single
|
||||
instance of multimodal data, that are passed to the language model.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -269,8 +286,9 @@ class MultiModalPlugin(ABC):
|
||||
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
|
||||
):
|
||||
"""
|
||||
Register the maximum number of multi-modal tokens input to the
|
||||
language model for a model class.
|
||||
Register the maximum number of tokens, corresponding to a single
|
||||
instance of multimodal data, that are passed to the language model
|
||||
for a model class.
|
||||
|
||||
If `None` is provided, then the default calculation is used instead.
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.transformers_utils.image_processor import get_image_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MultiModalInputs, MultiModalPlugin
|
||||
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -110,8 +110,11 @@ class ImagePlugin(MultiModalPlugin):
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
def _default_input_mapper(self, ctx: InputContext,
|
||||
data: object) -> MultiModalInputs:
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
) -> MultiModalInputs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
# PIL image
|
||||
|
||||
@ -1,19 +1,33 @@
|
||||
import functools
|
||||
from typing import Dict, Optional, Sequence
|
||||
from collections import UserDict
|
||||
from typing import Dict, Mapping, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import ModelConfig, MultiModalConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .audio import AudioPlugin
|
||||
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
|
||||
MultiModalPlugin, MultiModalTokensCalc)
|
||||
MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
|
||||
from .image import ImagePlugin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _MultiModalLimits(UserDict):
|
||||
"""
|
||||
Wraps `_limits_by_model` for a more informative error message
|
||||
when attempting to access a model that does not exist.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: ModelConfig) -> Dict[str, int]:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError as exc:
|
||||
msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
|
||||
"forget to call `init_mm_limits_per_prompt`?")
|
||||
raise KeyError(msg) from exc
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
A registry that dispatches data processing to the
|
||||
@ -28,6 +42,11 @@ class MultiModalRegistry:
|
||||
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
|
||||
self._plugins = {p.get_data_key(): p for p in plugins}
|
||||
|
||||
# This is used for non-multimodal models
|
||||
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
|
||||
|
||||
self._limits_by_model = _MultiModalLimits()
|
||||
|
||||
def register_plugin(self, plugin: MultiModalPlugin) -> None:
|
||||
"""
|
||||
Register a multi-modal plugin so it can be recognized by vLLM.
|
||||
@ -86,13 +105,24 @@ class MultiModalRegistry:
|
||||
via the input mapper registered for that model.
|
||||
|
||||
See :meth:`MultiModalPlugin.map_input` for more details.
|
||||
|
||||
Note:
|
||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||
"""
|
||||
merged_dict: Dict[str, torch.Tensor] = {}
|
||||
merged_dict: Dict[str, NestedTensors] = {}
|
||||
|
||||
for data_key, data_value in data.items():
|
||||
input_dict = self._get_plugin(data_key) \
|
||||
.map_input(model_config, data_value)
|
||||
plugin = self._get_plugin(data_key)
|
||||
|
||||
num_items = len(data_value) if isinstance(data_value, list) else 1
|
||||
max_items = self._limits_by_model[model_config][data_key]
|
||||
if num_items > max_items:
|
||||
raise ValueError(
|
||||
f"You set {data_key}={max_items} (or defaulted to 1) in "
|
||||
f"`--limit-mm-per-prompt`, but found {num_items} items "
|
||||
"in the same prompt.")
|
||||
|
||||
input_dict = plugin.map_input(model_config, data_value)
|
||||
for input_key, input_tensor in input_dict.items():
|
||||
if input_key in merged_dict:
|
||||
raise ValueError(f"The input mappers (keys={set(data)}) "
|
||||
@ -115,8 +145,9 @@ class MultiModalRegistry:
|
||||
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
|
||||
):
|
||||
"""
|
||||
Register the maximum number of tokens, belonging to a
|
||||
specific modality, input to the language model for a model class.
|
||||
Register the maximum number of tokens, corresponding to a single
|
||||
instance of multimodal data belonging to a specific modality, that are
|
||||
passed to the language model for a model class.
|
||||
"""
|
||||
return self._get_plugin(data_type_key) \
|
||||
.register_max_multimodal_tokens(max_mm_tokens)
|
||||
@ -126,8 +157,8 @@ class MultiModalRegistry:
|
||||
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
|
||||
):
|
||||
"""
|
||||
Register the maximum number of image tokens
|
||||
input to the language model for a model class.
|
||||
Register the maximum number of image tokens, corresponding to a single
|
||||
image, that are passed to the language model for a model class.
|
||||
"""
|
||||
return self.register_max_multimodal_tokens("image", max_mm_tokens)
|
||||
|
||||
@ -137,7 +168,61 @@ class MultiModalRegistry:
|
||||
for profiling the memory usage of a model.
|
||||
|
||||
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
|
||||
|
||||
Note:
|
||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||
"""
|
||||
return sum(
|
||||
plugin.get_max_multimodal_tokens(model_config)
|
||||
for plugin in self._plugins.values())
|
||||
limits_per_plugin = self._limits_by_model[model_config]
|
||||
|
||||
return sum((limits_per_plugin[key] *
|
||||
plugin.get_max_multimodal_tokens(model_config))
|
||||
for key, plugin in self._plugins.items())
|
||||
|
||||
def init_mm_limits_per_prompt(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the maximum number of multi-modal input instances for each
|
||||
modality that are allowed per prompt for a model class.
|
||||
"""
|
||||
if model_config in self._limits_by_model:
|
||||
logger.warning(
|
||||
"`mm_limits` has already been set for model=%s, and will "
|
||||
"be overwritten by the new values.", model_config.model)
|
||||
|
||||
if multimodal_config is None:
|
||||
limits_per_plugin = self._disabled_limits_per_plugin
|
||||
else:
|
||||
config_limits_per_plugin = multimodal_config.limit_per_prompt
|
||||
|
||||
extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
|
||||
if extra_keys:
|
||||
logger.warning(
|
||||
"Detected extra keys in `--limit-mm-per-prompt` which "
|
||||
"are not registered as multi-modal plugins: %s. "
|
||||
"They will be ignored.", extra_keys)
|
||||
|
||||
# NOTE: Currently the default is set to 1 for each plugin
|
||||
# TODO: Automatically determine the limits based on budget
|
||||
# once more models support multi-image inputs
|
||||
limits_per_plugin = {
|
||||
key: config_limits_per_plugin.get(key, 1)
|
||||
for key in self._plugins
|
||||
}
|
||||
|
||||
self._limits_by_model[model_config] = limits_per_plugin
|
||||
|
||||
def get_mm_limits_per_prompt(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of multi-modal input instances for each modality
|
||||
that are allowed per prompt for a model class.
|
||||
|
||||
Note:
|
||||
This should be called after :meth:`init_mm_limits_per_prompt`.
|
||||
"""
|
||||
return self._limits_by_model[model_config]
|
||||
|
||||
@ -13,7 +13,6 @@ import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from asyncio import FIRST_COMPLETED, ensure_future
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
||||
@ -760,16 +759,6 @@ class CudaMemoryProfiler:
|
||||
gc.collect()
|
||||
|
||||
|
||||
def str_to_int_tuple(s: str) -> Tuple[int, ...]:
|
||||
"""Convert a string to a tuple of integers."""
|
||||
try:
|
||||
return tuple(map(int, s.split(",")))
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"String must be a series of integers separated by commas "
|
||||
f"(e.g., 1, 2, 3). Given input: {s}") from e
|
||||
|
||||
|
||||
def make_ndarray_with_pad(
|
||||
x: List[List[T]],
|
||||
pad: T,
|
||||
@ -863,23 +852,6 @@ def is_list_of(
|
||||
assert_never(check)
|
||||
|
||||
|
||||
def merge_dicts(dict1: Dict[K, List[T]],
|
||||
dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
|
||||
"""Merge 2 dicts that have key -> List of items.
|
||||
|
||||
When a key conflicts, the values in dict1 is prioritized.
|
||||
"""
|
||||
merged_dict: Dict[K, List[T]] = defaultdict(list)
|
||||
|
||||
for key, value in dict1.items():
|
||||
merged_dict[key].extend(value)
|
||||
|
||||
for key, value in dict2.items():
|
||||
merged_dict[key].extend(value)
|
||||
|
||||
return dict(merged_dict)
|
||||
|
||||
|
||||
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
|
||||
Tuple["JSONTree[T]", ...], T]
|
||||
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
||||
|
||||
@ -12,9 +12,10 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
@ -83,6 +84,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
'''
|
||||
EncoderDecoderModelRunner constructor.
|
||||
@ -271,6 +274,16 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
|
||||
model_config = self.model_config
|
||||
mm_config = self.multimodal_config
|
||||
|
||||
input_registry = self.input_registry
|
||||
mm_registry = self.mm_registry
|
||||
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
|
||||
|
||||
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
|
||||
if max_mm_tokens > 0:
|
||||
raise NotImplementedError(
|
||||
"Multi-modal encoder-decoder models are not supported yet")
|
||||
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
@ -278,8 +291,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
seq_data, _ = INPUT_REGISTRY \
|
||||
.dummy_data_for_profiling(model_config, seq_len)
|
||||
seq_data, _ = input_registry \
|
||||
.dummy_data_for_profiling(model_config, seq_len, mm_registry)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
assert len(seq_data.prompt_token_ids) >= seq_len, (
|
||||
|
||||
@ -31,7 +31,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -43,7 +43,7 @@ from vllm.model_executor.models.interfaces import (supports_lora,
|
||||
supports_multimodal)
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalInputs)
|
||||
MultiModalInputs, MultiModalRegistry)
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.prompt_adapter.worker_manager import (
|
||||
@ -807,6 +807,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
@ -860,8 +862,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
) if num_attn_heads else None
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
self.input_registry = input_registry
|
||||
self.mm_registry = mm_registry
|
||||
self.multi_modal_input_mapper = mm_registry \
|
||||
.create_input_mapper(model_config)
|
||||
|
||||
# Lazy initialization
|
||||
self.model: nn.Module # Set after load_model
|
||||
@ -902,7 +906,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
assert supports_lora(self.model), "Model does not support LoRA"
|
||||
assert not supports_multimodal(
|
||||
self.model
|
||||
), "To be tested: multimodal language model with LoRA settings."
|
||||
), "To be tested: Multi-modal model with LoRA settings."
|
||||
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
@ -1046,17 +1050,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
# Additional GPU memory may be needed for vision encoding, which needs
|
||||
# to be accounted for when calculating the GPU blocks for
|
||||
# Additional GPU memory may be needed for multi-modal encoding, which
|
||||
# needs to be accounted for when calculating the GPU blocks for
|
||||
# vLLM blocker manager.
|
||||
# To exercise the worst scenario for GPU memory consumption,
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
model_config = self.model_config
|
||||
mm_config = self.multimodal_config
|
||||
|
||||
if supports_multimodal(self.model):
|
||||
max_mm_tokens = MULTIMODAL_REGISTRY \
|
||||
.get_max_multimodal_tokens(model_config)
|
||||
input_registry = self.input_registry
|
||||
mm_registry = self.mm_registry
|
||||
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
|
||||
|
||||
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
|
||||
if max_mm_tokens > 0:
|
||||
max_num_seqs_orig = max_num_seqs
|
||||
max_num_seqs = min(max_num_seqs,
|
||||
max_num_batched_tokens // max_mm_tokens)
|
||||
@ -1074,13 +1082,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
|
||||
.dummy_data_for_profiling(model_config, seq_len)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
assert len(seq_data.prompt_token_ids) >= seq_len, (
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but got: {len(seq_data.prompt_token_ids)}")
|
||||
seq_data, dummy_multi_modal_data = input_registry \
|
||||
.dummy_data_for_profiling(model_config, seq_len, mm_registry)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
|
||||
@ -9,12 +9,11 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import supports_multimodal
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalInputs)
|
||||
MultiModalInputs, MultiModalRegistry)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
@ -89,6 +88,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@ -120,8 +121,10 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
|
||||
.create_input_mapper(self.model_config)
|
||||
self.input_registry = input_registry
|
||||
self.mm_registry = mm_registry
|
||||
self.multi_modal_input_mapper = mm_registry \
|
||||
.create_input_mapper(model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
@ -157,17 +160,21 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
# Additional GPU memory may be needed for vision encoding, which needs
|
||||
# to be accounted for when calculating the GPU blocks for
|
||||
# Additional GPU memory may be needed for multi-modal encoding, which
|
||||
# needs to be accounted for when calculating the GPU blocks for
|
||||
# vLLM blocker manager.
|
||||
# To exercise the worst scenario for GPU memory consumption,
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
model_config = self.model_config
|
||||
mm_config = self.multimodal_config
|
||||
|
||||
if supports_multimodal(self.model):
|
||||
max_mm_tokens = MULTIMODAL_REGISTRY \
|
||||
.get_max_multimodal_tokens(model_config)
|
||||
input_registry = self.input_registry
|
||||
mm_registry = self.mm_registry
|
||||
mm_registry.init_mm_limits_per_prompt(model_config, mm_config)
|
||||
|
||||
max_mm_tokens = mm_registry.get_max_multimodal_tokens(model_config)
|
||||
if max_mm_tokens > 0:
|
||||
max_num_seqs_orig = max_num_seqs
|
||||
max_num_seqs = min(max_num_seqs,
|
||||
max_num_batched_tokens // max_mm_tokens)
|
||||
@ -183,13 +190,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
|
||||
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
|
||||
.dummy_data_for_profiling(model_config, seq_len)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
assert len(seq_data.prompt_token_ids) >= seq_len, (
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but got: {len(seq_data.prompt_token_ids)}")
|
||||
seq_data, dummy_multi_modal_data = input_registry \
|
||||
.dummy_data_for_profiling(model_config, seq_len, mm_registry)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user