mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 08:55:46 +08:00
[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
b6087a6bee
commit
8c38ee7007
@ -1,70 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import InputContext
|
||||
|
||||
from ....utils import build_model_context
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def get_max_llava_next_image_tokens():
|
||||
from vllm.model_executor.models.llava_next import (
|
||||
get_max_llava_next_image_tokens)
|
||||
return get_max_llava_next_image_tokens
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_data_for_llava_next():
|
||||
from vllm.model_executor.models.llava_next import dummy_data_for_llava_next
|
||||
return dummy_data_for_llava_next
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gridpoints,expected_max_tokens", [
|
||||
([[336, 336]], 1176),
|
||||
([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], 2928),
|
||||
])
|
||||
def test_get_max_llava_next_image_tokens(gridpoints, expected_max_tokens,
|
||||
get_max_llava_next_image_tokens):
|
||||
ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf")
|
||||
|
||||
# Update the config image_grid_pinpoints
|
||||
# and calculate the resulting max tokens
|
||||
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
|
||||
|
||||
actual_max_tokens = get_max_llava_next_image_tokens(
|
||||
InputContext(ctx.model_config))
|
||||
|
||||
assert expected_max_tokens == actual_max_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gridpoints,expected_size",
|
||||
[
|
||||
# One point; it has to be the largest
|
||||
([[336, 336]], (336, 336)),
|
||||
# Default for most llava next models; the 2x2 tile is the largest
|
||||
([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]],
|
||||
(672, 672)),
|
||||
# If two rectangular gridpoints are the same, the more vertical
|
||||
# one has the higher feature count due to newline features
|
||||
([[336, 672], [672, 336]], (672, 336))
|
||||
])
|
||||
def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next,
|
||||
gridpoints, expected_size):
|
||||
ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf")
|
||||
|
||||
# Update the config image_grid_pinpoints
|
||||
ctx.model_config.hf_config.image_grid_pinpoints = gridpoints
|
||||
seq_len = 5000 # bigger than the max feature size for any image
|
||||
|
||||
dummy_data = dummy_data_for_llava_next(
|
||||
ctx,
|
||||
seq_len=seq_len,
|
||||
mm_counts={"image": 1},
|
||||
)
|
||||
seq_data = dummy_data.seq_data
|
||||
mm_data = dummy_data.multi_modal_data
|
||||
|
||||
# The dummy data dims should match the gridpoint with the biggest feat size
|
||||
assert mm_data["image"].height == expected_size[0]
|
||||
assert mm_data["image"].width == expected_size[1]
|
||||
assert len(seq_data.get_token_ids()) >= seq_len
|
||||
@ -1,118 +0,0 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import LlavaNextImageProcessor
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm_registry():
|
||||
return MultiModalRegistry()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half", "float"])
|
||||
@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0])
|
||||
def test_llava_next_image_processor(image_assets, mm_registry, dtype,
|
||||
size_factor):
|
||||
MODEL_NAME = "llava-hf/llava-v1.6-vicuna-7b-hf"
|
||||
|
||||
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
|
||||
assert isinstance(hf_processor, LlavaNextImageProcessor)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
task="auto",
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype=dtype,
|
||||
revision=None,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
mm_registry.init_mm_limits_per_prompt(model_config)
|
||||
|
||||
for asset in image_assets:
|
||||
image = rescale_image_size(asset.pil_image, size_factor)
|
||||
|
||||
hf_result = hf_processor.preprocess(
|
||||
image,
|
||||
return_tensors="pt",
|
||||
)
|
||||
vllm_result = mm_registry.map_input(
|
||||
model_config,
|
||||
{"image": image},
|
||||
)
|
||||
|
||||
assert hf_result.keys() == vllm_result.keys()
|
||||
for key, hf_tensor in hf_result.items():
|
||||
hf_arr: np.ndarray = hf_tensor.numpy()
|
||||
vllm_arr: np.ndarray = vllm_result[key].numpy()
|
||||
|
||||
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
|
||||
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("num_images", "limit", "is_valid"),
|
||||
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
|
||||
(2, 1, False), (2, 2, True)],
|
||||
)
|
||||
def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
|
||||
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
task="auto",
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
revision=None,
|
||||
limit_mm_per_prompt={"image": limit},
|
||||
)
|
||||
|
||||
mm_registry.init_mm_limits_per_prompt(model_config)
|
||||
|
||||
image = image_assets[0].pil_image
|
||||
if num_images == 0:
|
||||
mm_inputs = {}
|
||||
elif num_images == 1:
|
||||
mm_inputs = {"image": image}
|
||||
else:
|
||||
mm_inputs = {"image": [image] * num_images}
|
||||
|
||||
with nullcontext() if is_valid else pytest.raises(ValueError):
|
||||
mm_registry.map_input(model_config, mm_inputs)
|
||||
|
||||
|
||||
# NOTE: We don't test zero images since the HF processor doesn't support it
|
||||
@pytest.mark.parametrize("num_images", [1, 2])
|
||||
def test_image_mapper_multi(image_assets, mm_registry, num_images):
|
||||
MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL_NAME,
|
||||
task="auto",
|
||||
tokenizer=MODEL_NAME,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
revision=None,
|
||||
limit_mm_per_prompt={"image": num_images},
|
||||
)
|
||||
|
||||
mm_registry.init_mm_limits_per_prompt(model_config)
|
||||
|
||||
image = image_assets[0].pil_image
|
||||
mm_inputs = {"image": [image] * num_images}
|
||||
|
||||
mapped_inputs = mm_registry.map_input(model_config, mm_inputs)
|
||||
assert len(mapped_inputs["pixel_values"]) == num_images
|
||||
@ -1,5 +1,7 @@
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -526,6 +528,100 @@ def _rand_audio(
|
||||
return rng.rand(audio_len), sr
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize(
|
||||
("limit", "num_supported", "is_valid"),
|
||||
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
|
||||
(2, 1, False), (2, 2, True)],
|
||||
)
|
||||
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
limit_mm_per_prompt = {"image": limit}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
revision=None,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
ctx = InputProcessingContext(
|
||||
model_config,
|
||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||
)
|
||||
|
||||
processor = processor_factory(ctx, cache=None)
|
||||
|
||||
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
|
||||
processor.get_supported_mm_limits = mock_supported_mm_limits
|
||||
|
||||
if is_valid:
|
||||
exc_ctx = nullcontext()
|
||||
else:
|
||||
exc_ctx = pytest.raises(ValueError, match="this model only supports")
|
||||
|
||||
with exc_ctx:
|
||||
processor._get_and_validate_dummy_mm_counts()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize(
|
||||
("num_images", "limit", "is_valid"),
|
||||
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
|
||||
(2, 1, False), (2, 2, True)],
|
||||
)
|
||||
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
||||
limit_mm_per_prompt = {"image": limit}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
revision=None,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
ctx = InputProcessingContext(
|
||||
model_config,
|
||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||
)
|
||||
|
||||
processor = processor_factory(ctx, cache=None)
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
image = _rand_img(rng, min_wh=128, max_wh=256)
|
||||
if num_images == 0:
|
||||
mm_data = {}
|
||||
elif num_images == 1:
|
||||
mm_data = {"image": image}
|
||||
else:
|
||||
mm_data = {"image": [image] * num_images}
|
||||
|
||||
if is_valid:
|
||||
exc_ctx = nullcontext()
|
||||
else:
|
||||
exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")
|
||||
|
||||
with exc_ctx:
|
||||
processor.apply(
|
||||
"<image>" * num_images,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
def _test_processing_cache_correctness(
|
||||
model_id: str,
|
||||
modalities: dict[str, bool],
|
||||
@ -631,6 +727,7 @@ def _test_processing_cache_correctness(
|
||||
("facebook/chameleon-7b", {"image": False}),
|
||||
("adept/fuyu-8b", {"image": False}),
|
||||
("llava-hf/llava-1.5-7b-hf", {"image": True}),
|
||||
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
|
||||
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
|
||||
("mistral-community/pixtral-12b", {"image": True}),
|
||||
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
|
||||
|
||||
@ -3,13 +3,11 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
|
||||
LlavaMultiModalProcessor,
|
||||
get_max_llava_image_tokens)
|
||||
LlavaMultiModalProcessor)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
|
||||
class MyLlava(LlavaForConditionalGeneration):
|
||||
|
||||
|
||||
@ -24,6 +24,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
resolve_visual_encoder_outputs)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
from .vision import VisionEncoderInfo
|
||||
|
||||
|
||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
@ -149,6 +151,29 @@ def input_processor_for_clip(
|
||||
multi_modal_placeholders={"image": ranges})
|
||||
|
||||
|
||||
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
return get_clip_image_feature_size(self.vision_config)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_clip_image_tokens(self.vision_config)
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
return get_clip_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return ImageSize(width=target_size["width"],
|
||||
height=target_size["height"])
|
||||
|
||||
def _get_image_grid_size(
|
||||
def _get_image_feature_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
@ -99,7 +99,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
|
||||
max_ncols, max_nrows = self._get_image_grid_size(
|
||||
max_ncols, max_nrows = self._get_image_feature_grid_size(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
@ -172,7 +172,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
ncols, nrows = self._get_image_grid_size(
|
||||
ncols, nrows = self._get_image_feature_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -12,7 +13,6 @@ from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -23,23 +23,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement,
|
||||
InputProcessingContext,
|
||||
MultiModalDataItems, ProcessingCache,
|
||||
ProcessorInputs, PromptReplacement,
|
||||
full_groupby_modality)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
get_max_clip_image_tokens)
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
||||
get_max_pixtral_hf_image_tokens,
|
||||
get_pixtral_hf_image_feature_size)
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
get_max_siglip_image_tokens)
|
||||
from .pixtral import (PixtralHFVisionModel,
|
||||
get_pixtral_hf_image_feature_grid_size)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import vision_encoder_info
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@ -94,39 +94,167 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_max_llava_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
num_image_tokens = get_max_clip_image_tokens(vision_config)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
num_image_tokens = get_max_siglip_image_tokens(vision_config)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
strategy = hf_config.vision_feature_select_strategy
|
||||
if strategy == "default":
|
||||
return num_image_tokens - 1
|
||||
elif strategy == "full":
|
||||
return num_image_tokens
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
class LlavaLikeConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
vision_feature_select_strategy: Final[str]
|
||||
vision_feature_layer: Final[Union[int, List[int]]]
|
||||
|
||||
|
||||
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__(ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks)
|
||||
|
||||
vision_config = self._get_hf_config().vision_config
|
||||
self._vision_encoder_info = vision_encoder_info(vision_config)
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": get_max_llava_image_tokens(self.ctx)}
|
||||
def _apply_feature_select_strategy(
|
||||
self,
|
||||
strategy: str,
|
||||
encoder_num_image_tokens: int,
|
||||
) -> int:
|
||||
if strategy == "default":
|
||||
return encoder_num_image_tokens - 1
|
||||
if strategy == "full":
|
||||
return encoder_num_image_tokens
|
||||
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
|
||||
msg = f"Unexpected feature select strategy: {strategy!r}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_max_image_tokens(),
|
||||
)
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
image_size = self._vision_encoder_info.get_image_size()
|
||||
return ImageSize(image_size, image_size)
|
||||
|
||||
@abstractmethod
|
||||
def _get_image_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_token = self._get_image_token()
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_config(self) -> LlavaConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_hf_processor(self) -> LlavaProcessor:
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self._get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
num_image_tokens = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_image_tokens = self._get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_config(self) -> LlavaConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_hf_processor(self) -> PixtralProcessor:
|
||||
return self.ctx.get_hf_processor(PixtralProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@ -140,119 +268,82 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
# NOTE: pixel_values=None for MLlavaProcessor
|
||||
pixel_values = processed_outputs.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
if isinstance(self._get_hf_processor(), PixtralProcessor):
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list)
|
||||
and len(pixel_values) == 1)
|
||||
assert (isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
|
||||
assert (isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
hf_config = self._get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
processor = self._get_hf_processor()
|
||||
if isinstance(processor, PixtralProcessor):
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
|
||||
def get_replacement_pixtral(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
(
|
||||
num_width_tokens,
|
||||
num_height_tokens,
|
||||
) = get_pixtral_hf_image_feature_size(
|
||||
vision_config,
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
|
||||
vision_config,
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
tokens = ([image_token] * num_width_tokens +
|
||||
[image_break_token]) * num_height_tokens
|
||||
tokens[-1] = image_end_token
|
||||
tokens = ([image_token] * ncols + [image_break_token]) * nrows
|
||||
tokens[-1] = image_end_token
|
||||
|
||||
return "".join(tokens)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement_pixtral,
|
||||
),
|
||||
]
|
||||
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
return "".join(tokens)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=[image_token_id] * max_image_tokens,
|
||||
)
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
data = dummy_image_for_siglip(vision_config, num_images)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
def _build_llava_or_pixtral_hf_processor(
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True,
|
||||
) -> BaseLlavaMultiModalProcessor:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_token = hf_processor.image_token
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
||||
return PixtralHFMultiModalProcessor(
|
||||
ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
|
||||
class LlavaLikeConfig(Protocol):
|
||||
vision_config: PretrainedConfig
|
||||
vision_feature_layer: Union[int, List[int]]
|
||||
return LlavaMultiModalProcessor(
|
||||
ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
|
||||
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
||||
@ -330,7 +421,7 @@ def init_vision_tower_for_llava(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
@ -596,7 +687,12 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
) -> MultiModalInputsV2:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
|
||||
# Assume that it doesn't depend on the image size
|
||||
num_image_tokens = self._get_num_image_tokens(
|
||||
image_width=-1,
|
||||
image_height=-1,
|
||||
)
|
||||
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
@ -609,14 +705,14 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
def get_replacement_mantis(item_idx: int):
|
||||
return "".join([
|
||||
f"(image {item_idx+1}: <Image>", # 7 tokens
|
||||
"<image>" * max_image_tokens,
|
||||
"<image>" * num_image_tokens,
|
||||
"</Image>)", # 3 tokens
|
||||
])
|
||||
|
||||
mantis_repls = self._bind_prompt_replacements([
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id] * max_image_tokens,
|
||||
target=[image_token_id] * num_image_tokens,
|
||||
replacement=get_replacement_mantis,
|
||||
)
|
||||
])
|
||||
|
||||
@ -4,31 +4,25 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
|
||||
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
|
||||
from vllm.multimodal.parse import ImageSize
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
dummy_seq_data_for_clip, get_clip_image_feature_size,
|
||||
get_clip_patch_grid_length, input_processor_for_clip)
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||
from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector,
|
||||
init_vision_tower_for_llava)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
@ -65,218 +59,127 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
||||
LlavaNextImageEmbeddingInputs]
|
||||
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
||||
def _get_llava_next_num_unpadded_features(
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
def _get_hf_config(self) -> LlavaNextConfig:
|
||||
return self.ctx.get_hf_config(LlavaNextConfig)
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= 2 * padding
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= 2 * padding
|
||||
def _get_hf_processor(self) -> LlavaNextProcessor:
|
||||
return self.ctx.get_hf_processor(LlavaNextProcessor)
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
|
||||
def get_llava_next_image_feature_size(
|
||||
hf_config: LlavaNextConfig,
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
num_patches = get_clip_patch_grid_length(
|
||||
image_size=vision_config.image_size,
|
||||
patch_size=vision_config.patch_size,
|
||||
)
|
||||
base_feature_size = get_clip_image_feature_size(vision_config)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
num_patches = get_siglip_patch_grid_length(
|
||||
image_size=vision_config.image_size,
|
||||
patch_size=vision_config.patch_size,
|
||||
)
|
||||
base_feature_size = get_siglip_image_feature_size(vision_config)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
strategy = hf_config.vision_feature_select_strategy
|
||||
if strategy == "default":
|
||||
base_feature_size -= 1
|
||||
elif strategy == "full":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_size=(input_height, input_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=vision_config.image_size,
|
||||
)
|
||||
|
||||
(
|
||||
unpadded_feature_size,
|
||||
newline_feature_size,
|
||||
) = _get_llava_next_num_unpadded_features(input_height, input_width,
|
||||
num_patches, num_patch_height,
|
||||
num_patch_width)
|
||||
|
||||
return unpadded_feature_size + newline_feature_size + base_feature_size
|
||||
|
||||
|
||||
def get_max_llava_next_image_tokens(ctx: InputContext):
|
||||
"""Compute the max feature size for all possible image grid pinpoints."""
|
||||
return _get_pinpoint_with_largest_features(ctx)[0]
|
||||
|
||||
|
||||
def _get_pinpoint_with_largest_features(
|
||||
ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
|
||||
"""Get the grid pinpoint with the largest features & its feature size."""
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
largest_feature_size = 0
|
||||
largest_feature_pinpoint = None
|
||||
for (height, width) in hf_config.image_grid_pinpoints:
|
||||
feat_size = get_llava_next_image_feature_size(
|
||||
hf_config,
|
||||
input_height=height,
|
||||
input_width=width,
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = (height, width)
|
||||
if not largest_feature_size or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
return largest_feature_size, largest_feature_pinpoint
|
||||
|
||||
|
||||
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
|
||||
max_feat_height, max_feat_width = pinpoint
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(
|
||||
vision_config,
|
||||
num_images,
|
||||
image_width_override=max_feat_width,
|
||||
image_height_override=max_feat_height,
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
largest_feature_size, _ = self._get_pinpoint_with_most_features()
|
||||
return largest_feature_size
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
_, pinpoint = self._get_pinpoint_with_most_features()
|
||||
return pinpoint
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
base_feature_size = self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
num_patches = self._vision_encoder_info.get_num_patches()
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_size=(image_height, image_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=self._vision_encoder_info.get_image_size(),
|
||||
)
|
||||
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
(
|
||||
unpadded_feature_size,
|
||||
newline_feature_size,
|
||||
) = self._get_num_unpadded_features(
|
||||
original_height=image_height,
|
||||
original_width=image_width,
|
||||
npatches=num_patches,
|
||||
num_patch_height=num_patch_height,
|
||||
num_patch_width=num_patch_width,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(
|
||||
vision_config,
|
||||
num_images,
|
||||
image_width_override=max_feat_width,
|
||||
image_height_override=max_feat_height,
|
||||
)
|
||||
return unpadded_feature_size + newline_feature_size + base_feature_size
|
||||
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
||||
def _get_num_unpadded_features(
|
||||
self,
|
||||
*,
|
||||
original_height: int,
|
||||
original_width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height -= 2 * padding
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width -= 2 * padding
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]:
|
||||
"""
|
||||
Get the grid pinpoint with the most features and
|
||||
the corresponding feature size.
|
||||
"""
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
for (height, width) in hf_config.image_grid_pinpoints:
|
||||
feat_size = self._get_num_image_tokens(image_width=width,
|
||||
image_height=height)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = ImageSize(width=width,
|
||||
height=height)
|
||||
|
||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
|
||||
return largest_feature_size, largest_feature_pinpoint
|
||||
|
||||
|
||||
def input_processor_for_llava_next(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
|
||||
image_feature_size = get_llava_next_image_feature_size(
|
||||
hf_config,
|
||||
input_height=height,
|
||||
input_width=width,
|
||||
)
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_size = [
|
||||
get_llava_next_image_feature_size(hf_config,
|
||||
input_height=img.height,
|
||||
input_width=img.width)
|
||||
for img in image_data
|
||||
]
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
elif is_list_of(image_data, torch.Tensor):
|
||||
image_feature_size = [item.shape[1] for item in image_data]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return input_processor_for_siglip(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
@ -507,7 +410,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
inputs: LlavaNextImagePixelInputs,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
|
||||
@ -34,7 +34,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.parse import ImageEmbeddingItems, ImageProcessorItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement,
|
||||
@ -388,15 +388,19 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
def get_replacement_phi3v(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
num_tokens = self._get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
num_image_tokens = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_image_tokens = self._get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
|
||||
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
|
||||
|
||||
num_images = mm_items.get_count("image", strict=False)
|
||||
|
||||
|
||||
@ -38,6 +38,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
from .vision import VisionEncoderInfo
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
@ -697,10 +698,18 @@ def get_pixtral_hf_patch_grid_length(*, image_size: int,
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||
grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size,
|
||||
patch_size=patch_size)
|
||||
return grid_length * grid_length
|
||||
def get_pixtral_hf_image_feature_size(
|
||||
*,
|
||||
image_size: int,
|
||||
patch_size: int,
|
||||
) -> int:
|
||||
grid_length = get_pixtral_hf_patch_grid_length(
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
# Consider the image_break_token
|
||||
return (grid_length + 1) * grid_length
|
||||
|
||||
|
||||
def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
|
||||
@ -730,13 +739,16 @@ def dummy_image_for_pixtral_hf(
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
|
||||
image_width: int,
|
||||
image_height: int) -> Tuple[int, int]:
|
||||
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
|
||||
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501
|
||||
max_width, max_height = hf_config.image_size, hf_config.image_size
|
||||
patch_width, patch_height = hf_config.patch_size, hf_config.patch_size
|
||||
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
|
||||
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180
|
||||
def get_pixtral_hf_image_feature_grid_size(
|
||||
hf_config: PixtralVisionConfig,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
max_width = max_height = hf_config.image_size
|
||||
patch_width = patch_height = hf_config.patch_size
|
||||
|
||||
ratio = max(image_width / max_width, image_height / max_height)
|
||||
|
||||
@ -744,12 +756,38 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
|
||||
image_width = int(math.ceil(image_width / ratio))
|
||||
image_height = int(math.ceil(image_height / ratio))
|
||||
|
||||
num_height_tokens, num_width_tokens = _get_pixtral_hf_num_image_tokens(
|
||||
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
||||
(image_height, image_width),
|
||||
(patch_height, patch_width),
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
return num_width_tokens, num_height_tokens
|
||||
return ncols, nrows
|
||||
|
||||
|
||||
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
return get_pixtral_hf_image_feature_size(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.get_image_size(),
|
||||
)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_pixtral_hf_image_tokens(self.vision_config)
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
return get_pixtral_hf_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
|
||||
class PixtralHFMLP(nn.Module):
|
||||
|
||||
@ -28,6 +28,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
resolve_visual_encoder_outputs)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
from .vision import VisionEncoderInfo
|
||||
|
||||
|
||||
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
# Since interpolation is applied, the image size need not be divisible
|
||||
@ -156,6 +158,29 @@ def input_processor_for_siglip(
|
||||
multi_modal_placeholders={"image": ranges})
|
||||
|
||||
|
||||
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
return get_siglip_image_feature_size(self.vision_config)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
return get_max_siglip_image_tokens(self.vision_config)
|
||||
|
||||
def get_num_patches(self) -> int:
|
||||
return get_siglip_patch_grid_length(
|
||||
image_size=self.vision_config.image_size,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
)
|
||||
|
||||
def get_image_size(self) -> int:
|
||||
return self.vision_config.image_size
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
|
||||
class SiglipVisionEmbeddings(nn.Module):
|
||||
|
||||
|
||||
@ -373,7 +373,7 @@ def embed_multimodal(
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_token_id: int,
|
||||
get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
|
||||
multimodal_embeds: Union[torch.Tensor, List[torch.Tensor]],
|
||||
multimodal_embeds: NestedTensors,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Embed token IDs and multimodal inputs and combine their embeddings.
|
||||
|
||||
52
vllm/model_executor/models/vision.py
Normal file
52
vllm/model_executor/models/vision.py
Normal file
@ -0,0 +1,52 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||
|
||||
|
||||
class VisionEncoderInfo(ABC, Generic[_C]):
|
||||
|
||||
def __init__(self, vision_config: _C) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.vision_config = vision_config
|
||||
|
||||
@abstractmethod
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_max_image_tokens(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_patches(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_image_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
|
||||
# Avoid circular imports
|
||||
from .clip import CLIPEncoderInfo, CLIPVisionConfig
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
|
||||
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPEncoderInfo(vision_config)
|
||||
if isinstance(vision_config, PixtralVisionConfig):
|
||||
return PixtralHFEncoderInfo(vision_config)
|
||||
if isinstance(vision_config, SiglipVisionConfig):
|
||||
return SiglipEncoderInfo(vision_config)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
@ -1,7 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, Iterator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar
|
||||
from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar,
|
||||
Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -87,7 +88,7 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> object:
|
||||
def get(self, index: int) -> torch.Tensor:
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
@ -96,6 +97,9 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}_embeds": self.data}
|
||||
|
||||
def get_feature_size(self, item_idx: int) -> int:
|
||||
return len(self.get(item_idx))
|
||||
|
||||
|
||||
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
||||
|
||||
@ -182,7 +186,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
|
||||
def get_items(
|
||||
self,
|
||||
modality: str,
|
||||
typ: type[_D],
|
||||
typ: Union[type[_D], tuple[type[_D], ...]],
|
||||
) -> _D:
|
||||
"""
|
||||
Get the data items belonging to a modality,
|
||||
@ -199,7 +203,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(items)}")
|
||||
|
||||
return items
|
||||
return items # type: ignore[return-value]
|
||||
|
||||
|
||||
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user