mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
[Core][VLM] Support image embeddings as input (#6613)
This commit is contained in:
parent
ec2affa8ae
commit
e6e42e4b17
@ -49,6 +49,17 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI
|
||||
"multi_modal_data": {"image": image},
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
# Inference with image embeddings as input
|
||||
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
|
||||
outputs = llm.generate({
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {"image": image_embeds},
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
159
tests/models/test_llava_image_embeds.py
Normal file
159
tests/models/test_llava_image_embeds.py
Normal file
@ -0,0 +1,159 @@
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
||||
"cherry_blossom":
|
||||
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
||||
})
|
||||
|
||||
models = [
|
||||
"llava-hf/llava-1.5-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
model: str):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
image_token_id = config.image_token_index
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
hf_output_ids = [
|
||||
token_id for idx, token_id in enumerate(output_ids)
|
||||
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
||||
]
|
||||
|
||||
assert output_str[0] == " "
|
||||
hf_output_str = output_str[1:]
|
||||
if hf_output_ids[-1] == eos_token_id:
|
||||
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||
|
||||
return hf_output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
|
||||
# vLLM to load from image embeddings
|
||||
vllm_images = [asset.image_embeds for asset in image_assets]
|
||||
|
||||
# transformers to load from PIL images
|
||||
hf_images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
vllm_inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[image for _ in size_factors],
|
||||
) for image, prompt in zip(vllm_images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
hf_inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[image for _ in size_factors],
|
||||
) for image, prompt in zip(hf_images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in vllm_inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in hf_inputs_per_image
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
vllm_outputs_per_image):
|
||||
# TODO: Check whether using original CLIPVisionModel can improve
|
||||
# consistency against HF
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, model)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm.assets.base import get_vllm_public_assets
|
||||
@ -18,3 +19,12 @@ class ImageAsset:
|
||||
image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
|
||||
s3_prefix=VLM_IMAGES_DIR)
|
||||
return Image.open(image_path)
|
||||
|
||||
@property
|
||||
def image_embeds(self) -> torch.Tensor:
|
||||
"""
|
||||
Image embeddings, only used for testing purposes with llava 1.5.
|
||||
"""
|
||||
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
|
||||
s3_prefix=VLM_IMAGES_DIR)
|
||||
return torch.load(image_path)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -28,6 +28,29 @@ _KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.model": "language_model",
|
||||
}
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
# defined on the HuggingFace repo
|
||||
BLIP2_IMAGE_TOKEN = "<image>"
|
||||
BLIP2_IMAGE_TOKEN_ID = 50265
|
||||
|
||||
|
||||
class Blip2ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, num_channels, height, width)"""
|
||||
|
||||
|
||||
class Blip2ImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
|
||||
|
||||
|
||||
class Blip2QFormerMultiHeadAttention(nn.Module):
|
||||
|
||||
@ -375,20 +398,6 @@ class Blip2QFormerModel(nn.Module):
|
||||
return sequence_output
|
||||
|
||||
|
||||
class Blip2ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, num_channels, height, width)"""
|
||||
|
||||
|
||||
Blip2ImageInputs = Blip2ImagePixelInputs
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
# defined on the HuggingFace repo
|
||||
BLIP2_IMAGE_TOKEN = "<image>"
|
||||
BLIP2_IMAGE_TOKEN_ID = 50265
|
||||
|
||||
|
||||
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
@ -506,18 +515,31 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None:
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Blip2ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
return Blip2ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return Blip2ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
|
||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
@ -538,6 +560,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: Blip2ImageInputs) -> torch.Tensor:
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_model is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
|
||||
|
||||
@ -88,7 +88,13 @@ def input_processor_for_clip(
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
image_feature_size = image_data.shape[0]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
|
||||
@ -234,7 +234,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
|
||||
if isinstance(image_patches, torch.Tensor):
|
||||
@ -249,6 +250,13 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
||||
data=image_patches)
|
||||
return None
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
|
||||
return vision_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -261,8 +269,7 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings, _ = self.vision_embed_tokens(
|
||||
image_input["data"])
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||
vision_embeddings,
|
||||
|
||||
@ -50,6 +50,19 @@ class InternVLImagePixelInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class InternVLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
InternVLImageInputs = Union[InternVLImagePixelInputs,
|
||||
InternVLImageEmbeddingInputs]
|
||||
|
||||
|
||||
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def build_transform(input_size):
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
@ -193,8 +206,10 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
# add thumbnail image if num_blocks > 1
|
||||
if hf_config.use_thumbnail and num_blocks > 1:
|
||||
num_blocks += 1
|
||||
image_feature_size = num_blocks * num_patches
|
||||
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
image_feature_size = image_data.shape[0]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
@ -205,7 +220,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END
|
||||
image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
|
||||
new_prompt = prompt.replace('<image>', image_prompt, 1)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
|
||||
@ -378,23 +393,49 @@ class InternVLChatModel(nn.Module, SupportsVision):
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[InternVLImagePixelInputs]:
|
||||
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_token_id = kwargs.pop("image_token_id", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None:
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return InternVLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
self.img_context_token_id = image_token_id[0]
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: InternVLImageInputs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_model is not None
|
||||
image_embeds = self.extract_feature(image_input["data"])
|
||||
|
||||
return image_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -409,9 +450,9 @@ class InternVLChatModel(nn.Module, SupportsVision):
|
||||
if image_input is not None:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
vit_embeds = self.extract_feature(image_input["data"])
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||
vit_embeds,
|
||||
vision_embeddings,
|
||||
self.img_context_token_id)
|
||||
input_ids = None
|
||||
else:
|
||||
|
||||
@ -27,6 +27,24 @@ from .utils import (filter_weights, init_vllm_registered_model,
|
||||
merge_vision_embeddings)
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class LlavaImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
|
||||
|
||||
|
||||
# TODO(xwjiang): Run benchmark and decide if TP.
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
|
||||
@ -49,15 +67,6 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
LlavaImageInputs = LlavaImagePixelInputs
|
||||
|
||||
|
||||
def get_max_llava_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
@ -210,18 +219,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None:
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return LlavaImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
@ -258,6 +279,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: LlavaImageInputs) -> torch.Tensor:
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_tower is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
@ -60,7 +60,17 @@ class LlavaNextImagePixelInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
LlavaNextImageInputs = LlavaNextImagePixelInputs
|
||||
class LlavaNextImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
||||
LlavaNextImageEmbeddingInputs]
|
||||
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
||||
@ -208,7 +218,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
input_width=width,
|
||||
)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
image_feature_size = image_data.shape[0]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
@ -320,26 +330,40 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
|
||||
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None:
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(image_sizes, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
if not isinstance(image_sizes, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
image_sizes=self._validate_image_sizes(image_sizes),
|
||||
)
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
image_sizes=self._validate_image_sizes(image_sizes),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeds. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return LlavaNextImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
@ -466,6 +490,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self,
|
||||
image_input: LlavaNextImageInputs,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
return [image_input["data"]]
|
||||
|
||||
patch_embeddings = self._process_image_pixels(image_input)
|
||||
|
||||
image_sizes = image_input.get("image_sizes")
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -31,6 +31,25 @@ _KEYS_TO_MODIFY_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
class PaliGemmaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, num_channels, height, width)"""
|
||||
|
||||
|
||||
class PaliGemmaImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
|
||||
PaliGemmaImageEmbeddingInputs]
|
||||
|
||||
|
||||
def get_max_paligemma_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(PaliGemmaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
@ -107,15 +126,6 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PaliGemmaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, num_channels, height, width)"""
|
||||
|
||||
|
||||
PaliGemmaImageInputs = PaliGemmaImagePixelInputs
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
||||
@ -163,18 +173,30 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None:
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
return PaliGemmaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
|
||||
return PaliGemmaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return PaliGemmaImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
@ -187,26 +209,20 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
return image_features
|
||||
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
inputs: PaliGemmaImagePixelInputs,
|
||||
) -> torch.Tensor:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
|
||||
return self._image_pixels_to_features(
|
||||
self.vision_tower,
|
||||
pixel_values,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: PaliGemmaImageInputs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_tower is not None
|
||||
image_features = self._process_image_pixels(image_input, )
|
||||
pixel_values = image_input["data"]
|
||||
image_features = self._image_pixels_to_features(
|
||||
self.vision_tower,
|
||||
pixel_values,
|
||||
)
|
||||
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
|
||||
@ -70,6 +70,36 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
||||
projection_dim=768)
|
||||
|
||||
|
||||
class Phi3VImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch, in which case
|
||||
the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
|
||||
class Phi3VImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
|
||||
|
||||
|
||||
class Phi3ImageEmbeddingBase(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -257,24 +287,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
return image_features_hd_newline
|
||||
|
||||
|
||||
class Phi3VImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch, in which case
|
||||
the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
||||
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
target_height = int(np.ceil(height / padding_unit) * padding_unit)
|
||||
@ -390,7 +402,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
input_width=w,
|
||||
input_height=h)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
image_feature_size = image_data.shape[0]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
@ -494,25 +506,55 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
|
||||
self, **kwargs: object) -> Optional[Phi3VImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if not isinstance(image_sizes, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Phi3VImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
image_sizes=self._validate_image_sizes(image_sizes))
|
||||
if not isinstance(image_sizes, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return Phi3VImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
image_sizes=self._validate_image_sizes(image_sizes))
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return Phi3VImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Phi3VImageInputs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
image_embeds = self.vision_embed_tokens(image_input["data"],
|
||||
image_input["image_sizes"])
|
||||
|
||||
return image_embeds
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -524,8 +566,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings = self.vision_embed_tokens(
|
||||
image_input["data"], image_input["image_sizes"])
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||
vision_embeddings,
|
||||
|
||||
@ -97,7 +97,13 @@ def input_processor_for_siglip(
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_siglip_image_feature_size(hf_config)
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
image_feature_size = get_siglip_image_feature_size(hf_config)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
image_feature_size = image_data.shape[0]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
|
||||
@ -115,6 +115,7 @@ class ImagePlugin(MultiModalPlugin):
|
||||
data: object) -> MultiModalInputs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
# PIL image
|
||||
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
||||
image_processor = self._get_hf_image_processor(model_config)
|
||||
if image_processor is None:
|
||||
@ -129,8 +130,10 @@ class ImagePlugin(MultiModalPlugin):
|
||||
raise
|
||||
|
||||
return MultiModalInputs(batch_data)
|
||||
|
||||
# Image embedding
|
||||
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
return MultiModalInputs({"image_embeds": data})
|
||||
|
||||
raise TypeError(f"Invalid image type: {type(data)}")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user