mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
642 lines
25 KiB
Python
642 lines
25 KiB
Python
import itertools
|
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
|
TypedDict, Union)
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from PIL import Image
|
|
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
|
|
from transformers.models.llava_next.modeling_llava_next import (
|
|
get_anyres_image_grid_shape, unpad_image)
|
|
from typing_extensions import NotRequired
|
|
|
|
from vllm.attention import AttentionMetadata
|
|
from vllm.config import CacheConfig, MultiModalConfig
|
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
|
|
|
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
|
dummy_seq_data_for_clip, get_clip_image_feature_size,
|
|
get_clip_patch_grid_length, input_processor_for_clip)
|
|
from .interfaces import SupportsMultiModal
|
|
from .llava import LlavaMultiModalProjector
|
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
|
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
|
get_siglip_patch_grid_length, input_processor_for_siglip)
|
|
from .utils import (filter_weights, init_vllm_registered_model,
|
|
merge_multimodal_embeddings)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
_KEYS_TO_MODIFY_MAPPING = {
|
|
"language_model.lm_head": "lm_head",
|
|
"language_model.model": "language_model",
|
|
}
|
|
|
|
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
|
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
|
|
|
|
|
class LlavaNextImagePixelInputs(TypedDict):
|
|
type: Literal["pixel_values"]
|
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
|
"""
|
|
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
|
|
|
Note that `num_patches` may be different for each batch, in which case
|
|
the data is passed as a list instead of a batched tensor.
|
|
"""
|
|
|
|
image_sizes: NotRequired[torch.Tensor]
|
|
"""
|
|
Shape: `(batch_size, 2)`
|
|
|
|
This should be in `(height, width)` format.
|
|
"""
|
|
|
|
|
|
class LlavaNextImageEmbeddingInputs(TypedDict):
|
|
type: Literal["image_embeds"]
|
|
data: torch.Tensor
|
|
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
|
|
|
`hidden_size` must match the hidden size of language model backbone.
|
|
"""
|
|
|
|
|
|
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
|
LlavaNextImageEmbeddingInputs]
|
|
|
|
|
|
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
|
def _get_llava_next_num_unpadded_features(
|
|
original_height: int,
|
|
original_width: int,
|
|
npatches: int,
|
|
num_patch_height: int,
|
|
num_patch_width: int,
|
|
) -> Tuple[int, int]:
|
|
current_height = npatches * num_patch_height
|
|
current_width = npatches * num_patch_width
|
|
|
|
aspect_ratio = original_width / original_height
|
|
current_aspect_ratio = current_width / current_height
|
|
|
|
if aspect_ratio > current_aspect_ratio:
|
|
new_height = (original_height * current_width) // original_width
|
|
padding = (current_height - new_height) // 2
|
|
current_height -= padding * 2
|
|
else:
|
|
new_width = (original_width * current_height) // original_height
|
|
padding = (current_width - new_width) // 2
|
|
current_width -= padding * 2
|
|
|
|
unpadded_features = current_height * current_width
|
|
newline_features = current_height
|
|
return (unpadded_features, newline_features)
|
|
|
|
|
|
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
|
|
def get_llava_next_image_feature_size(
|
|
hf_config: LlavaNextConfig,
|
|
*,
|
|
input_height: int,
|
|
input_width: int,
|
|
) -> int:
|
|
vision_config = hf_config.vision_config
|
|
|
|
if isinstance(vision_config, CLIPVisionConfig):
|
|
num_patches = get_clip_patch_grid_length(
|
|
image_size=vision_config.image_size,
|
|
patch_size=vision_config.patch_size,
|
|
)
|
|
base_feature_size = get_clip_image_feature_size(vision_config)
|
|
elif isinstance(vision_config, SiglipVisionConfig):
|
|
num_patches = get_siglip_patch_grid_length(
|
|
image_size=vision_config.image_size,
|
|
patch_size=vision_config.patch_size,
|
|
)
|
|
base_feature_size = get_siglip_image_feature_size(vision_config)
|
|
else:
|
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
|
raise NotImplementedError(msg)
|
|
|
|
strategy = hf_config.vision_feature_select_strategy
|
|
if strategy == "default":
|
|
base_feature_size -= 1
|
|
elif strategy == "full":
|
|
pass
|
|
else:
|
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
|
|
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
|
image_size=(input_height, input_width),
|
|
grid_pinpoints=hf_config.image_grid_pinpoints,
|
|
patch_size=vision_config.image_size,
|
|
)
|
|
|
|
(
|
|
unpadded_feature_size,
|
|
newline_feature_size,
|
|
) = _get_llava_next_num_unpadded_features(input_height, input_width,
|
|
num_patches, num_patch_height,
|
|
num_patch_width)
|
|
|
|
return unpadded_feature_size + newline_feature_size + base_feature_size
|
|
|
|
|
|
def get_max_llava_next_image_tokens(ctx: InputContext):
|
|
return get_llava_next_image_feature_size(
|
|
ctx.get_hf_config(LlavaNextConfig),
|
|
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
|
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
|
)
|
|
|
|
|
|
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
|
|
mm_counts: Mapping[str, int]):
|
|
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
|
vision_config = hf_config.vision_config
|
|
num_images = mm_counts["image"]
|
|
|
|
image_feature_size = get_max_llava_next_image_tokens(ctx)
|
|
|
|
if isinstance(vision_config, CLIPVisionConfig):
|
|
seq_data = dummy_seq_data_for_clip(
|
|
vision_config,
|
|
seq_len,
|
|
num_images,
|
|
image_token_id=hf_config.image_token_index,
|
|
image_feature_size_override=image_feature_size,
|
|
)
|
|
|
|
mm_data = dummy_image_for_clip(
|
|
vision_config,
|
|
num_images,
|
|
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
|
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
|
)
|
|
|
|
return seq_data, mm_data
|
|
elif isinstance(vision_config, SiglipVisionConfig):
|
|
seq_data = dummy_seq_data_for_siglip(
|
|
vision_config,
|
|
seq_len,
|
|
num_images,
|
|
image_token_id=hf_config.image_token_index,
|
|
image_feature_size_override=image_feature_size,
|
|
)
|
|
|
|
mm_data = dummy_image_for_siglip(
|
|
vision_config,
|
|
num_images,
|
|
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
|
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
|
)
|
|
|
|
return seq_data, mm_data
|
|
|
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
|
return llm_inputs
|
|
|
|
model_config = ctx.model_config
|
|
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
|
vision_config = hf_config.vision_config
|
|
|
|
image_data = multi_modal_data["image"]
|
|
if isinstance(image_data, Image.Image):
|
|
width, height = image_data.size
|
|
|
|
image_feature_size = get_llava_next_image_feature_size(
|
|
hf_config,
|
|
input_height=height,
|
|
input_width=width,
|
|
)
|
|
elif isinstance(image_data, torch.Tensor):
|
|
image_feature_size = image_data.shape[0]
|
|
else:
|
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
|
|
|
vision_config = hf_config.vision_config
|
|
|
|
if isinstance(vision_config, CLIPVisionConfig):
|
|
return input_processor_for_clip(
|
|
model_config,
|
|
vision_config,
|
|
llm_inputs,
|
|
image_token_id=hf_config.image_token_index,
|
|
image_feature_size_override=image_feature_size,
|
|
)
|
|
elif isinstance(vision_config, SiglipVisionConfig):
|
|
return input_processor_for_siglip(
|
|
model_config,
|
|
vision_config,
|
|
llm_inputs,
|
|
image_token_id=hf_config.image_token_index,
|
|
image_feature_size_override=image_feature_size,
|
|
)
|
|
|
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
def _init_vision_tower(hf_config: LlavaNextConfig):
|
|
vision_config = hf_config.vision_config
|
|
|
|
# Initialize the vision tower only up to the required feature layer
|
|
vision_feature_layer = hf_config.vision_feature_layer
|
|
if vision_feature_layer < 0:
|
|
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
|
|
+ vision_feature_layer + 1
|
|
else:
|
|
num_hidden_layers = vision_feature_layer + 1
|
|
|
|
if isinstance(vision_config, CLIPVisionConfig):
|
|
return CLIPVisionModel(
|
|
vision_config,
|
|
num_hidden_layers_override=num_hidden_layers,
|
|
)
|
|
elif isinstance(vision_config, SiglipVisionConfig):
|
|
return SiglipVisionModel(
|
|
vision_config,
|
|
num_hidden_layers_override=num_hidden_layers,
|
|
)
|
|
|
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
|
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
|
|
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|
|
|
def __init__(self,
|
|
config: LlavaNextConfig,
|
|
multimodal_config: MultiModalConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
|
|
# TODO: Optionally initializes this for supporting embeddings.
|
|
self.vision_tower = _init_vision_tower(config)
|
|
self.multi_modal_projector = LlavaMultiModalProjector(
|
|
vision_hidden_size=config.vision_config.hidden_size,
|
|
text_hidden_size=config.text_config.hidden_size,
|
|
projector_hidden_act=config.projector_hidden_act)
|
|
|
|
self.language_model = init_vllm_registered_model(
|
|
config.text_config, cache_config, quant_config)
|
|
|
|
self.image_newline = nn.Parameter(
|
|
torch.empty(config.text_config.hidden_size))
|
|
|
|
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
|
if list(data.shape[1:]) != [2]:
|
|
raise ValueError(
|
|
f"The expected image sizes shape is batch dimension plus "
|
|
f"{[2]}. You supplied {data.shape}.")
|
|
|
|
return data
|
|
|
|
def _validate_pixel_values(
|
|
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
|
|
h = w = self.config.vision_config.image_size
|
|
expected_dims = (3, h, w)
|
|
|
|
def _validate_shape(d: torch.Tensor):
|
|
actual_dims = tuple(d.shape[1:])
|
|
|
|
if actual_dims != expected_dims:
|
|
expected_expr = ("num_patches", *map(str, expected_dims))
|
|
raise ValueError(
|
|
"The expected shape of pixel values in each batch element "
|
|
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
|
|
|
for d in data:
|
|
_validate_shape(d)
|
|
|
|
return data
|
|
|
|
def _parse_and_validate_image_input(
|
|
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
image_sizes = kwargs.pop("image_sizes", None)
|
|
image_embeds = kwargs.pop("image_embeds", None)
|
|
|
|
if pixel_values is None and image_embeds is None:
|
|
return None
|
|
|
|
if pixel_values is not None:
|
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of pixel values. "
|
|
f"Got type: {type(pixel_values)}")
|
|
|
|
if not isinstance(image_sizes, torch.Tensor):
|
|
raise ValueError("Incorrect type of image sizes. "
|
|
f"Got type: {type(image_sizes)}")
|
|
|
|
return LlavaNextImagePixelInputs(
|
|
type="pixel_values",
|
|
data=self._validate_pixel_values(pixel_values),
|
|
image_sizes=self._validate_image_sizes(image_sizes),
|
|
)
|
|
|
|
if image_embeds is not None:
|
|
if not isinstance(image_embeds, torch.Tensor):
|
|
raise ValueError("Incorrect type of image embeds. "
|
|
f"Got type: {type(image_embeds)}")
|
|
|
|
return LlavaNextImageEmbeddingInputs(
|
|
type="image_embeds",
|
|
data=image_embeds,
|
|
)
|
|
|
|
raise AssertionError("This line should be unreachable.")
|
|
|
|
def _select_image_features(self, image_features: torch.Tensor, *,
|
|
strategy: str) -> torch.Tensor:
|
|
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
|
|
if strategy == "default":
|
|
return image_features[:, 1:]
|
|
elif strategy == "full":
|
|
return image_features
|
|
|
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
|
|
|
def _image_pixels_to_features(
|
|
self,
|
|
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
|
pixel_values: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
|
|
# NOTE: we skip the step to select the vision feature layer since
|
|
# this is already done inside the vision tower
|
|
image_features = vision_tower(pixel_values)
|
|
|
|
return self._select_image_features(
|
|
image_features,
|
|
strategy=self.config.vision_feature_select_strategy,
|
|
)
|
|
|
|
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
|
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
|
|
patch_embeddings: torch.Tensor, *,
|
|
strategy: str) -> torch.Tensor:
|
|
if strategy == "flat":
|
|
return patch_embeddings.flatten(0, 1)
|
|
|
|
if strategy.startswith("spatial"):
|
|
height = width = self.config.vision_config.image_size \
|
|
// self.config.vision_config.patch_size
|
|
|
|
base_patch_embeds = patch_embeddings[0]
|
|
if height * width != base_patch_embeds.shape[0]:
|
|
raise ValueError(
|
|
"The number of patches is not consistent with the "
|
|
"image size.")
|
|
|
|
if patch_embeddings.shape[0] > 1:
|
|
other_patch_embeds = patch_embeddings[1:]
|
|
|
|
# Move to CPU to avoid floating-point errors
|
|
orig_height, orig_width = image_size.tolist()
|
|
|
|
# image_aspect_ratio == "anyres"
|
|
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
|
(orig_height, orig_width),
|
|
self.config.image_grid_pinpoints,
|
|
self.config.vision_config.image_size,
|
|
)
|
|
other_patch_embeds = other_patch_embeds \
|
|
.view(num_patch_height, num_patch_width, height, width, -1)
|
|
|
|
if "unpad" in strategy:
|
|
other_patch_embeds = other_patch_embeds \
|
|
.permute(4, 0, 2, 1, 3).contiguous() \
|
|
.flatten(1, 2).flatten(2, 3)
|
|
other_patch_embeds = unpad_image(other_patch_embeds,
|
|
(orig_height, orig_width))
|
|
other_patch_embeds = torch.cat((
|
|
other_patch_embeds,
|
|
self.image_newline[:, None, None] \
|
|
.expand(*other_patch_embeds.shape[:-1], 1) \
|
|
.to(other_patch_embeds.device),
|
|
), dim=-1)
|
|
other_patch_embeds = other_patch_embeds \
|
|
.flatten(1, 2).transpose(0, 1)
|
|
else:
|
|
other_patch_embeds = other_patch_embeds \
|
|
.permute(0, 2, 1, 3, 4).contiguous() \
|
|
.flatten(0, 3)
|
|
|
|
merged_patch_embeddings = torch.cat(
|
|
(base_patch_embeds, other_patch_embeds), dim=0)
|
|
else:
|
|
if "unpad" in strategy:
|
|
merged_patch_embeddings = torch.cat(
|
|
(base_patch_embeds,
|
|
self.image_newline[None] \
|
|
.to(base_patch_embeds.device)
|
|
), dim=0)
|
|
else:
|
|
merged_patch_embeddings = base_patch_embeds
|
|
|
|
return merged_patch_embeddings
|
|
|
|
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
|
|
|
|
def _process_image_pixels(
|
|
self,
|
|
inputs: LlavaNextImagePixelInputs,
|
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
assert self.vision_tower is not None
|
|
|
|
pixel_values = inputs["data"]
|
|
|
|
if isinstance(pixel_values, torch.Tensor):
|
|
b, num_patches, c, h, w = pixel_values.shape
|
|
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
|
|
stacked_image_features = self._image_pixels_to_features(
|
|
self.vision_tower, stacked_pixel_values)
|
|
stacked_patch_embeddings = self.multi_modal_projector(
|
|
stacked_image_features)
|
|
|
|
return stacked_patch_embeddings.view(
|
|
b, num_patches, *stacked_patch_embeddings.shape[1:])
|
|
|
|
num_patches_per_batch = [v.shape[0] for v in pixel_values]
|
|
stacked_pixel_values = torch.cat(pixel_values)
|
|
stacked_image_features = self._image_pixels_to_features(
|
|
self.vision_tower, stacked_pixel_values)
|
|
|
|
return [
|
|
self.multi_modal_projector(image_features) for image_features in
|
|
torch.split(stacked_image_features, num_patches_per_batch)
|
|
]
|
|
|
|
def _process_image_input(
|
|
self,
|
|
image_input: LlavaNextImageInputs,
|
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
|
|
if image_input["type"] == "image_embeds":
|
|
return [image_input["data"]]
|
|
|
|
patch_embeddings = self._process_image_pixels(image_input)
|
|
|
|
image_sizes = image_input.get("image_sizes")
|
|
if image_sizes is None:
|
|
batch_size = len(image_input["data"])
|
|
vision_config = self.config.vision_config
|
|
default_height = default_width = vision_config.image_size
|
|
image_sizes = torch.as_tensor([[default_height, default_width]
|
|
for _ in range(batch_size)])
|
|
|
|
return [
|
|
self._merge_image_patch_embeddings(image_sizes[i],
|
|
patch_features_batch,
|
|
strategy="spatial_unpad")
|
|
for i, patch_features_batch in enumerate(patch_embeddings)
|
|
]
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
**kwargs: object,
|
|
) -> SamplerOutput:
|
|
"""Run forward pass for LlaVA-NeXT.
|
|
|
|
One key thing to understand is the `input_ids` already accounts for the
|
|
positions of the to-be-inserted image embeddings.
|
|
|
|
Concretely, consider a text prompt:
|
|
`"A chat between a curious human and an artificial intelligence
|
|
assistant. The assistant gives helpful, detailed, and polite answers to
|
|
the human's questions.
|
|
USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
|
|
|
|
Tokenizer outputs:
|
|
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
|
|
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
|
|
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
|
|
29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
|
|
9047, 13566, 29901]`.
|
|
|
|
To reserve space in KV cache, we have to insert placeholder tokens
|
|
before they are inputted to the model, so the input processor prepends
|
|
additional image tokens (denoted as `32000`), resulting in:
|
|
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
|
|
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
|
|
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
|
|
29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
|
|
319, 1799, 9047, 13566, 29901]`.
|
|
|
|
Unlike in LLaVA-1.5, the number of image tokens inputted to the language
|
|
model depends on the original size of the input image. Including the
|
|
original image token in the input, the required number of image tokens
|
|
is given by :func:`get_llava_next_image_feature_size`.
|
|
|
|
This way, the `positions` and `attn_metadata` are consistent
|
|
with the `input_ids`.
|
|
|
|
Args:
|
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
|
batch.
|
|
pixel_values: The pixels in each grid patch for each input image.
|
|
image_sizes: The original `(height, width)` for each input image.
|
|
|
|
See also:
|
|
:class:`LlavaNextImageInputs`
|
|
"""
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
|
|
if image_input is not None:
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
|
input_ids)
|
|
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids, inputs_embeds, vision_embeddings,
|
|
self.config.image_token_index)
|
|
|
|
input_ids = None
|
|
else:
|
|
inputs_embeds = None
|
|
|
|
hidden_states = self.language_model.model(input_ids,
|
|
positions,
|
|
kv_caches,
|
|
attn_metadata,
|
|
None,
|
|
inputs_embeds=inputs_embeds)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
return self.language_model.compute_logits(hidden_states,
|
|
sampling_metadata)
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
return self.language_model.sample(logits, sampling_metadata)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
# prepare weight iterators for components
|
|
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
|
|
weights, 4)
|
|
|
|
# load vision encoder
|
|
vit_weights = filter_weights(vit_weights, "vision_tower")
|
|
self.vision_tower.load_weights(vit_weights)
|
|
|
|
# load mlp projector
|
|
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
|
for name, loaded_weight in mlp_weights:
|
|
param = mlp_params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
# load newline
|
|
newline_weights = filter_weights(newline_weights, "image_newline")
|
|
for name, loaded_weight in newline_weights:
|
|
assert name == ""
|
|
param = self.image_newline
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
# load llm backbone
|
|
llm_weights = filter_weights(llm_weights, "language_model")
|
|
self.language_model.load_weights(llm_weights)
|