[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-03 00:39:27 +08:00 committed by GitHub
parent b6087a6bee
commit 8c38ee7007
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 609 additions and 555 deletions

View File

@ -1,70 +0,0 @@
import pytest
from vllm.inputs import InputContext
from ....utils import build_model_context
@pytest.fixture()
def get_max_llava_next_image_tokens():
from vllm.model_executor.models.llava_next import (
get_max_llava_next_image_tokens)
return get_max_llava_next_image_tokens
@pytest.fixture()
def dummy_data_for_llava_next():
from vllm.model_executor.models.llava_next import dummy_data_for_llava_next
return dummy_data_for_llava_next
@pytest.mark.parametrize("gridpoints,expected_max_tokens", [
([[336, 336]], 1176),
([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], 2928),
])
def test_get_max_llava_next_image_tokens(gridpoints, expected_max_tokens,
get_max_llava_next_image_tokens):
ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf")
# Update the config image_grid_pinpoints
# and calculate the resulting max tokens
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
actual_max_tokens = get_max_llava_next_image_tokens(
InputContext(ctx.model_config))
assert expected_max_tokens == actual_max_tokens
@pytest.mark.parametrize(
"gridpoints,expected_size",
[
# One point; it has to be the largest
([[336, 336]], (336, 336)),
# Default for most llava next models; the 2x2 tile is the largest
([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]],
(672, 672)),
# If two rectangular gridpoints are the same, the more vertical
# one has the higher feature count due to newline features
([[336, 672], [672, 336]], (672, 336))
])
def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next,
gridpoints, expected_size):
ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf")
# Update the config image_grid_pinpoints
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
seq_len = 5000 # bigger than the max feature size for any image
dummy_data = dummy_data_for_llava_next(
ctx,
seq_len=seq_len,
mm_counts={"image": 1},
)
seq_data = dummy_data.seq_data
mm_data = dummy_data.multi_modal_data
# The dummy data dims should match the gridpoint with the biggest feat size
assert mm_data["image"].height == expected_size[0]
assert mm_data["image"].width == expected_size[1]
assert len(seq_data.get_token_ids()) >= seq_len

View File

@ -1,118 +0,0 @@
from contextlib import nullcontext
import numpy as np
import pytest
from transformers import LlavaNextImageProcessor
from vllm.config import ModelConfig
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.image import rescale_image_size
@pytest.fixture
def mm_registry():
return MultiModalRegistry()
@pytest.mark.parametrize("dtype", ["half", "float"])
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
size_factor):
MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf"
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
assert isinstance(hf_processor, LlavaNextImageProcessor)
model_config = ModelConfig(
model=MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype=dtype,
revision=None,
limit_mm_per_prompt={"image": 1},
)
mm_registry.init_mm_limits_per_prompt(model_config)
for asset in image_assets:
image = rescale_image_size(asset.pil_image, size_factor)
hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
)
vllm_result = mm_registry.map_input(
model_config,
{"image": image},
)
assert hf_result.keys() == vllm_result.keys()
for key, hf_tensor in hf_result.items():
hf_arr: np.ndarray = hf_tensor.numpy()
vllm_arr: np.ndarray = vllm_result[key].numpy()
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
@pytest.mark.parametrize(
("num_images", "limit", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
)
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
model_config = ModelConfig(
model=MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
limit_mm_per_prompt={"image": limit},
)
mm_registry.init_mm_limits_per_prompt(model_config)
image = image_assets[0].pil_image
if num_images == 0:
mm_inputs = {}
elif num_images == 1:
mm_inputs = {"image": image}
else:
mm_inputs = {"image": [image] * num_images}
with nullcontext() if is_valid else pytest.raises(ValueError):
mm_registry.map_input(model_config, mm_inputs)
# NOTE: We don't test zero images since the HF processor doesn't support it
@pytest.mark.parametrize("num_images", [1, 2])
def test_image_mapper_multi(image_assets, mm_registry, num_images):
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
model_config = ModelConfig(
model=MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
limit_mm_per_prompt={"image": num_images},
)
mm_registry.init_mm_limits_per_prompt(model_config)
image = image_assets[0].pil_image
mm_inputs = {"image": [image] * num_images}
mapped_inputs = mm_registry.map_input(model_config, mm_inputs)
assert len(mapped_inputs["pixel_values"]) == num_images

View File

@ -1,5 +1,7 @@
from contextlib import nullcontext
from functools import partial
from typing import cast
from unittest.mock import MagicMock
import numpy as np
import pytest
@ -526,6 +528,100 @@ def _rand_audio(
return rng.rand(audio_len), sr
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("limit", "num_supported", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
limit_mm_per_prompt = {"image": limit}
model_config = ModelConfig(
model=model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
processor = processor_factory(ctx, cache=None)
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
processor.get_supported_mm_limits = mock_supported_mm_limits
if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="this model only supports")
with exc_ctx:
processor._get_and_validate_dummy_mm_counts()
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("num_images", "limit", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
limit_mm_per_prompt = {"image": limit}
model_config = ModelConfig(
model=model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
processor = processor_factory(ctx, cache=None)
rng = np.random.RandomState(0)
image = _rand_img(rng, min_wh=128, max_wh=256)
if num_images == 0:
mm_data = {}
elif num_images == 1:
mm_data = {"image": image}
else:
mm_data = {"image": [image] * num_images}
if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")
with exc_ctx:
processor.apply(
"<image>" * num_images,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
def _test_processing_cache_correctness(
model_id: str,
modalities: dict[str, bool],
@ -631,6 +727,7 @@ def _test_processing_cache_correctness(
("facebook/chameleon-7b", {"image": False}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
("mistral-community/pixtral-12b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),

View File

@ -3,13 +3,11 @@ from typing import Optional
import torch
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
LlavaMultiModalProcessor,
get_max_llava_image_tokens)
LlavaMultiModalProcessor)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class MyLlava(LlavaForConditionalGeneration):

View File

@ -24,6 +24,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData
from .vision import VisionEncoderInfo
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
@ -149,6 +151,29 @@ def input_processor_for_clip(
multi_modal_placeholders={"image": ranges})
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
return get_clip_image_feature_size(self.vision_config)
def get_max_image_tokens(self) -> int:
return get_max_clip_image_tokens(self.vision_config)
def get_num_patches(self) -> int:
return get_clip_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
def get_image_size(self) -> int:
return self.vision_config.image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):

View File

@ -76,7 +76,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
return ImageSize(width=target_size["width"],
height=target_size["height"])
def _get_image_grid_size(
def _get_image_feature_grid_size(
self,
*,
image_width: int,
@ -99,7 +99,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
target_width, target_height = self._get_image_target_size()
max_ncols, max_nrows = self._get_image_grid_size(
max_ncols, max_nrows = self._get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
@ -172,7 +172,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = self._get_image_grid_size(
ncols, nrows = self._get_image_feature_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)

View File

@ -1,6 +1,7 @@
from abc import abstractmethod
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, TypedDict, Union)
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
import torch
import torch.nn as nn
@ -12,7 +13,6 @@ from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@ -23,23 +23,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
InputProcessingContext,
MultiModalDataItems, ProcessingCache,
ProcessorInputs, PromptReplacement,
full_groupby_modality)
from vllm.sequence import IntermediateTensors
from .clip import (CLIPVisionModel, dummy_image_for_clip,
get_max_clip_image_tokens)
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
get_max_pixtral_hf_image_tokens,
get_pixtral_hf_image_feature_size)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
get_max_siglip_image_tokens)
from .pixtral import (PixtralHFVisionModel,
get_pixtral_hf_image_feature_grid_size)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import vision_encoder_info
class LlavaImagePixelInputs(TypedDict):
@ -94,39 +94,167 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states
def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_image_tokens = get_max_siglip_image_tokens(vision_config)
elif isinstance(vision_config, PixtralVisionConfig):
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
return num_image_tokens - 1
elif strategy == "full":
return num_image_tokens
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
class LlavaLikeConfig(Protocol):
vision_config: Final[PretrainedConfig]
vision_feature_select_strategy: Final[str]
vision_feature_layer: Final[Union[int, List[int]]]
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod
def _get_hf_config(self) -> LlavaLikeConfig:
raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": get_max_llava_image_tokens(self.ctx)}
def _apply_feature_select_strategy(
self,
strategy: str,
encoder_num_image_tokens: int,
) -> int:
if strategy == "default":
return encoder_num_image_tokens - 1
if strategy == "full":
return encoder_num_image_tokens
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg)
def _get_max_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_max_image_tokens(),
)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_dummy_image_size(self) -> ImageSize:
image_size = self._vision_encoder_info.get_image_size()
return ImageSize(image_size, image_size)
@abstractmethod
def _get_image_token(self) -> str:
raise NotImplementedError
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
image_token = self._get_image_token()
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
def _get_hf_config(self) -> LlavaConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_hf_processor(self) -> LlavaProcessor:
return self.ctx.get_hf_processor(LlavaProcessor)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
image_token_id = hf_config.image_token_index
def get_replacement(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement,
),
]
class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
def _get_hf_config(self) -> LlavaConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_hf_processor(self) -> PixtralProcessor:
return self.ctx.get_hf_processor(PixtralProcessor)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _call_hf_processor(
self,
@ -140,119 +268,82 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
mm_kwargs=mm_kwargs,
)
# NOTE: pixel_values=None for MLlavaProcessor
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
images = mm_data["images"]
assert isinstance(images, list)
if isinstance(self._get_hf_processor(), PixtralProcessor):
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list)
and len(pixel_values) == 1)
assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
processed_outputs["pixel_values"] = pixel_values[0]
processed_outputs["pixel_values"] = pixel_values[0]
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
hf_config = self._get_hf_config()
image_token_id = hf_config.image_token_index
processor = self._get_hf_processor()
if isinstance(processor, PixtralProcessor):
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
def get_replacement_pixtral(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
(
num_width_tokens,
num_height_tokens,
) = get_pixtral_hf_image_feature_size(
vision_config,
image_width=image_size.width,
image_height=image_size.height,
)
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
vision_config,
image_width=image_size.width,
image_height=image_size.height,
)
tokens = ([image_token] * num_width_tokens +
[image_break_token]) * num_height_tokens
tokens[-1] = image_end_token
tokens = ([image_token] * ncols + [image_break_token]) * nrows
tokens[-1] = image_end_token
return "".join(tokens)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_pixtral,
),
]
max_image_tokens = get_max_llava_image_tokens(self.ctx)
return "".join(tokens)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * max_image_tokens,
)
replacement=get_replacement,
),
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts.get("image", 0)
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
elif isinstance(vision_config, SiglipVisionConfig):
data = dummy_image_for_siglip(vision_config, num_images)
elif isinstance(vision_config, PixtralVisionConfig):
data = dummy_image_for_pixtral_hf(vision_config, num_images)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def _build_llava_or_pixtral_hf_processor(
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseLlavaMultiModalProcessor:
hf_config = ctx.get_hf_config(LlavaConfig)
hf_processor = self._get_hf_processor()
image_token = hf_processor.image_token
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
if isinstance(hf_config.vision_config, PixtralVisionConfig):
return PixtralHFMultiModalProcessor(
ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig
vision_feature_layer: Union[int, List[int]]
return LlavaMultiModalProcessor(
ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
@ -330,7 +421,7 @@ def init_vision_tower_for_llava(
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
@ -596,7 +687,12 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
) -> MultiModalInputsV2:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
max_image_tokens = get_max_llava_image_tokens(self.ctx)
# Assume that it doesn't depend on the image size
num_image_tokens = self._get_num_image_tokens(
image_width=-1,
image_height=-1,
)
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
@ -609,14 +705,14 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def get_replacement_mantis(item_idx: int):
return "".join([
f"(image {item_idx+1}: <Image>", # 7 tokens
"<image>" * max_image_tokens,
"<image>" * num_image_tokens,
"</Image>)", # 3 tokens
])
mantis_repls = self._bind_prompt_replacements([
PromptReplacement(
modality="image",
target=[image_token_id] * max_image_tokens,
target=[image_token_id] * num_image_tokens,
replacement=get_replacement_mantis,
)
])

View File

@ -4,31 +4,25 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import ImageSize
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector,
init_vision_tower_for_llava)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model, maybe_prefix)
@ -65,218 +59,127 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs]
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
def _get_llava_next_num_unpadded_features(
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
def _get_hf_config(self) -> LlavaNextConfig:
return self.ctx.get_hf_config(LlavaNextConfig)
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
current_height -= 2 * padding
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
current_width -= 2 * padding
def _get_hf_processor(self) -> LlavaNextProcessor:
return self.ctx.get_hf_processor(LlavaNextProcessor)
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_patches = get_clip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_clip_image_feature_size(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_patches = get_siglip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_siglip_image_feature_size(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
base_feature_size -= 1
elif strategy == "full":
pass
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)
(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_next_num_unpadded_features(input_height, input_width,
num_patches, num_patch_height,
num_patch_width)
return unpadded_feature_size + newline_feature_size + base_feature_size
def get_max_llava_next_image_tokens(ctx: InputContext):
"""Compute the max feature size for all possible image grid pinpoints."""
return _get_pinpoint_with_largest_features(ctx)[0]
def _get_pinpoint_with_largest_features(
ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
"""Get the grid pinpoint with the largest features & its feature size."""
hf_config = ctx.get_hf_config(LlavaNextConfig)
largest_feature_size = 0
largest_feature_pinpoint = None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = (height, width)
if not largest_feature_size or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_size, largest_feature_pinpoint
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
max_feat_height, max_feat_width = pinpoint
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
mm_data = dummy_image_for_clip(
vision_config,
num_images,
image_width_override=max_feat_width,
image_height_override=max_feat_height,
def _get_max_image_tokens(self) -> int:
largest_feature_size, _ = self._get_pinpoint_with_most_features()
return largest_feature_size
def _get_dummy_image_size(self) -> ImageSize:
_, pinpoint = self._get_pinpoint_with_most_features()
return pinpoint
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
base_feature_size = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
num_patches = self._vision_encoder_info.get_num_patches()
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(image_height, image_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=self._vision_encoder_info.get_image_size(),
)
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
(
unpadded_feature_size,
newline_feature_size,
) = self._get_num_unpadded_features(
original_height=image_height,
original_width=image_width,
npatches=num_patches,
num_patch_height=num_patch_height,
num_patch_width=num_patch_width,
)
mm_data = dummy_image_for_siglip(
vision_config,
num_images,
image_width_override=max_feat_width,
image_height_override=max_feat_height,
)
return unpadded_feature_size + newline_feature_size + base_feature_size
return DummyData(seq_data, mm_data, ranges)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
def _get_num_unpadded_features(
self,
*,
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
current_height -= 2 * padding
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
current_width -= 2 * padding
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]:
"""
Get the grid pinpoint with the most features and
the corresponding feature size.
"""
hf_config = self._get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self._get_num_image_tokens(image_width=width,
image_height=height)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_size, largest_feature_pinpoint
def input_processor_for_llava_next(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image_feature_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_next_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
]
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
@ -507,7 +410,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_pixels(
self,
inputs: LlavaNextImagePixelInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None
pixel_values = inputs["data"]

View File

@ -34,7 +34,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.parse import ImageEmbeddingItems, ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
@ -388,15 +388,19 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
assert isinstance(bos_token_id, int)
def get_replacement_phi3v(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
num_tokens = self._get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
num_images = mm_items.get_count("image", strict=False)

View File

@ -38,6 +38,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import VisionEncoderInfo
try:
from xformers import ops as xops
@ -697,10 +698,18 @@ def get_pixtral_hf_patch_grid_length(*, image_size: int,
return image_size // patch_size
def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_pixtral_hf_image_feature_size(
*,
image_size: int,
patch_size: int,
) -> int:
grid_length = get_pixtral_hf_patch_grid_length(
image_size=image_size,
patch_size=patch_size,
)
# Consider the image_break_token
return (grid_length + 1) * grid_length
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
@ -730,13 +739,16 @@ def dummy_image_for_pixtral_hf(
return {"image": image if num_images == 1 else [image] * num_images}
def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
image_width: int,
image_height: int) -> Tuple[int, int]:
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501
max_width, max_height = hf_config.image_size, hf_config.image_size
patch_width, patch_height = hf_config.patch_size, hf_config.patch_size
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180
def get_pixtral_hf_image_feature_grid_size(
hf_config: PixtralVisionConfig,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
max_width = max_height = hf_config.image_size
patch_width = patch_height = hf_config.patch_size
ratio = max(image_width / max_width, image_height / max_height)
@ -744,12 +756,38 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
image_width = int(math.ceil(image_width / ratio))
image_height = int(math.ceil(image_height / ratio))
num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens(
nrows, ncols = _get_pixtral_hf_num_image_tokens(
(image_height, image_width),
(patch_height, patch_width),
)
) # type: ignore
return num_width_tokens, num_height_tokens
return ncols, nrows
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
return get_pixtral_hf_image_feature_size(
image_size=self.vision_config.image_size,
patch_size=self.get_image_size(),
)
def get_max_image_tokens(self) -> int:
return get_max_pixtral_hf_image_tokens(self.vision_config)
def get_num_patches(self) -> int:
return get_pixtral_hf_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
def get_image_size(self) -> int:
return self.vision_config.image_size
class PixtralHFMLP(nn.Module):

View File

@ -28,6 +28,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData
from .vision import VisionEncoderInfo
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
@ -156,6 +158,29 @@ def input_processor_for_siglip(
multi_modal_placeholders={"image": ranges})
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
return get_siglip_image_feature_size(self.vision_config)
def get_max_image_tokens(self) -> int:
return get_max_siglip_image_tokens(self.vision_config)
def get_num_patches(self) -> int:
return get_siglip_patch_grid_length(
image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size,
)
def get_image_size(self) -> int:
return self.vision_config.image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):

View File

@ -373,7 +373,7 @@ def embed_multimodal(
input_ids: torch.Tensor,
multimodal_token_id: int,
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
multimodal_embeds: Union[torch.Tensor, List[torch.Tensor]],
multimodal_embeds: NestedTensors,
) -> torch.Tensor:
"""
Embed token IDs and multimodal inputs and combine their embeddings.

View File

@ -0,0 +1,52 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from transformers import PretrainedConfig
_C = TypeVar("_C", bound=PretrainedConfig)
class VisionEncoderInfo(ABC, Generic[_C]):
def __init__(self, vision_config: _C) -> None:
super().__init__()
self.vision_config = vision_config
@abstractmethod
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
raise NotImplementedError
@abstractmethod
def get_max_image_tokens(self) -> int:
raise NotImplementedError
@abstractmethod
def get_num_patches(self) -> int:
raise NotImplementedError
@abstractmethod
def get_image_size(self) -> int:
raise NotImplementedError
def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
# Avoid circular imports
from .clip import CLIPEncoderInfo, CLIPVisionConfig
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
if isinstance(vision_config, CLIPVisionConfig):
return CLIPEncoderInfo(vision_config)
if isinstance(vision_config, PixtralVisionConfig):
return PixtralHFEncoderInfo(vision_config)
if isinstance(vision_config, SiglipVisionConfig):
return SiglipEncoderInfo(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)

View File

@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar
from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
Union)
import numpy as np
import torch
@ -87,7 +88,7 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> object:
def get(self, index: int) -> torch.Tensor:
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
@ -96,6 +97,9 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
def get_passthrough_data(self) -> Mapping[str, object]:
return {f"{self.modality}_embeds": self.data}
def get_feature_size(self, item_idx: int) -> int:
return len(self.get(item_idx))
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
@ -182,7 +186,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
def get_items(
self,
modality: str,
typ: type[_D],
typ: Union[type[_D], tuple[type[_D], ...]],
) -> _D:
"""
Get the data items belonging to a modality,
@ -199,7 +203,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
f"Expected type: {typ}, but "
f"found type: {type(items)}")
return items
return items # type: ignore[return-value]
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],