[Core][VLM] Support image embeddings as input (#6613)

This commit is contained in:
Roger Wang 2024-08-12 01:16:06 -07:00 committed by GitHub
parent ec2affa8ae
commit e6e42e4b17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 517 additions and 138 deletions

View File

@ -49,6 +49,17 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI
"multi_modal_data": {"image": image},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Inference with image embeddings as input
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": image_embeds},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)

View File

@ -0,0 +1,159 @@
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
"cherry_blossom":
"USER: <image>\nWhat is the season?\nASSISTANT:",
})
models = [
"llava-hf/llava-1.5-7b-hf",
]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
assert output_str[0] == " "
hf_output_str = output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# vLLM to load from image embeddings
vllm_images = [asset.image_embeds for asset in image_assets]
# transformers to load from PIL images
hf_images = [asset.pil_image for asset in image_assets]
vllm_inputs_per_image = [(
[prompt for _ in size_factors],
[image for _ in size_factors],
) for image, prompt in zip(vllm_images, HF_IMAGE_PROMPTS)]
hf_inputs_per_image = [(
[prompt for _ in size_factors],
[image for _ in size_factors],
) for image, prompt in zip(hf_images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in vllm_inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in hf_inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Literal
import torch
from PIL import Image
from vllm.assets.base import get_vllm_public_assets
@ -18,3 +19,12 @@ class ImageAsset:
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
s3_prefix=VLM_IMAGES_DIR)
return Image.open(image_path)
@property
def image_embeds(self) -> torch.Tensor:
"""
Image embeddings, only used for testing purposes with llava 1.5.
"""
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
s3_prefix=VLM_IMAGES_DIR)
return torch.load(image_path)

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -28,6 +28,29 @@ _KEYS_TO_MODIFY_MAPPING = {
"language_model.model": "language_model",
}
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class Blip2ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
class Blip2QFormerMultiHeadAttention(nn.Module):
@ -375,20 +398,6 @@ class Blip2QFormerModel(nn.Module):
return sequence_output
class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
Blip2ImageInputs = Blip2ImagePixelInputs
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
return hf_config.num_query_tokens
@ -506,18 +515,31 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Blip2ImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
@ -538,6 +560,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(self,
image_input: Blip2ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)

View File

@ -88,7 +88,13 @@ def input_processor_for_clip(
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else:
image_feature_size = image_feature_size_override

View File

@ -234,7 +234,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
cache_config=cache_config,
quant_config=quant_config)
def _parse_and_validate_image_input(self, **kwargs: object):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
image_patches = kwargs.pop("image_patches", None)
if isinstance(image_patches, torch.Tensor):
@ -249,6 +250,13 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
data=image_patches)
return None
def _process_image_input(
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings
def forward(
self,
input_ids: torch.Tensor,
@ -261,8 +269,7 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings, _ = self.vision_embed_tokens(
image_input["data"])
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,

View File

@ -50,6 +50,19 @@ class InternVLImagePixelInputs(TypedDict):
"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
@ -193,8 +206,10 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
# add thumbnail image if num_blocks > 1
if hf_config.use_thumbnail and num_blocks > 1:
num_blocks += 1
image_feature_size = num_blocks * num_patches
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
@ -205,7 +220,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END
image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
new_prompt = prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt)
@ -378,23 +393,49 @@ class InternVLChatModel(nn.Module, SupportsVision):
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternVLImagePixelInputs]:
self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_token_id = kwargs.pop("image_token_id", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
self.img_context_token_id = image_token_id[0]
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"])
return image_embeds
def forward(
self,
@ -409,9 +450,9 @@ class InternVLChatModel(nn.Module, SupportsVision):
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vit_embeds = self.extract_feature(image_input["data"])
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vit_embeds,
vision_embeddings,
self.img_context_token_id)
input_ids = None
else:

View File

@ -27,6 +27,24 @@ from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):
@ -49,15 +67,6 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
LlavaImageInputs = LlavaImagePixelInputs
def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
@ -210,18 +219,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
@ -258,6 +279,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)

View File

@ -60,7 +60,17 @@ class LlavaNextImagePixelInputs(TypedDict):
"""
LlavaNextImageInputs = LlavaNextImagePixelInputs
class LlavaNextImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs]
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
@ -208,7 +218,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
input_width=width,
)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
@ -320,26 +330,40 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
@ -466,6 +490,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self,
image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return [image_input["data"]]
patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes")

View File

@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
from torch import nn
@ -31,6 +31,25 @@ _KEYS_TO_MODIFY_MAPPING = {
}
class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class PaliGemmaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
PaliGemmaImageEmbeddingInputs]
def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
@ -107,15 +126,6 @@ class PaliGemmaMultiModalProjector(nn.Module):
return hidden_states
class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
PaliGemmaImageInputs = PaliGemmaImagePixelInputs
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@ -163,18 +173,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return PaliGemmaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
return PaliGemmaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return PaliGemmaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(
self,
@ -187,26 +209,20 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return image_features
def _process_image_pixels(
self,
inputs: PaliGemmaImagePixelInputs,
) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
def _process_image_input(
self,
image_input: PaliGemmaImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input, )
pixel_values = image_input["data"]
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
return self.multi_modal_projector(image_features)

View File

@ -70,6 +70,36 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
projection_dim=768)
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
class Phi3VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
class Phi3ImageEmbeddingBase(nn.Module):
def __init__(self) -> None:
@ -257,24 +287,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
@ -390,7 +402,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
input_width=w,
input_height=h)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
@ -494,25 +506,55 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
self, **kwargs: object) -> Optional[Phi3VImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is None and image_embeds is None:
return None
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Phi3VImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: Phi3VImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(image_input["data"],
image_input["image_sizes"])
return image_embeds
def forward(self,
input_ids: torch.Tensor,
@ -524,8 +566,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self.vision_embed_tokens(
image_input["data"], image_input["image_sizes"])
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,

View File

@ -97,7 +97,13 @@ def input_processor_for_siglip(
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_feature_size = get_siglip_image_feature_size(hf_config)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_siglip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else:
image_feature_size = image_feature_size_override

View File

@ -115,6 +115,7 @@ class ImagePlugin(MultiModalPlugin):
data: object) -> MultiModalInputs:
model_config = ctx.model_config
# PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config)
if image_processor is None:
@ -129,8 +130,10 @@ class ImagePlugin(MultiModalPlugin):
raise
return MultiModalInputs(batch_data)
# Image embedding
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
return MultiModalInputs({"image_embeds": data})
raise TypeError(f"Invalid image type: {type(data)}")