mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 17:35:28 +08:00
[Core][VLM] Support image embeddings as input (#6613)
This commit is contained in:
parent
ec2affa8ae
commit
e6e42e4b17
@ -49,6 +49,17 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI
|
|||||||
"multi_modal_data": {"image": image},
|
"multi_modal_data": {"image": image},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
# Inference with image embeddings as input
|
||||||
|
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
|
||||||
|
outputs = llm.generate({
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {"image": image_embeds},
|
||||||
|
})
|
||||||
|
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
generated_text = o.outputs[0].text
|
generated_text = o.outputs[0].text
|
||||||
print(generated_text)
|
print(generated_text)
|
||||||
|
|||||||
159
tests/models/test_llava_image_embeds.py
Normal file
159
tests/models/test_llava_image_embeds.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
|
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||||
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
|
"stop_sign":
|
||||||
|
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
||||||
|
"cherry_blossom":
|
||||||
|
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
||||||
|
})
|
||||||
|
|
||||||
|
models = [
|
||||||
|
"llava-hf/llava-1.5-7b-hf",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||||
|
Optional[SampleLogprobs]],
|
||||||
|
model: str):
|
||||||
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
image_token_id = config.image_token_index
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
hf_output_ids = [
|
||||||
|
token_id for idx, token_id in enumerate(output_ids)
|
||||||
|
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert output_str[0] == " "
|
||||||
|
hf_output_str = output_str[1:]
|
||||||
|
if hf_output_ids[-1] == eos_token_id:
|
||||||
|
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||||
|
|
||||||
|
return hf_output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
size_factors: List[float],
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
|
All the image fixtures for the test is under tests/images.
|
||||||
|
For huggingface runner, we provide the PIL images as input.
|
||||||
|
For vllm runner, we provide MultiModalDataDict objects
|
||||||
|
and corresponding vision language config as input.
|
||||||
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
|
The text output is sanitized to be able to compare with hf.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# vLLM to load from image embeddings
|
||||||
|
vllm_images = [asset.image_embeds for asset in image_assets]
|
||||||
|
|
||||||
|
# transformers to load from PIL images
|
||||||
|
hf_images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
vllm_inputs_per_image = [(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[image for _ in size_factors],
|
||||||
|
) for image, prompt in zip(vllm_images, HF_IMAGE_PROMPTS)]
|
||||||
|
|
||||||
|
hf_inputs_per_image = [(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[image for _ in size_factors],
|
||||||
|
) for image, prompt in zip(hf_images, HF_IMAGE_PROMPTS)]
|
||||||
|
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
|
# max_model_len should be greater than image_feature_size
|
||||||
|
with vllm_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True) as vllm_model:
|
||||||
|
vllm_outputs_per_image = [
|
||||||
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images)
|
||||||
|
for prompts, images in vllm_inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||||
|
hf_outputs_per_image = [
|
||||||
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=images)
|
||||||
|
for prompts, images in hf_inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||||
|
vllm_outputs_per_image):
|
||||||
|
# TODO: Check whether using original CLIPVisionModel can improve
|
||||||
|
# consistency against HF
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=[
|
||||||
|
vllm_to_hf_output(vllm_output, model)
|
||||||
|
for vllm_output in vllm_outputs
|
||||||
|
],
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# No image
|
||||||
|
[],
|
||||||
|
# Single-scale
|
||||||
|
[1.0],
|
||||||
|
# Single-scale, batched
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||||
|
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
image_assets,
|
||||||
|
model,
|
||||||
|
size_factors=size_factors,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from vllm.assets.base import get_vllm_public_assets
|
from vllm.assets.base import get_vllm_public_assets
|
||||||
@ -18,3 +19,12 @@ class ImageAsset:
|
|||||||
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
|
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
|
||||||
s3_prefix=VLM_IMAGES_DIR)
|
s3_prefix=VLM_IMAGES_DIR)
|
||||||
return Image.open(image_path)
|
return Image.open(image_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_embeds(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Image embeddings, only used for testing purposes with llava 1.5.
|
||||||
|
"""
|
||||||
|
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
|
||||||
|
s3_prefix=VLM_IMAGES_DIR)
|
||||||
|
return torch.load(image_path)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -28,6 +28,29 @@ _KEYS_TO_MODIFY_MAPPING = {
|
|||||||
"language_model.model": "language_model",
|
"language_model.model": "language_model",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# We use this internally as placeholders since there is no image token
|
||||||
|
# defined on the HuggingFace repo
|
||||||
|
BLIP2_IMAGE_TOKEN = "<image>"
|
||||||
|
BLIP2_IMAGE_TOKEN_ID = 50265
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2ImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: (batch_size, num_channels, height, width)"""
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2ImageEmbeddingInputs(TypedDict):
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||||
|
|
||||||
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
class Blip2QFormerMultiHeadAttention(nn.Module):
|
class Blip2QFormerMultiHeadAttention(nn.Module):
|
||||||
|
|
||||||
@ -375,20 +398,6 @@ class Blip2QFormerModel(nn.Module):
|
|||||||
return sequence_output
|
return sequence_output
|
||||||
|
|
||||||
|
|
||||||
class Blip2ImagePixelInputs(TypedDict):
|
|
||||||
type: Literal["pixel_values"]
|
|
||||||
data: torch.Tensor
|
|
||||||
"""Shape: (batch_size, num_channels, height, width)"""
|
|
||||||
|
|
||||||
|
|
||||||
Blip2ImageInputs = Blip2ImagePixelInputs
|
|
||||||
|
|
||||||
# We use this internally as placeholders since there is no image token
|
|
||||||
# defined on the HuggingFace repo
|
|
||||||
BLIP2_IMAGE_TOKEN = "<image>"
|
|
||||||
BLIP2_IMAGE_TOKEN_ID = 50265
|
|
||||||
|
|
||||||
|
|
||||||
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
|
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
|
||||||
return hf_config.num_query_tokens
|
return hf_config.num_query_tokens
|
||||||
|
|
||||||
@ -506,18 +515,31 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
|
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(pixel_values, torch.Tensor):
|
if pixel_values is not None:
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
if not isinstance(pixel_values, torch.Tensor):
|
||||||
f"Got type: {type(pixel_values)}")
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
return Blip2ImagePixelInputs(
|
return Blip2ImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if image_embeds is not None:
|
||||||
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
|
f"Got type: {type(image_embeds)}")
|
||||||
|
return Blip2ImageEmbeddingInputs(
|
||||||
|
type="image_embeds",
|
||||||
|
data=image_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
|
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
|
||||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
@ -538,6 +560,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
|
|
||||||
def _process_image_input(self,
|
def _process_image_input(self,
|
||||||
image_input: Blip2ImageInputs) -> torch.Tensor:
|
image_input: Blip2ImageInputs) -> torch.Tensor:
|
||||||
|
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
return image_input["data"]
|
||||||
|
|
||||||
assert self.vision_model is not None
|
assert self.vision_model is not None
|
||||||
image_features = self._process_image_pixels(image_input)
|
image_features = self._process_image_pixels(image_input)
|
||||||
|
|
||||||
|
|||||||
@ -88,7 +88,13 @@ def input_processor_for_clip(
|
|||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||||
|
|
||||||
if image_feature_size_override is None:
|
if image_feature_size_override is None:
|
||||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
image_data = multi_modal_data["image"]
|
||||||
|
if isinstance(image_data, Image.Image):
|
||||||
|
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||||
|
elif isinstance(image_data, torch.Tensor):
|
||||||
|
image_feature_size = image_data.shape[0]
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||||
else:
|
else:
|
||||||
image_feature_size = image_feature_size_override
|
image_feature_size = image_feature_size_override
|
||||||
|
|
||||||
|
|||||||
@ -234,7 +234,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
|
|
||||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
|
||||||
image_patches = kwargs.pop("image_patches", None)
|
image_patches = kwargs.pop("image_patches", None)
|
||||||
|
|
||||||
if isinstance(image_patches, torch.Tensor):
|
if isinstance(image_patches, torch.Tensor):
|
||||||
@ -249,6 +250,13 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
|||||||
data=image_patches)
|
data=image_patches)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _process_image_input(
|
||||||
|
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
|
||||||
|
|
||||||
|
assert self.vision_embed_tokens is not None
|
||||||
|
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
|
||||||
|
return vision_embeddings
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -261,8 +269,7 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
|||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
|
||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
vision_embeddings, _ = self.vision_embed_tokens(
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
image_input["data"])
|
|
||||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||||
vision_embeddings,
|
vision_embeddings,
|
||||||
|
|||||||
@ -50,6 +50,19 @@ class InternVLImagePixelInputs(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLImageEmbeddingInputs(TypedDict):
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||||
|
|
||||||
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
InternVLImageInputs = Union[InternVLImagePixelInputs,
|
||||||
|
InternVLImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
|
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||||
def build_transform(input_size):
|
def build_transform(input_size):
|
||||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||||
@ -193,8 +206,10 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
# add thumbnail image if num_blocks > 1
|
# add thumbnail image if num_blocks > 1
|
||||||
if hf_config.use_thumbnail and num_blocks > 1:
|
if hf_config.use_thumbnail and num_blocks > 1:
|
||||||
num_blocks += 1
|
num_blocks += 1
|
||||||
|
image_feature_size = num_blocks * num_patches
|
||||||
|
|
||||||
elif isinstance(image_data, torch.Tensor):
|
elif isinstance(image_data, torch.Tensor):
|
||||||
raise NotImplementedError("Embeddings input is not supported yet")
|
image_feature_size = image_data.shape[0]
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||||
|
|
||||||
@ -205,7 +220,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = tokenizer.decode(prompt_token_ids)
|
prompt = tokenizer.decode(prompt_token_ids)
|
||||||
image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END
|
image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
|
||||||
new_prompt = prompt.replace('<image>', image_prompt, 1)
|
new_prompt = prompt.replace('<image>', image_prompt, 1)
|
||||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||||
|
|
||||||
@ -378,23 +393,49 @@ class InternVLChatModel(nn.Module, SupportsVision):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[InternVLImagePixelInputs]:
|
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
image_token_id = kwargs.pop("image_token_id", None)
|
image_token_id = kwargs.pop("image_token_id", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if image_embeds is not None:
|
||||||
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
|
f"Got type: {type(image_embeds)}")
|
||||||
|
return InternVLImageEmbeddingInputs(
|
||||||
|
type="image_embeds",
|
||||||
|
data=image_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
self.img_context_token_id = image_token_id[0]
|
self.img_context_token_id = image_token_id[0]
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
if pixel_values is not None:
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
f"Got type: {type(pixel_values)}")
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
return InternVLImagePixelInputs(
|
return InternVLImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
|
def _process_image_input(
|
||||||
|
self,
|
||||||
|
image_input: InternVLImageInputs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
return image_input["data"]
|
||||||
|
|
||||||
|
assert self.vision_model is not None
|
||||||
|
image_embeds = self.extract_feature(image_input["data"])
|
||||||
|
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -409,9 +450,9 @@ class InternVLChatModel(nn.Module, SupportsVision):
|
|||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
input_ids)
|
input_ids)
|
||||||
vit_embeds = self.extract_feature(image_input["data"])
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||||
vit_embeds,
|
vision_embeddings,
|
||||||
self.img_context_token_id)
|
self.img_context_token_id)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -27,6 +27,24 @@ from .utils import (filter_weights, init_vllm_registered_model,
|
|||||||
merge_vision_embeddings)
|
merge_vision_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaImageEmbeddingInputs(TypedDict):
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||||
|
|
||||||
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
# TODO(xwjiang): Run benchmark and decide if TP.
|
# TODO(xwjiang): Run benchmark and decide if TP.
|
||||||
class LlavaMultiModalProjector(nn.Module):
|
class LlavaMultiModalProjector(nn.Module):
|
||||||
|
|
||||||
@ -49,15 +67,6 @@ class LlavaMultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class LlavaImagePixelInputs(TypedDict):
|
|
||||||
type: Literal["pixel_values"]
|
|
||||||
data: torch.Tensor
|
|
||||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
|
||||||
|
|
||||||
|
|
||||||
LlavaImageInputs = LlavaImagePixelInputs
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_llava_image_tokens(ctx: InputContext):
|
def get_max_llava_image_tokens(ctx: InputContext):
|
||||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
@ -210,18 +219,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(pixel_values, torch.Tensor):
|
if pixel_values is not None:
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
if not isinstance(pixel_values, torch.Tensor):
|
||||||
f"Got type: {type(pixel_values)}")
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
return LlavaImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
data=self._validate_pixel_values(pixel_values),
|
||||||
|
)
|
||||||
|
|
||||||
return LlavaImagePixelInputs(
|
if image_embeds is not None:
|
||||||
type="pixel_values",
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
data=self._validate_pixel_values(pixel_values),
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
)
|
f"Got type: {type(image_embeds)}")
|
||||||
|
return LlavaImageEmbeddingInputs(
|
||||||
|
type="image_embeds",
|
||||||
|
data=image_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||||
strategy: str) -> torch.Tensor:
|
strategy: str) -> torch.Tensor:
|
||||||
@ -258,6 +279,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
|
|
||||||
def _process_image_input(self,
|
def _process_image_input(self,
|
||||||
image_input: LlavaImageInputs) -> torch.Tensor:
|
image_input: LlavaImageInputs) -> torch.Tensor:
|
||||||
|
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
return image_input["data"]
|
||||||
|
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
image_features = self._process_image_pixels(image_input)
|
image_features = self._process_image_pixels(image_input)
|
||||||
return self.multi_modal_projector(image_features)
|
return self.multi_modal_projector(image_features)
|
||||||
|
|||||||
@ -60,7 +60,17 @@ class LlavaNextImagePixelInputs(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
LlavaNextImageInputs = LlavaNextImagePixelInputs
|
class LlavaNextImageEmbeddingInputs(TypedDict):
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||||
|
|
||||||
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
||||||
|
LlavaNextImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
||||||
@ -208,7 +218,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
input_width=width,
|
input_width=width,
|
||||||
)
|
)
|
||||||
elif isinstance(image_data, torch.Tensor):
|
elif isinstance(image_data, torch.Tensor):
|
||||||
raise NotImplementedError("Embeddings input is not supported yet")
|
image_feature_size = image_data.shape[0]
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||||
|
|
||||||
@ -320,26 +330,40 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
|
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
image_sizes = kwargs.pop("image_sizes", None)
|
image_sizes = kwargs.pop("image_sizes", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
if pixel_values is not None:
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
f"Got type: {type(pixel_values)}")
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
if not isinstance(image_sizes, torch.Tensor):
|
if not isinstance(image_sizes, torch.Tensor):
|
||||||
raise ValueError("Incorrect type of image sizes. "
|
raise ValueError("Incorrect type of image sizes. "
|
||||||
f"Got type: {type(image_sizes)}")
|
f"Got type: {type(image_sizes)}")
|
||||||
|
|
||||||
return LlavaNextImagePixelInputs(
|
return LlavaNextImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
image_sizes=self._validate_image_sizes(image_sizes),
|
image_sizes=self._validate_image_sizes(image_sizes),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if image_embeds is not None:
|
||||||
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of image embeds. "
|
||||||
|
f"Got type: {type(image_embeds)}")
|
||||||
|
|
||||||
|
return LlavaNextImageEmbeddingInputs(
|
||||||
|
type="image_embeds",
|
||||||
|
data=image_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||||
strategy: str) -> torch.Tensor:
|
strategy: str) -> torch.Tensor:
|
||||||
@ -466,6 +490,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
self,
|
self,
|
||||||
image_input: LlavaNextImageInputs,
|
image_input: LlavaNextImageInputs,
|
||||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
return [image_input["data"]]
|
||||||
|
|
||||||
patch_embeddings = self._process_image_pixels(image_input)
|
patch_embeddings = self._process_image_pixels(image_input)
|
||||||
|
|
||||||
image_sizes = image_input.get("image_sizes")
|
image_sizes = image_input.get("image_sizes")
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -31,6 +31,25 @@ _KEYS_TO_MODIFY_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: (batch_size, num_channels, height, width)"""
|
||||||
|
|
||||||
|
|
||||||
|
class PaliGemmaImageEmbeddingInputs(TypedDict):
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||||
|
|
||||||
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
|
||||||
|
PaliGemmaImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
def get_max_paligemma_image_tokens(ctx: InputContext):
|
def get_max_paligemma_image_tokens(ctx: InputContext):
|
||||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
@ -107,15 +126,6 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaImagePixelInputs(TypedDict):
|
|
||||||
type: Literal["pixel_values"]
|
|
||||||
data: torch.Tensor
|
|
||||||
"""Shape: (batch_size, num_channels, height, width)"""
|
|
||||||
|
|
||||||
|
|
||||||
PaliGemmaImageInputs = PaliGemmaImagePixelInputs
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
||||||
@ -163,18 +173,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
|
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(pixel_values, torch.Tensor):
|
if pixel_values is not None:
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
if not isinstance(pixel_values, torch.Tensor):
|
||||||
f"Got type: {type(pixel_values)}")
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
return PaliGemmaImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
data=self._validate_pixel_values(pixel_values),
|
||||||
|
)
|
||||||
|
|
||||||
return PaliGemmaImagePixelInputs(
|
if image_embeds is not None:
|
||||||
type="pixel_values",
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
data=self._validate_pixel_values(pixel_values),
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
)
|
f"Got type: {type(image_embeds)}")
|
||||||
|
return PaliGemmaImageEmbeddingInputs(
|
||||||
|
type="image_embeds",
|
||||||
|
data=image_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _image_pixels_to_features(
|
def _image_pixels_to_features(
|
||||||
self,
|
self,
|
||||||
@ -187,26 +209,20 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
|
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
def _process_image_pixels(
|
|
||||||
self,
|
|
||||||
inputs: PaliGemmaImagePixelInputs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert self.vision_tower is not None
|
|
||||||
|
|
||||||
pixel_values = inputs["data"]
|
|
||||||
|
|
||||||
return self._image_pixels_to_features(
|
|
||||||
self.vision_tower,
|
|
||||||
pixel_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self,
|
self,
|
||||||
image_input: PaliGemmaImageInputs,
|
image_input: PaliGemmaImageInputs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
return image_input["data"]
|
||||||
|
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
image_features = self._process_image_pixels(image_input, )
|
pixel_values = image_input["data"]
|
||||||
|
image_features = self._image_pixels_to_features(
|
||||||
|
self.vision_tower,
|
||||||
|
pixel_values,
|
||||||
|
)
|
||||||
|
|
||||||
return self.multi_modal_projector(image_features)
|
return self.multi_modal_projector(image_features)
|
||||||
|
|
||||||
|
|||||||
@ -70,6 +70,36 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
|||||||
projection_dim=768)
|
projection_dim=768)
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3VImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||||
|
|
||||||
|
Note that `num_patches` may be different for each batch, in which case
|
||||||
|
the data is passed as a list instead of a batched tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_sizes: torch.Tensor
|
||||||
|
"""
|
||||||
|
Shape: `(batch_size, 2)`
|
||||||
|
|
||||||
|
This should be in `(height, width)` format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3VImageEmbeddingInputs(TypedDict):
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||||
|
|
||||||
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
class Phi3ImageEmbeddingBase(nn.Module):
|
class Phi3ImageEmbeddingBase(nn.Module):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -257,24 +287,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
|||||||
return image_features_hd_newline
|
return image_features_hd_newline
|
||||||
|
|
||||||
|
|
||||||
class Phi3VImagePixelInputs(TypedDict):
|
|
||||||
type: Literal["pixel_values"]
|
|
||||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
|
||||||
"""
|
|
||||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
|
||||||
|
|
||||||
Note that `num_patches` may be different for each batch, in which case
|
|
||||||
the data is passed as a list instead of a batched tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
image_sizes: torch.Tensor
|
|
||||||
"""
|
|
||||||
Shape: `(batch_size, 2)`
|
|
||||||
|
|
||||||
This should be in `(height, width)` format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
||||||
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||||
target_height = int(np.ceil(height / padding_unit) * padding_unit)
|
target_height = int(np.ceil(height / padding_unit) * padding_unit)
|
||||||
@ -390,7 +402,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
input_width=w,
|
input_width=w,
|
||||||
input_height=h)
|
input_height=h)
|
||||||
elif isinstance(image_data, torch.Tensor):
|
elif isinstance(image_data, torch.Tensor):
|
||||||
raise NotImplementedError("Embeddings input is not supported yet")
|
image_feature_size = image_data.shape[0]
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||||
|
|
||||||
@ -494,25 +506,55 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
|
self, **kwargs: object) -> Optional[Phi3VImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
image_sizes = kwargs.pop("image_sizes", None)
|
image_sizes = kwargs.pop("image_sizes", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
if pixel_values is None and image_embeds is None:
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
return None
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
if not isinstance(image_sizes, torch.Tensor):
|
if pixel_values is not None:
|
||||||
raise ValueError("Incorrect type of image sizes. "
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
f"Got type: {type(image_sizes)}")
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
return Phi3VImagePixelInputs(
|
if not isinstance(image_sizes, torch.Tensor):
|
||||||
type="pixel_values",
|
raise ValueError("Incorrect type of image sizes. "
|
||||||
data=self._validate_pixel_values(pixel_values),
|
f"Got type: {type(image_sizes)}")
|
||||||
image_sizes=self._validate_image_sizes(image_sizes))
|
|
||||||
|
return Phi3VImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
data=self._validate_pixel_values(pixel_values),
|
||||||
|
image_sizes=self._validate_image_sizes(image_sizes))
|
||||||
|
|
||||||
|
if image_embeds is not None:
|
||||||
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
|
f"Got type: {type(image_embeds)}")
|
||||||
|
return Phi3VImageEmbeddingInputs(
|
||||||
|
type="image_embeds",
|
||||||
|
data=image_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
|
def _process_image_input(
|
||||||
|
self,
|
||||||
|
image_input: Phi3VImageInputs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
return image_input["data"]
|
||||||
|
|
||||||
|
assert self.vision_embed_tokens is not None
|
||||||
|
image_embeds = self.vision_embed_tokens(image_input["data"],
|
||||||
|
image_input["image_sizes"])
|
||||||
|
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -524,8 +566,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
|||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
|
||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
vision_embeddings = self.vision_embed_tokens(
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
image_input["data"], image_input["image_sizes"])
|
|
||||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||||
vision_embeddings,
|
vision_embeddings,
|
||||||
|
|||||||
@ -97,7 +97,13 @@ def input_processor_for_siglip(
|
|||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||||
|
|
||||||
if image_feature_size_override is None:
|
if image_feature_size_override is None:
|
||||||
image_feature_size = get_siglip_image_feature_size(hf_config)
|
image_data = multi_modal_data["image"]
|
||||||
|
if isinstance(image_data, Image.Image):
|
||||||
|
image_feature_size = get_siglip_image_feature_size(hf_config)
|
||||||
|
elif isinstance(image_data, torch.Tensor):
|
||||||
|
image_feature_size = image_data.shape[0]
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||||
else:
|
else:
|
||||||
image_feature_size = image_feature_size_override
|
image_feature_size = image_feature_size_override
|
||||||
|
|
||||||
|
|||||||
@ -115,6 +115,7 @@ class ImagePlugin(MultiModalPlugin):
|
|||||||
data: object) -> MultiModalInputs:
|
data: object) -> MultiModalInputs:
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
|
|
||||||
|
# PIL image
|
||||||
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
||||||
image_processor = self._get_hf_image_processor(model_config)
|
image_processor = self._get_hf_image_processor(model_config)
|
||||||
if image_processor is None:
|
if image_processor is None:
|
||||||
@ -129,8 +130,10 @@ class ImagePlugin(MultiModalPlugin):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
return MultiModalInputs(batch_data)
|
return MultiModalInputs(batch_data)
|
||||||
|
|
||||||
|
# Image embedding
|
||||||
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
|
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
|
||||||
raise NotImplementedError("Embeddings input is not supported yet")
|
return MultiModalInputs({"image_embeds": data})
|
||||||
|
|
||||||
raise TypeError(f"Invalid image type: {type(data)}")
|
raise TypeError(f"Invalid image type: {type(data)}")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user