mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-11 12:24:31 +08:00
1805 lines
69 KiB
Python
1805 lines
69 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
import math
|
|
import re
|
|
from functools import lru_cache
|
|
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Tuple,
|
|
TypedDict, Union)
|
|
|
|
import numpy as np
|
|
import scipy.signal
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.transforms as T
|
|
from PIL import Image
|
|
from transformers import PretrainedConfig, SiglipVisionConfig
|
|
from transformers.utils import logging
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import get_pp_group
|
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
|
InputContext)
|
|
from vllm.inputs.data import TokenInputs, token_inputs
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
|
from vllm.model_executor.models.llama import LlamaModel
|
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
|
from vllm.sequence import IntermediateTensors, SequenceData
|
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
|
|
|
from .idefics2_vision_model import Idefics2VisionTransformer
|
|
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsV0Only
|
|
from .phi4mm_audio import AudioEmbedding
|
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
|
|
|
# <|endoftext10|> (see vocab.json in hf model)
|
|
_IMAGE_PLACEHOLDER_TOKEN_ID = 200010
|
|
# <|endoftext11|>
|
|
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011
|
|
|
|
_AUDIO_MAX_SOUNDFILE_SIZE = 241_000
|
|
DUMMY_SAMPLING_FREQUENCY = 16_000 # kHz
|
|
|
|
DYNAMIC_HD = 16
|
|
AUDIO_TOKEN_PATTERN = r"<\|audio_(\d+)\|>"
|
|
IMAGE_TOKEN_PATTERN = r"<\|image_(\d+)\|>"
|
|
|
|
SIGLIP_NAME = "siglip-so400m-patch14-448"
|
|
VISION_ENCODER_TO_PROCESSING_CONFIG = {
|
|
'siglip-so400m-patch14-448': {
|
|
'dynamic_hd': 16,
|
|
'vit_image_size': 448,
|
|
'vit_patch_size': 14,
|
|
'token_compression_factor': 2,
|
|
},
|
|
}
|
|
logger = logging.get_logger(__name__)
|
|
# This is a workaround to prevent text (user input) + audio + image
|
|
# from being used in the same prompt.
|
|
# It includes token ids for "/n" and tokens in added_tokens_decoder
|
|
# from the tokenizer_confg.json file.
|
|
NON_USER_INPUT_TOKENS = {
|
|
198, 200010, 200011, 199999, 200018, 200019, 200020, 200021, 200022,
|
|
200023, 200024, 200025, 200026, 200027, 200028
|
|
}
|
|
|
|
|
|
def get_max_dummy_image(ctx: InputContext):
|
|
hf_config = ctx.get_hf_config()
|
|
vision_encoder_name = hf_config.img_processor
|
|
if vision_encoder_name is None:
|
|
vision_encoder_name = SIGLIP_NAME
|
|
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
|
|
dynamic_hd_size = prepro_config['dynamic_hd']
|
|
vit_image_size = prepro_config['vit_image_size']
|
|
|
|
max_side = vit_image_size * dynamic_hd_size
|
|
dummy_image = dummy_image_for_phi4mm(vit_image_size, max_side)
|
|
return dummy_image
|
|
|
|
|
|
# image token length
|
|
def get_max_phi4mm_image_tokens(ctx: InputContext):
|
|
dummy_image = get_max_dummy_image(ctx)
|
|
|
|
hf_config = ctx.get_hf_config()
|
|
vision_encoder_name = hf_config.img_processor
|
|
if vision_encoder_name is None:
|
|
vision_encoder_name = SIGLIP_NAME
|
|
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
|
|
dynamic_hd_size = prepro_config['dynamic_hd']
|
|
vit_image_size = prepro_config['vit_image_size']
|
|
vit_patch_size = prepro_config['vit_patch_size']
|
|
token_compression_factor = prepro_config['token_compression_factor']
|
|
|
|
image_num_tokens = _compute_num_image_tokens(dummy_image, dynamic_hd_size,
|
|
vit_image_size,
|
|
vit_patch_size,
|
|
token_compression_factor)
|
|
return image_num_tokens
|
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
|
image_size):
|
|
best_ratio_diff = float('inf')
|
|
best_ratio = (1, 1)
|
|
area = width * height
|
|
for ratio in target_ratios:
|
|
target_aspect_ratio = ratio[0] / ratio[1]
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
if ratio_diff < best_ratio_diff:
|
|
best_ratio_diff = ratio_diff
|
|
best_ratio = ratio
|
|
elif ratio_diff == best_ratio_diff:
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
|
best_ratio = ratio
|
|
return best_ratio
|
|
|
|
|
|
def _find_target_aspect_ratio(image, image_size, max_num, min_num):
|
|
orig_width, orig_height = image.size
|
|
|
|
w_crop_num = math.ceil(orig_width / float(image_size))
|
|
h_crop_num = math.ceil(orig_height / float(image_size))
|
|
if w_crop_num * h_crop_num > max_num:
|
|
aspect_ratio = orig_width / orig_height
|
|
|
|
# calculate the existing image aspect ratio
|
|
target_ratios = set((i, j) for i in range(1, max_num + 1)
|
|
for j in range(1, max_num + 1)
|
|
if i * j <= max_num and i * j >= min_num)
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
|
|
# find the closest aspect ratio to the target
|
|
target_aspect_ratio = find_closest_aspect_ratio(
|
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
|
|
|
# calculate the target width and height
|
|
target_width = image_size * target_aspect_ratio[0]
|
|
target_height = image_size * target_aspect_ratio[1]
|
|
logger.debug("target_aspect_ratio: %s", target_aspect_ratio)
|
|
else:
|
|
target_width = image_size * w_crop_num
|
|
target_height = image_size * h_crop_num
|
|
target_aspect_ratio = (w_crop_num, h_crop_num)
|
|
return target_aspect_ratio, target_height, target_width
|
|
|
|
|
|
def _get_padding_size(image, target_height, target_width):
|
|
orig_width, orig_height = image.size
|
|
ratio_width = target_width / orig_width
|
|
ratio_height = target_height / orig_height
|
|
|
|
if ratio_width < ratio_height:
|
|
padding_width = 0
|
|
padding_height = target_height - int(orig_height * ratio_width)
|
|
else:
|
|
padding_width = target_width - int(orig_width * ratio_height)
|
|
padding_height = 0
|
|
return padding_height, padding_width
|
|
|
|
|
|
def dynamic_preprocess(image,
|
|
min_num=1,
|
|
max_num=12,
|
|
image_size=384,
|
|
mask_size=27):
|
|
target_aspect_ratio, target_height, target_width =\
|
|
_find_target_aspect_ratio(
|
|
image, image_size, max_num, min_num)
|
|
padding_height, padding_width = _get_padding_size(image, target_height,
|
|
target_width)
|
|
|
|
# Calculate the ratio
|
|
orig_width, orig_height = image.size
|
|
ratio_width = target_width / orig_width
|
|
ratio_height = target_height / orig_height
|
|
if ratio_width < ratio_height:
|
|
new_size = (target_width, int(orig_height * ratio_width))
|
|
else:
|
|
new_size = (int(orig_width * ratio_height), target_height)
|
|
|
|
attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]),
|
|
int(mask_size * target_aspect_ratio[0])))
|
|
if padding_width >= 14:
|
|
attention_mask[:, -math.floor(padding_width / 14):] = 0
|
|
if padding_height >= 14:
|
|
attention_mask[-math.floor(padding_height / 14):, :] = 0
|
|
assert attention_mask.sum(
|
|
) > 0, f'attention mask is empty {attention_mask}'
|
|
|
|
if min(new_size[1], target_height) < 10 or min(new_size[0],
|
|
target_width) < 10:
|
|
raise ValueError(f'the aspect ratio is very extreme {new_size}')
|
|
|
|
image = T.functional.resize(
|
|
image,
|
|
[new_size[1], new_size[0]],
|
|
)
|
|
|
|
resized_img = T.functional.pad(image,
|
|
[0, 0, padding_width, padding_height],
|
|
fill=[255, 255, 255])
|
|
|
|
return resized_img, attention_mask
|
|
|
|
|
|
def pad_to_max_num_crops(images, max_crops=5):
|
|
"""
|
|
images: B x 3 x H x W, B<=max_crops
|
|
"""
|
|
B, _, H, W = images.shape
|
|
if max_crops > B:
|
|
pad = torch.zeros(max_crops - B,
|
|
3,
|
|
H,
|
|
W,
|
|
dtype=images.dtype,
|
|
device=images.device)
|
|
images = torch.cat([images, pad], dim=0)
|
|
return images
|
|
|
|
|
|
def pad_mask_to_max_num_crops(masks, max_crops=5):
|
|
B, H, W = masks.shape
|
|
if max_crops > B:
|
|
pad = torch.ones(max_crops - B,
|
|
H,
|
|
W,
|
|
dtype=masks.dtype,
|
|
device=masks.device)
|
|
masks = torch.cat([masks, pad], dim=0)
|
|
return masks
|
|
|
|
|
|
def preprocess(images, dynamic_hd_size, vit_resolution, vit_patch_size):
|
|
|
|
# Basic settings.
|
|
img_processor = T.Compose([
|
|
T.ToTensor(),
|
|
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
|
])
|
|
# Dynamic HD
|
|
base_resolution = vit_resolution
|
|
images = [image.convert('RGB') for image in images]
|
|
# cover 384 and 448 resolution
|
|
mask_resolution = base_resolution // vit_patch_size
|
|
elems, image_attention_masks = [], []
|
|
for im in images:
|
|
elem, attention_mask = dynamic_preprocess(im,
|
|
max_num=dynamic_hd_size,
|
|
image_size=base_resolution,
|
|
mask_size=mask_resolution)
|
|
elems.append(elem)
|
|
image_attention_masks.append(attention_mask)
|
|
hd_images = [img_processor(im) for im in elems]
|
|
global_image = [
|
|
torch.nn.functional.interpolate(
|
|
im.unsqueeze(0).float(),
|
|
size=(base_resolution, base_resolution),
|
|
mode='bicubic',
|
|
).to(im.dtype) for im in hd_images
|
|
]
|
|
shapes = [[im.size(1), im.size(2)] for im in hd_images]
|
|
mask_shapes = [[mask.size(0), mask.size(1)]
|
|
for mask in image_attention_masks]
|
|
global_attention_mask = [
|
|
torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images
|
|
]
|
|
hd_images_reshape = [
|
|
im.reshape(1, 3, h // base_resolution, base_resolution,
|
|
w // base_resolution, base_resolution).permute(
|
|
0, 2, 4, 1, 3, 5).reshape(-1, 3, base_resolution,
|
|
base_resolution).contiguous()
|
|
for im, (h, w) in zip(hd_images, shapes)
|
|
]
|
|
attention_masks_reshape = [
|
|
mask.reshape(1, h // mask_resolution, mask_resolution,
|
|
w // mask_resolution, mask_resolution).permute(
|
|
0, 1, 3, 2, 4).reshape(-1, mask_resolution,
|
|
mask_resolution).contiguous()
|
|
for mask, (h, w) in zip(image_attention_masks, mask_shapes)
|
|
]
|
|
# NOTE token compression is hard coded here, and odd numbers seems to fail
|
|
downsample_attention_masks = [
|
|
mask[:, 0::2,
|
|
0::2].reshape(1, h // mask_resolution, w // mask_resolution,
|
|
mask_resolution // 2 + mask_resolution % 2,
|
|
mask_resolution // 2 + mask_resolution % 2).permute(
|
|
0, 1, 3, 2, 4)
|
|
for mask, (h, w) in zip(attention_masks_reshape, mask_shapes)
|
|
]
|
|
downsample_attention_masks = [
|
|
mask.reshape(mask.size(1) * mask.size(2),
|
|
mask.size(3) * mask.size(4))
|
|
for mask in downsample_attention_masks
|
|
]
|
|
# NOTE hard coded number of tokens
|
|
num_img_tokens = [
|
|
256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16
|
|
for mask in downsample_attention_masks
|
|
]
|
|
|
|
hd_images_reshape = [
|
|
torch.cat([_global_image] + [_im], dim=0)
|
|
for _global_image, _im in zip(global_image, hd_images_reshape)
|
|
]
|
|
hd_masks_reshape = [
|
|
torch.cat([_global_mask] + [_mask],
|
|
dim=0) for _global_mask, _mask in zip(
|
|
global_attention_mask, attention_masks_reshape)
|
|
]
|
|
max_crops = max([img.size(0) for img in hd_images_reshape])
|
|
image_transformed = [
|
|
pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape
|
|
]
|
|
image_transformed = torch.stack(image_transformed, dim=0)
|
|
mask_transformed = [
|
|
pad_mask_to_max_num_crops(mask, max_crops) \
|
|
for mask in hd_masks_reshape
|
|
]
|
|
mask_transformed = torch.stack(mask_transformed, dim=0)
|
|
|
|
returned_input_image_embeds = image_transformed
|
|
returned_image_sizes = torch.tensor(shapes, dtype=torch.long)
|
|
returned_image_attention_mask = mask_transformed
|
|
returned_num_img_tokens = num_img_tokens
|
|
|
|
data = {
|
|
"pixel_values": returned_input_image_embeds,
|
|
"image_sizes": returned_image_sizes,
|
|
"image_attention_mask": returned_image_attention_mask,
|
|
"num_img_tokens": returned_num_img_tokens,
|
|
}
|
|
return data
|
|
|
|
|
|
def get_navit_vision_model(layer_idx: int = -1, **kwargs):
|
|
vision_config = {
|
|
"hidden_size": 1152,
|
|
"image_size": 448,
|
|
"intermediate_size": 4304,
|
|
"model_type": "siglip_vision_model",
|
|
"num_attention_heads": 16,
|
|
"num_hidden_layers": 27,
|
|
"patch_size": 14,
|
|
}
|
|
|
|
model_config = SiglipVisionConfig(**vision_config, **kwargs)
|
|
if layer_idx < 0:
|
|
num_hidden_layers = model_config.num_hidden_layers \
|
|
+ layer_idx + 1
|
|
else:
|
|
num_hidden_layers = layer_idx + 1
|
|
|
|
vision_model = Idefics2VisionTransformer(
|
|
config=model_config,
|
|
require_post_norm=False,
|
|
num_hidden_layers_override=num_hidden_layers,
|
|
)
|
|
|
|
return vision_model
|
|
|
|
|
|
class Phi4MMImageEncoder(nn.Module):
|
|
"""Image embedding."""
|
|
|
|
def __init__(self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig],
|
|
prefix: str = "",
|
|
model_dir: str = "") -> None:
|
|
super().__init__()
|
|
|
|
# n_embed or hidden_size
|
|
hidden_size = config.n_embd if hasattr(
|
|
config, 'n_embd') else config.hidden_size
|
|
|
|
# layer_idx to output the img features
|
|
if isinstance(config.img_processor, dict):
|
|
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
|
self.type_feature = config.img_processor.get(
|
|
'type_feature', 'patch')
|
|
else:
|
|
self.layer_idx = -2
|
|
self.type_feature = 'patch'
|
|
|
|
self.img_processor = get_navit_vision_model(layer_idx=self.layer_idx)
|
|
|
|
pe_weight = self.img_processor.embeddings.position_embedding.weight
|
|
L, D = pe_weight.size()
|
|
H = int(math.sqrt(L))
|
|
assert H**2 == L, f'position embedding size {L} is not square'
|
|
if H % 2 != 0:
|
|
self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1))
|
|
H += 1
|
|
image_dim_out = D
|
|
# ((448/14)//2)**2
|
|
self.num_img_tokens = (H // 2)**2
|
|
self.base_feat_height_target = H
|
|
|
|
self.image_dim_out = image_dim_out
|
|
self.img_sizes = None
|
|
self.image_attention_mask = None
|
|
|
|
# global_gn and sub_gn for hd transform, serves as line separator
|
|
self.use_hd_transform = True
|
|
self.with_learnable_separator = True
|
|
self.hd_transform_order = "sub_glb"
|
|
self.freeze_img_processor = False
|
|
self.crop_size = 448
|
|
|
|
# image token compression
|
|
self.image_token_compression_cls = 'avg_pool_2d'
|
|
self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2)
|
|
self.base_feat_height_reduction = 1
|
|
self.base_feat_height_target = self.base_feat_height_target // 2
|
|
|
|
# with_hd_transform and with_learnable_separator should have same value
|
|
assert self.use_hd_transform == self.with_learnable_separator, \
|
|
'use_hd_transform and with_learnable_separator should have same value'
|
|
assert self.use_hd_transform, \
|
|
'learnable separator is only for hd transform'
|
|
# 1024 * 4, merge spatial to channel dimension
|
|
self.glb_GN = nn.Parameter(
|
|
torch.zeros([
|
|
1, 1, self.image_dim_out * self.base_feat_height_reduction**2
|
|
]))
|
|
self.sub_GN = nn.Parameter(
|
|
torch.zeros([
|
|
1, 1, 1,
|
|
self.image_dim_out * self.base_feat_height_reduction**2
|
|
]))
|
|
|
|
dim_projection = hidden_size
|
|
depth = 2
|
|
layers = [
|
|
nn.Linear(image_dim_out * self.base_feat_height_reduction**2,
|
|
dim_projection)
|
|
]
|
|
for _ in range(1, depth):
|
|
layers.extend(
|
|
[nn.GELU(),
|
|
nn.Linear(dim_projection, dim_projection)])
|
|
self.img_projection = nn.Sequential(*layers)
|
|
|
|
self.vocab_size = config.vocab_size
|
|
self.img_features = None
|
|
|
|
self.use_out_place_operations = False
|
|
|
|
def get_img_features(self,
|
|
img_embeds: torch.FloatTensor,
|
|
attention_mask=None) -> torch.FloatTensor:
|
|
|
|
img_feature = self.img_processor(img_embeds,
|
|
patch_attention_mask=attention_mask)
|
|
|
|
if self.type_feature == "patch":
|
|
patch_feature = img_feature
|
|
|
|
use_token_compression = self.image_token_compression is not None
|
|
use_padding = getattr(self, 'img_processor_padding',
|
|
None) is not None
|
|
if use_token_compression or use_padding:
|
|
# reshape to 2D tensor
|
|
width = int(math.sqrt(patch_feature.size(1)))
|
|
patch_feature = patch_feature.view(-1, width, width,
|
|
patch_feature.size(-1))
|
|
# convert to NCHW
|
|
patch_feature = patch_feature.permute(0, 3, 1, 2)
|
|
|
|
if use_padding:
|
|
patch_feature = self.img_processor_padding(patch_feature)
|
|
if use_token_compression:
|
|
patch_feature = self.image_token_compression(patch_feature)
|
|
|
|
# convert to NHWC
|
|
patch_feature = patch_feature.permute(0, 2, 3, 1)
|
|
patch_feature = patch_feature.view(
|
|
-1,
|
|
patch_feature.size(1) * patch_feature.size(2),
|
|
patch_feature.size(-1))
|
|
|
|
return patch_feature
|
|
|
|
raise NotImplementedError
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor,
|
|
image_sizes: torch.Tensor,
|
|
image_attention_mask: torch.Tensor) -> torch.FloatTensor:
|
|
"""
|
|
process image and return vision embeddings.
|
|
|
|
pixel_values: (num_images, num_crops, c, h, w)
|
|
image_sizes: [[h1, w1], [h2, w2]]
|
|
image_attention_mask: num_images x num_crops x 32 x 32
|
|
output: (num_images, num_img_tokens, hidden_size)
|
|
"""
|
|
|
|
# eg
|
|
# pixel_values: torch.Size([1, 7, 3, 448, 448])
|
|
# image_sizes: tensor([[ 896, 1344]], device='cuda:0')
|
|
# output: torch.Size([1, 1841, 3072])
|
|
|
|
if isinstance(self.img_projection, nn.Sequential):
|
|
target_device = self.img_projection[0].bias.device
|
|
target_dtype = self.img_projection[0].bias.dtype
|
|
else: # It's a single nn.Linear layer
|
|
target_device = self.img_projection.bias.device
|
|
target_dtype = self.img_projection.bias.dtype
|
|
|
|
img_sizes = image_sizes
|
|
num_images, num_crops, c, h, w = pixel_values.shape
|
|
bs = num_images
|
|
pixel_values = pixel_values.flatten(0, 1)
|
|
|
|
img_features = self.get_img_features(
|
|
pixel_values,
|
|
image_attention_mask.type(torch.BoolTensor).flatten(
|
|
0, 1).to(target_device))
|
|
|
|
base_feat_height_target = self.base_feat_height_target
|
|
base_resolution = self.crop_size
|
|
base_feat_height_reduction = self.base_feat_height_reduction
|
|
|
|
base_feat_height = base_feat_width = int(np.sqrt(
|
|
img_features.shape[1]))
|
|
assert base_feat_height == base_feat_height_target \
|
|
and base_feat_width == base_feat_height_target, \
|
|
f'base_feat_height: {base_feat_height},"\
|
|
f" base_feat_width: {base_feat_width}, "\
|
|
f"expect {base_feat_height_target} features for hd transform'
|
|
|
|
# bs x max_num_crops x (24x24) x C
|
|
img_features = img_features.view(bs, -1,
|
|
base_feat_height * base_feat_width,
|
|
self.image_dim_out)
|
|
C = self.image_dim_out
|
|
H = base_feat_height
|
|
|
|
output_imgs = []
|
|
output_len = []
|
|
# training is tensor, inference is list
|
|
if isinstance(img_sizes, torch.Tensor):
|
|
img_sizes = img_sizes.view(-1, 2)
|
|
for _bs in range(bs):
|
|
h, w = img_sizes[_bs]
|
|
h = h // base_resolution
|
|
w = w // base_resolution
|
|
B_ = h * w
|
|
|
|
# 1 x (24x24) x 1024
|
|
global_img_feature = img_features[_bs, :1]
|
|
|
|
# 1 x 12 x 12 x 4096
|
|
glb_img = global_img_feature.reshape(1, H, H, C).reshape(
|
|
1, H // base_feat_height_reduction, base_feat_height_reduction,
|
|
H // base_feat_height_reduction, base_feat_height_reduction,
|
|
C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape(
|
|
1, H // base_feat_height_reduction,
|
|
H // base_feat_height_reduction,
|
|
base_feat_height_reduction * base_feat_height_reduction *
|
|
C).contiguous()
|
|
temp_glb_GN = self.sub_GN.repeat(1,
|
|
H // base_feat_height_reduction,
|
|
1, 1)
|
|
|
|
# 1 x 156 x 4096
|
|
glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(
|
|
1, -1,
|
|
base_feat_height_reduction * base_feat_height_reduction * C)
|
|
|
|
# (max_num_crops-1) x (12x12) x C
|
|
sub_img = img_features[_bs, 1:]
|
|
# 16x574x1024
|
|
# get rid of padding sub_img
|
|
sub_img = sub_img[:B_]
|
|
|
|
# (num_crops, 12, 2, 12, 2, 1024) ->
|
|
# (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024)
|
|
sub_img = sub_img.reshape(B_, H, H, C).reshape(
|
|
B_, H // base_feat_height_reduction,
|
|
base_feat_height_reduction, H // base_feat_height_reduction,
|
|
base_feat_height_reduction,
|
|
C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape(
|
|
B_, -1, base_feat_height_reduction *
|
|
base_feat_height_reduction * C).contiguous()
|
|
sub_img = sub_img.reshape(
|
|
1, h, w, base_feat_height // base_feat_height_reduction,
|
|
base_feat_width // base_feat_height_reduction,
|
|
-1).permute(0, 1, 3, 2, 4, 5).reshape(
|
|
1, h * base_feat_height // base_feat_height_reduction,
|
|
w * base_feat_width // base_feat_height_reduction,
|
|
base_feat_height_reduction * base_feat_height_reduction *
|
|
C)
|
|
|
|
if image_attention_mask is not None and len(
|
|
image_attention_mask) > 0:
|
|
reshaped_image_attention_mask = image_attention_mask[
|
|
_bs, 1:B_ + 1, 0::2, 0::2].reshape(
|
|
1, h, w,
|
|
base_feat_height // base_feat_height_reduction,
|
|
base_feat_width // base_feat_height_reduction).permute(
|
|
0, 1, 3, 2, 4).reshape(
|
|
1, h * base_feat_height //
|
|
base_feat_height_reduction, w *
|
|
base_feat_width // base_feat_height_reduction)
|
|
useful_height = int(
|
|
reshaped_image_attention_mask[0, :, 0].sum().item())
|
|
useful_width = int(
|
|
reshaped_image_attention_mask[0, 0, :].sum().item())
|
|
sub_img = sub_img[:, :useful_height, :useful_width]
|
|
temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1)
|
|
temp_len = int(
|
|
image_attention_mask[_bs, :B_ + 1, 0::2, 0::2].sum().item(
|
|
)) + (useful_height +
|
|
1) + base_feat_height // base_feat_height_reduction
|
|
else:
|
|
temp_sub_GN = self.sub_GN.repeat(
|
|
1, h * base_feat_height // base_feat_height_reduction, 1,
|
|
1)
|
|
temp_len = int((h * w + 1) * self.num_img_tokens + 1 +
|
|
(h + 1) * base_feat_height //
|
|
base_feat_height_reduction)
|
|
|
|
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(
|
|
1, -1,
|
|
base_feat_height_reduction * base_feat_height_reduction * C)
|
|
# (1, num_img_tokens, 1024*4)
|
|
|
|
# glb + sub
|
|
if self.hd_transform_order == 'glb_sub':
|
|
output_imgs.append(
|
|
torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
|
|
elif self.hd_transform_order == 'sub_glb':
|
|
output_imgs.append(
|
|
torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
|
|
else:
|
|
raise NotImplementedError(
|
|
f'hd_transform_order = {self.hd_transform_order}, "\
|
|
"not implemented')
|
|
|
|
#temp_len = int((h*w+1)*144 + 1 + (h+1)*12)
|
|
assert temp_len == output_imgs[-1].shape[
|
|
1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\
|
|
"{output_imgs[-1].shape[1]}'
|
|
|
|
output_len.append(temp_len)
|
|
|
|
img_set_tensor = []
|
|
for _output_img in output_imgs:
|
|
img_feature_proj = self.img_projection(
|
|
_output_img.to(target_device).to(target_dtype))
|
|
img_set_tensor.append(img_feature_proj)
|
|
|
|
return img_set_tensor
|
|
|
|
|
|
class Phi4MMAudioFeatureInputs(TypedDict):
|
|
type: Literal["audio_features"]
|
|
data: Tuple[NestedTensors]
|
|
"""Shape: `((batch_size, num_audios, 80, M), )"""
|
|
|
|
|
|
class Phi4MMAudioEmbeddingInputs(TypedDict):
|
|
type: Literal["audio_embeds"]
|
|
data: NestedTensors
|
|
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
|
|
|
|
|
|
Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
|
|
|
|
|
|
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
|
|
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
|
|
|
|
Args:
|
|
sample_rate (int): Sample rate in Hz. number > 0 [scalar]
|
|
n_fft (int): FFT size. int > 0 [scalar]
|
|
n_mel (int): Mel filter size. int > 0 [scalar]
|
|
fmin (float): lowest frequency (in Hz). If None use 0.0.
|
|
float >= 0 [scalar]
|
|
fmax: highest frequency (in Hz). If None use sample_rate / 2.
|
|
float >= 0 [scalar]
|
|
|
|
Returns
|
|
out (numpy.ndarray): Mel transform matrix
|
|
[shape=(n_mels, 1 + n_fft/2)]
|
|
"""
|
|
|
|
bank_width = int(n_fft // 2 + 1)
|
|
if fmax is None:
|
|
fmax = sample_rate / 2
|
|
if fmin is None:
|
|
fmin = 0
|
|
assert fmin >= 0, "fmin cannot be negative"
|
|
assert (fmin < fmax <=
|
|
sample_rate / 2), "fmax must be between (fmin, samplerate / 2]"
|
|
|
|
def mel(f):
|
|
return 1127.0 * np.log(1.0 + f / 700.0)
|
|
|
|
def bin2mel(fft_bin):
|
|
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
|
|
|
|
def f2bin(f):
|
|
return int((f * n_fft / sample_rate) + 0.5)
|
|
|
|
# Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
|
|
klo = f2bin(fmin) + 1
|
|
khi = f2bin(fmax)
|
|
|
|
khi = max(khi, klo)
|
|
|
|
# Spec 2: SpeechLib uses triangles in Mel space
|
|
mlo = mel(fmin)
|
|
mhi = mel(fmax)
|
|
m_centers = np.linspace(mlo, mhi, n_mels + 2)
|
|
ms = (mhi - mlo) / (n_mels + 1)
|
|
|
|
matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
|
|
for m in range(0, n_mels):
|
|
left = m_centers[m]
|
|
center = m_centers[m + 1]
|
|
right = m_centers[m + 2]
|
|
for fft_bin in range(klo, khi):
|
|
mbin = bin2mel(fft_bin)
|
|
if left < mbin < right:
|
|
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
|
|
|
|
return matrix
|
|
|
|
|
|
class LogFbankProcessor:
|
|
|
|
def __init__(self):
|
|
|
|
self._eightk_method = "fillzero"
|
|
self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
|
|
|
|
self._hamming400 = np.hamming(400) # for 16k audio
|
|
self._hamming200 = np.hamming(200) # for 8k audio
|
|
|
|
def extract_spectrogram(self, wav, fs):
|
|
"""Extract spectrogram features from waveform.
|
|
Args:
|
|
wav (1D array): waveform of the input
|
|
fs (int): sampling rate of the waveform, 16000 or 8000.
|
|
If fs=8000, the waveform will be resampled to 16000Hz.
|
|
Output:
|
|
log_fbank (2D array): a TxD matrix of log Mel filterbank features.
|
|
D=80, and T is the number of frames.
|
|
"""
|
|
if wav.ndim > 1:
|
|
wav = np.squeeze(wav)
|
|
|
|
# by default, we extract the mean if stereo
|
|
if len(wav.shape) == 2:
|
|
wav = wav.mean(1)
|
|
|
|
# Resample to 16000 or 8000 if needed
|
|
if fs > 16000:
|
|
wav = scipy.signal.resample_poly(wav, 1, fs // 16000)
|
|
fs = 16000
|
|
elif 8000 < fs < 16000:
|
|
wav = scipy.signal.resample_poly(wav, 1, fs // 8000)
|
|
fs = 8000
|
|
elif fs < 8000:
|
|
raise RuntimeError(f"Unsupported sample rate {fs}")
|
|
|
|
if fs == 8000:
|
|
if self._eightk_method == "resample":
|
|
# Input audio is 8 kHz. Convert to 16 kHz before feature
|
|
# extraction
|
|
wav = scipy.signal.resample_poly(wav, 2, 1)
|
|
fs = 16000
|
|
# Do nothing here for fillzero method
|
|
elif fs != 16000:
|
|
# Input audio is not a supported sample rate.
|
|
raise RuntimeError(
|
|
f"Input data using an unsupported sample rate: {fs}")
|
|
|
|
preemphasis = 0.97
|
|
|
|
if fs == 8000:
|
|
n_fft = 256
|
|
win_length = 200
|
|
hop_length = 80
|
|
fft_window = self._hamming200
|
|
elif fs == 16000:
|
|
n_fft = 512
|
|
win_length = 400
|
|
hop_length = 160
|
|
fft_window = self._hamming400
|
|
|
|
# Spec 1: SpeechLib cut remaining sample insufficient for a hop
|
|
n_batch = (wav.shape[0] - win_length) // hop_length + 1
|
|
# Here we don't use stride_tricks since the input array may not satisfy
|
|
# memory layout requirement and we need writeable output
|
|
# Here we only use list of views before copy to destination
|
|
# so it is more efficient than broadcasting
|
|
y_frames = np.array(
|
|
[
|
|
wav[_stride:_stride + win_length]
|
|
for _stride in range(0, hop_length * n_batch, hop_length)
|
|
],
|
|
dtype=np.float32,
|
|
)
|
|
|
|
# Spec 2: SpeechLib applies preemphasis within each batch
|
|
y_frames_prev = np.roll(y_frames, 1, axis=1)
|
|
y_frames_prev[:, 0] = y_frames_prev[:, 1]
|
|
y_frames = (y_frames - preemphasis * y_frames_prev) * 32768
|
|
|
|
S = np.fft.rfft(fft_window * y_frames, n=n_fft,
|
|
axis=1).astype(np.complex64)
|
|
|
|
if fs == 8000:
|
|
# Need to pad the output to look like 16 kHz data but with zeros in
|
|
# the 4 to 8 kHz bins.
|
|
frames, bins = S.shape
|
|
padarray = np.zeros((frames, bins))
|
|
S = np.concatenate((S[:, 0:-1], padarray),
|
|
axis=1) # Nyquist bin gets set to zero
|
|
|
|
spec = np.abs(S).astype(np.float32)
|
|
return spec
|
|
|
|
def extract_features(self, wav, fs):
|
|
"""Extract log filterbank features from waveform.
|
|
Args:
|
|
wav (1D array): waveform of the input
|
|
fs (int): sampling rate of the waveform, 16000 or 8000.
|
|
If fs=8000, the waveform will be resampled to 16000Hz.
|
|
Output:
|
|
log_fbank (2D array): a TxD matrix of log Mel filterbank features.
|
|
D=80, and T is the number of frames.
|
|
"""
|
|
spec = self.extract_spectrogram(wav, fs)
|
|
spec_power = spec**2
|
|
|
|
fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
|
|
log_fbank = np.log(fbank_power).astype(np.float32)
|
|
|
|
return log_fbank
|
|
|
|
|
|
@lru_cache
|
|
def audio_feature_extractor() -> LogFbankProcessor:
|
|
# Creates an instance of the audio processor, needed to extract the
|
|
# the audio features from the sound file
|
|
# LRU cache ensures that we only make one copy
|
|
return LogFbankProcessor()
|
|
|
|
|
|
def _compute_num_image_tokens(image, dynamic_hd_size, vit_image_size,
|
|
vit_patch_size, token_compression_factor):
|
|
"""
|
|
compute the number of tokens an image is expected to take up considering
|
|
the image encoder architecture and exclude output features containing
|
|
only padding pixels
|
|
|
|
for siglip, vit_image_size=448, vit_patch_size=14, so output will be
|
|
32x32 feature map
|
|
NOTE right now, Phi4MM uses hard-coded token_compression_factor=2
|
|
"""
|
|
assert vit_image_size % vit_patch_size == 0, \
|
|
"vit_image_size must be divisible by vit_patch_size"
|
|
assert vit_image_size // vit_patch_size % token_compression_factor == 0, \
|
|
"vit_image_size // vit_patch_size must be divisible by "\
|
|
"token_compression_factor"
|
|
|
|
target_aspect_ratio, target_height, target_width = (
|
|
_find_target_aspect_ratio(image,
|
|
vit_image_size,
|
|
dynamic_hd_size,
|
|
min_num=1))
|
|
assert target_aspect_ratio[
|
|
0] * vit_image_size == target_width, \
|
|
f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}"
|
|
assert target_aspect_ratio[
|
|
1] * vit_image_size == target_height, \
|
|
f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}"
|
|
assert (target_height % vit_image_size == 0
|
|
and target_width % vit_image_size == 0)
|
|
|
|
padding_height, padding_width = _get_padding_size(image, target_height,
|
|
target_width)
|
|
assert padding_width == 0 or padding_height == 0, \
|
|
"padding_width or padding_height must be 0"
|
|
|
|
target_feat_width = target_width // vit_patch_size
|
|
target_feat_height = target_height // vit_patch_size
|
|
if padding_width >= vit_patch_size:
|
|
assert padding_height == 0, "padding_height not 0"
|
|
non_pad_feat_width = target_feat_width - math.floor(
|
|
padding_width / vit_patch_size)
|
|
non_pad_feat_height = target_feat_height
|
|
elif padding_height >= vit_patch_size:
|
|
assert padding_width == 0, "padding_width not 0"
|
|
non_pad_feat_height = target_feat_height - math.floor(
|
|
padding_height / vit_patch_size)
|
|
non_pad_feat_width = target_feat_width
|
|
else:
|
|
# small padding shorter than a vit patch
|
|
non_pad_feat_width = target_feat_width
|
|
non_pad_feat_height = target_feat_height
|
|
|
|
feat_width = non_pad_feat_width // token_compression_factor
|
|
feat_height = non_pad_feat_height // token_compression_factor
|
|
# NOTE it's possible that the non-padding feature is not divisible
|
|
if non_pad_feat_width % token_compression_factor != 0:
|
|
feat_width += 1
|
|
if non_pad_feat_height % token_compression_factor != 0:
|
|
feat_height += 1
|
|
num_hd_patch_tokens = feat_width * feat_height
|
|
num_hd_newline_tokens = feat_height
|
|
vit_feature_size = vit_image_size // vit_patch_size
|
|
num_global_image_tokens = (vit_feature_size // token_compression_factor)**2
|
|
num_sep_tokens = 1
|
|
num_global_image_newline_tokens = \
|
|
vit_feature_size // token_compression_factor
|
|
|
|
return (num_global_image_tokens + num_sep_tokens + num_hd_patch_tokens +
|
|
num_hd_newline_tokens + num_global_image_newline_tokens)
|
|
|
|
|
|
def compute_logfbank_output_size(wav_length: int, fs: int) -> Tuple[int, int]:
|
|
"""
|
|
Compute the output size of the `extract_features` method.
|
|
|
|
Args:
|
|
wav_length (int): Length of the input waveform in samples.
|
|
fs (int): Sampling rate of the waveform, either 16000 or 8000.
|
|
|
|
Returns:
|
|
tuple (int, int): Output size as (T, D), where:
|
|
T: Number of time frames.
|
|
D: Number of Mel filterbank bins (80).
|
|
"""
|
|
|
|
# Resample to 16000 or 8000 if needed
|
|
if fs > 16000:
|
|
wav_length //= fs // 16000
|
|
fs = 16000
|
|
elif 8000 <= fs < 16000:
|
|
# We'll resample to 16K from 8K
|
|
wav_length *= 2
|
|
fs = 16000
|
|
elif fs < 8000:
|
|
raise RuntimeError(f"Unsupported sample rate {fs}")
|
|
|
|
# Spectrogram parameters for 16 kHz
|
|
win_length = 400 # Frame length in samples
|
|
hop_length = 160 # Frame shift in samples
|
|
mel_bins = 80 # Number of mel filterbank bins
|
|
|
|
# Calculate number of frames (T)
|
|
T = (wav_length - win_length) // hop_length + 1
|
|
if T < 1:
|
|
raise ValueError("Waveform too short for given parameters.")
|
|
|
|
# Return time frames (T) and mel bins (D)
|
|
return T, mel_bins
|
|
|
|
|
|
def _get_audio_embed_sizes(audios, ctx: InputContext):
|
|
"""
|
|
Get the audio embedding sizes for each audio file.
|
|
|
|
Args:
|
|
audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of
|
|
waveform and sample rate.
|
|
ctx (InputContext): Input context.
|
|
|
|
Returns:
|
|
List[int]: List of audio embedding sizes.
|
|
"""
|
|
audio_embed_sizes = []
|
|
for audio in audios:
|
|
audio_data, sf = audio
|
|
audio_frames, _ = compute_logfbank_output_size(len(audio_data), sf)
|
|
audio_embed_size = _compute_audio_embed_size(ctx.get_hf_config(),
|
|
audio_frames)
|
|
audio_embed_sizes.append(audio_embed_size)
|
|
return audio_embed_sizes
|
|
|
|
|
|
def _get_audio_id_to_input_ids(audios, ctx: InputContext, prompt_str=""):
|
|
"""
|
|
The following will search for `<|audio_{idx}|>` tokens and
|
|
return a mapping of audio placeholder tokens to audio placeholder token ids
|
|
based on the size of the audio embeddings.
|
|
|
|
Args:
|
|
audios (List[Tuple[np.ndarray, int]]): List of audio files as tuples of
|
|
waveform and sample rate.
|
|
ctx (InputContext): Input context.
|
|
prompt_str (str): The prompt string.
|
|
|
|
Returns:
|
|
Dict[str, List[int]]: Mapping of audio placeholder tokens to audio
|
|
placeholder token ids.
|
|
|
|
"""
|
|
if len(audios) == 0:
|
|
return {}
|
|
|
|
audio_embed_sizes = _get_audio_embed_sizes(audios, ctx)
|
|
audio_ids = re.findall(AUDIO_TOKEN_PATTERN, prompt_str)
|
|
audio_ids = [int(audio_id) for audio_id in audio_ids]
|
|
assert len(audio_ids) == len(
|
|
audio_embed_sizes
|
|
), "Number of audio tokens and audio features do not match"
|
|
assert tuple(audio_ids) == tuple(range(1,
|
|
len(audio_ids) +
|
|
1)), "Audio ids are not in order!"
|
|
audio_id_to_input_ids = {
|
|
f"<|audio_{audio_id}|>":
|
|
[_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
|
|
for audio_id, audio_embed_size in zip(audio_ids, audio_embed_sizes)
|
|
}
|
|
|
|
return audio_id_to_input_ids
|
|
|
|
|
|
def _count_image_tokens(images, ctx: InputContext):
|
|
hf_config = ctx.get_hf_config()
|
|
vision_encoder_name = hf_config.img_processor
|
|
if vision_encoder_name is None:
|
|
vision_encoder_name = SIGLIP_NAME
|
|
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
|
|
dynamic_hd_size = prepro_config['dynamic_hd']
|
|
vit_image_size = prepro_config['vit_image_size']
|
|
vit_patch_size = prepro_config['vit_patch_size']
|
|
token_compression_factor = prepro_config['token_compression_factor']
|
|
|
|
image_token_counts = [
|
|
_compute_num_image_tokens(image, dynamic_hd_size, vit_image_size,
|
|
vit_patch_size, token_compression_factor)
|
|
for image in images
|
|
]
|
|
return image_token_counts
|
|
|
|
|
|
def _get_image_id_to_input_ids(images, prompt, ctx: InputContext):
|
|
if len(images) == 0:
|
|
return {}
|
|
|
|
image_ids = re.findall(IMAGE_TOKEN_PATTERN, prompt)
|
|
image_ids = [int(image_id) for image_id in image_ids]
|
|
assert len(image_ids) == len(
|
|
set(image_ids)), "Duplicate image tokens in prompt"
|
|
assert len(images) == len(
|
|
image_ids), "Number of images and image tokens in prompt do not match"
|
|
|
|
# NOTE the following assertion is not strictly necessary
|
|
assert tuple(image_ids) == tuple(range(1,
|
|
len(image_ids) +
|
|
1)), "Image ids are not in order"
|
|
|
|
image_token_counts = _count_image_tokens(images, ctx)
|
|
image_id_to_input_ids = {
|
|
f"<|image_{image_id}|>": [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_tokens
|
|
for image_id, num_tokens in zip(image_ids, image_token_counts)
|
|
}
|
|
return image_id_to_input_ids
|
|
|
|
|
|
def input_processor_for_phi4mm(ctx: InputContext,
|
|
inputs: DecoderOnlyInputs) -> TokenInputs:
|
|
"""
|
|
Implements the input processor, which transforms the input prompt ids
|
|
to include the audio placeholder token. This will become the `input_ids`
|
|
in `forward` for the model.
|
|
|
|
Args:
|
|
ctx (InputContext): Input context.
|
|
inputs (DecoderOnlyInputs): The inputs (e.g. prompt, prompt_token_ids)
|
|
to process.
|
|
|
|
Returns:
|
|
TokenInputs: Processed inputs
|
|
"""
|
|
multi_modal_data = inputs.get("multi_modal_data")
|
|
if (multi_modal_data is None or
|
|
("audio" not in multi_modal_data and "image" not in multi_modal_data)):
|
|
# pure text input, so no need to do pre-processing
|
|
return inputs
|
|
|
|
prompt_str = inputs.get("prompt")
|
|
prompt_token_ids = inputs.get("prompt_token_ids")
|
|
# for offline_inference, we will get str input and we parse MM special
|
|
# tokens from it
|
|
# (ignore prompt_token_ids)
|
|
# for OAI server, we will get prompt_token_ids, where MM special tokens
|
|
# are already parsed
|
|
|
|
if 'audio' in multi_modal_data:
|
|
audios = multi_modal_data["audio"]
|
|
|
|
if not isinstance(audios, list):
|
|
audios = [audios]
|
|
if prompt_str is not None:
|
|
audio_id_to_input_ids = _get_audio_id_to_input_ids(
|
|
audios, ctx, prompt_str=prompt_str)
|
|
audio_embed_sizes = []
|
|
elif prompt_token_ids is not None:
|
|
audio_id_to_input_ids = {}
|
|
audio_embed_sizes = _get_audio_embed_sizes(audios, ctx)
|
|
else:
|
|
audio_id_to_input_ids = {}
|
|
audio_embed_sizes = []
|
|
|
|
if 'image' in multi_modal_data:
|
|
# PIL Image or list of PIL Images
|
|
images = multi_modal_data["image"]
|
|
if not isinstance(images, list):
|
|
images = [images]
|
|
if prompt_str is not None:
|
|
image_id_to_input_ids = _get_image_id_to_input_ids(
|
|
images, prompt_str, ctx)
|
|
image_token_counts = []
|
|
elif prompt_token_ids is not None:
|
|
image_id_to_input_ids = {}
|
|
image_token_counts = _count_image_tokens(images, ctx)
|
|
else:
|
|
image_id_to_input_ids = {}
|
|
image_token_counts = []
|
|
|
|
# Handle the case where the prompt is a string and we need to manually
|
|
# tokenize it.
|
|
# In this case, the `audio_id_to_input_ids` dict will be mapping from
|
|
# an audio placeholder
|
|
# string (e.g. `<|audio_1|>`) to the audio placeholder tokens for the
|
|
# given audio length.
|
|
if prompt_str:
|
|
pattern = r"(<\|image_\d+\|>|<\|audio_\d+\|>)"
|
|
prompt_chunk_strings = re.split(pattern, prompt_str)
|
|
prompt_chunk_strings = [s for s in prompt_chunk_strings if s != ""]
|
|
|
|
# Create the new input_ids with the placeholder image and audio
|
|
# tokens inserted
|
|
tokenizer = cached_tokenizer_from_config(ctx.model_config)
|
|
input_ids = []
|
|
has_imag, has_audio, has_user_text_input = False, False, False
|
|
for prompt_chunk_string in prompt_chunk_strings:
|
|
if re.match(IMAGE_TOKEN_PATTERN, prompt_chunk_string):
|
|
input_ids.extend(image_id_to_input_ids[prompt_chunk_string])
|
|
has_imag = True
|
|
elif re.match(AUDIO_TOKEN_PATTERN, prompt_chunk_string):
|
|
input_ids.extend(audio_id_to_input_ids[prompt_chunk_string])
|
|
has_audio = True
|
|
else:
|
|
curr_token_ids = tokenizer(prompt_chunk_string).input_ids
|
|
if not has_user_text_input:
|
|
for token_id in curr_token_ids:
|
|
if token_id not in NON_USER_INPUT_TOKENS:
|
|
has_user_text_input = True
|
|
break
|
|
input_ids.extend(curr_token_ids)
|
|
if has_audio and has_imag and has_user_text_input:
|
|
raise ValueError(
|
|
"Phi4MMForCausalLM does not support text + audio + image" +
|
|
" inputs in the same prompt")
|
|
# Handle the case where the prompt is already tokenized
|
|
else:
|
|
assert prompt_token_ids is not None, \
|
|
"If string prompt isn't provided, prompt_token_ids must be"
|
|
|
|
i = 0
|
|
input_ids = prompt_token_ids
|
|
# only needed for later assertion
|
|
img_cnt, audio_cnt, user_text_input_cnt = 0, 0, 0
|
|
image_token_count_iter = iter(image_token_counts)
|
|
audio_embed_size_iter = iter(audio_embed_sizes)
|
|
while i < len(input_ids):
|
|
token_id = input_ids[i]
|
|
if token_id == _AUDIO_PLACEHOLDER_TOKEN_ID:
|
|
token_count = next(audio_embed_size_iter)
|
|
audio_cnt += 1
|
|
elif token_id == _IMAGE_PLACEHOLDER_TOKEN_ID:
|
|
token_count = next(image_token_count_iter)
|
|
img_cnt += 1
|
|
else:
|
|
user_text_input_cnt += 1 if token_id not in \
|
|
NON_USER_INPUT_TOKENS else 0
|
|
i += 1
|
|
continue
|
|
tokens = [token_id] * token_count
|
|
input_ids = input_ids[:i] + tokens + input_ids[i + 1:]
|
|
i += token_count
|
|
|
|
if audio_cnt > 0 and img_cnt > 0 and user_text_input_cnt > 0:
|
|
raise ValueError(
|
|
"Phi4MMForCausalLM does not support text + audio + image" +
|
|
" inputs in the same prompt")
|
|
# If the below assertion fails, it might be that input pure-text
|
|
# messages contain image/audio special tokens literally
|
|
# (<|endoftext10|>, <|endoftext11|>).
|
|
assert (img_cnt == len(image_token_counts)), (
|
|
f"Number of image tokens in prompt_token_ids ({img_cnt}) "
|
|
f"does not match number of images ({len(image_token_counts)})")
|
|
assert (audio_cnt == len(audio_embed_sizes)), (
|
|
f"Number of audio tokens in prompt_token_ids ({audio_cnt}) "
|
|
f"does not match number of audios ({len(audio_embed_sizes)})")
|
|
|
|
# NOTE: Create a defensive copy of the original inputs
|
|
return token_inputs(
|
|
prompt_token_ids=input_ids,
|
|
prompt=prompt_str,
|
|
multi_modal_data=multi_modal_data,
|
|
)
|
|
|
|
|
|
def _compute_audio_embed_size(hf_config, audio_frames):
|
|
"""
|
|
Compute the audio embedding size based on the audio frames and
|
|
compression rate.
|
|
"""
|
|
compression_rate = hf_config.embd_layer['audio_embd_layer'][
|
|
'compression_rate']
|
|
# NOTE: this is a hard-coded value but might be configurable in the future
|
|
qformer_compression_rate = 1
|
|
integer = audio_frames // compression_rate
|
|
remainder = audio_frames % compression_rate
|
|
|
|
result = integer if remainder == 0 else integer + 1
|
|
|
|
integer = result // qformer_compression_rate
|
|
remainder = result % qformer_compression_rate
|
|
result = integer if remainder == 0 else integer + 1 # qformer compression
|
|
|
|
return result
|
|
|
|
|
|
def get_max_phi4mm_audio_tokens(ctx: InputContext) -> int:
|
|
return 10000
|
|
|
|
|
|
def dummy_audio_for_phi4mm(audio_count: int) -> dict:
|
|
"""
|
|
Create dummy audio data for the Phi4MM model, which is used for profiling.
|
|
|
|
Args:
|
|
audio_count (int): Number of audio samples.
|
|
|
|
Returns:
|
|
dict: Dummy audio data.
|
|
"""
|
|
dummy_audio = np.full((_AUDIO_MAX_SOUNDFILE_SIZE, ), 0.0)
|
|
return [(dummy_audio, DUMMY_SAMPLING_FREQUENCY)] * audio_count
|
|
|
|
|
|
def dummy_image_for_phi4mm(width: int, height: int):
|
|
image = Image.new('RGB', (width, height), color='black')
|
|
return image
|
|
|
|
|
|
def dummy_data_for_phi4mm(ctx: InputContext, seq_len: int,
|
|
mm_counts: Mapping[str, int]) -> DummyData:
|
|
"""
|
|
Create dummy sequence (input_ids) and audio data for the Phi4MM model,
|
|
which is used for profiling.
|
|
|
|
In this case, the sequence data is a bunch of 0s with a number of audio
|
|
tokens that correspond to the audio embed size of the
|
|
_AUDIO_MAX_SOUNDFILE_SIZE.
|
|
|
|
Args:
|
|
ctx (InputContext): Input context.
|
|
seq_len (int): Length of the sequence.
|
|
mm_counts (Mapping[str, int]): Multi-modal counts.
|
|
|
|
Returns:
|
|
Tuple: Dummy sequence data and dummy audio data.
|
|
"""
|
|
audio_count = mm_counts["audio"]
|
|
audio_frames, _ = compute_logfbank_output_size(_AUDIO_MAX_SOUNDFILE_SIZE,
|
|
DUMMY_SAMPLING_FREQUENCY)
|
|
audio_feature_size = _compute_audio_embed_size(ctx.get_hf_config(),
|
|
audio_frames)
|
|
|
|
image_count = mm_counts["image"]
|
|
dummy_image = get_max_dummy_image(ctx)
|
|
max_image_tokens = get_max_phi4mm_image_tokens(ctx)
|
|
total_image_tokens = image_count * max_image_tokens
|
|
|
|
if seq_len - audio_feature_size * audio_count - total_image_tokens < 0:
|
|
raise RuntimeError(
|
|
f"Phi4MM cannot process {audio_count} audios and {image_count}"
|
|
f"images in a prompt, please increase max_model_len to be at"
|
|
f" larger than "
|
|
f"{audio_feature_size * audio_count + total_image_tokens}"
|
|
" or reduce audio/image limit by --limit-mm-per-prompt.")
|
|
|
|
if audio_feature_size * audio_count > total_image_tokens:
|
|
seq_data = SequenceData.from_prompt_token_counts(
|
|
(_AUDIO_PLACEHOLDER_TOKEN_ID, audio_feature_size * audio_count),
|
|
(0, seq_len - audio_feature_size * audio_count),
|
|
)
|
|
mm_data = {
|
|
"audio": dummy_audio_for_phi4mm(audio_count),
|
|
}
|
|
else:
|
|
seq_data = SequenceData.from_prompt_token_counts(
|
|
(_IMAGE_PLACEHOLDER_TOKEN_ID, total_image_tokens),
|
|
(0, seq_len - total_image_tokens),
|
|
)
|
|
mm_data = {
|
|
"image": [dummy_image] * image_count,
|
|
}
|
|
return DummyData(seq_data, mm_data)
|
|
|
|
|
|
def input_mapper_for_phi4mm_audio(ctx: InputContext,
|
|
data: object) -> MultiModalKwargs:
|
|
"""
|
|
This function is used to create the MultiModalKwargs for the Phi4MM
|
|
(audio) model.
|
|
Specifically, for audio, we extract the audio features from the sound
|
|
file and create pairs of audio features and audio embed lengths (the
|
|
latter of which is used to repeat the audio placeholder token in the
|
|
input prompt IDs).
|
|
These pairs are used, downstream, in `_audio_features_to_embeddings`
|
|
(via `_process_audio_input`).
|
|
|
|
Note that the incoming audio data (each entry in `data`) is a tuple of
|
|
the audio data and the sampling frequency (e.g. from soundfile.read).
|
|
|
|
Args:
|
|
ctx (InputContext): Input context.
|
|
data (object): Audio data.
|
|
|
|
Returns:
|
|
MultiModalKwargs: Multi-modal inputs.
|
|
"""
|
|
if not isinstance(data, list):
|
|
data = [data]
|
|
|
|
if len(data) == 0:
|
|
return MultiModalKwargs()
|
|
|
|
audio_features = []
|
|
for audio_input in data:
|
|
if not isinstance(audio_input, tuple):
|
|
raise NotImplementedError(
|
|
f"Unsupported data type: {type(audio_input)}")
|
|
|
|
audio, sf = audio_input
|
|
feature_extractor = audio_feature_extractor()
|
|
single_audio_features = feature_extractor.extract_features(audio, sf)
|
|
feat_stride = (1 if not hasattr(feature_extractor, "stride") else
|
|
feature_extractor.stride)
|
|
audio_frames = len(single_audio_features) * feat_stride
|
|
single_audio_embed_size = _compute_audio_embed_size(
|
|
ctx.get_hf_config(), audio_frames)
|
|
single_audio_feature_audio_len_pair = (
|
|
single_audio_features,
|
|
[single_audio_embed_size],
|
|
)
|
|
audio_features.append(single_audio_feature_audio_len_pair)
|
|
return MultiModalKwargs({"audio_features": audio_features})
|
|
|
|
|
|
def input_mapper_for_phi4mm_image(ctx: InputContext, data: object):
|
|
if not isinstance(data, list):
|
|
data = [data]
|
|
# data: list of PIL images
|
|
if len(data) == 0:
|
|
return MultiModalKwargs()
|
|
hf_config = ctx.get_hf_config()
|
|
vision_encoder_name = hf_config.img_processor
|
|
if vision_encoder_name is None:
|
|
vision_encoder_name = SIGLIP_NAME
|
|
prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name]
|
|
dynamic_hd_size = prepro_config['dynamic_hd']
|
|
vit_image_size = prepro_config['vit_image_size']
|
|
vit_patch_size = prepro_config['vit_patch_size']
|
|
|
|
image_input_dict = preprocess(data, dynamic_hd_size, vit_image_size,
|
|
vit_patch_size)
|
|
return MultiModalKwargs({
|
|
"pixel_values":
|
|
image_input_dict["pixel_values"],
|
|
"image_sizes":
|
|
image_input_dict["image_sizes"],
|
|
"image_attention_mask":
|
|
image_input_dict["image_attention_mask"],
|
|
"num_img_tokens":
|
|
image_input_dict["num_img_tokens"],
|
|
})
|
|
|
|
|
|
def cat_with_pad(tensors, dim, padding_value=0):
|
|
"""
|
|
cat along dim, while pad to max for all other dims
|
|
"""
|
|
ndim = tensors[0].dim()
|
|
assert all(
|
|
t.dim() == ndim for t in
|
|
tensors[1:]), "All tensors must have the same number of dimensions"
|
|
|
|
out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
|
|
out_size[dim] = sum(t.shape[dim] for t in tensors)
|
|
output = tensors[0].new_full(out_size, padding_value)
|
|
|
|
index = 0
|
|
for t in tensors:
|
|
# Create a slice list where every dimension except dim is full slice
|
|
slices = [slice(0, t.shape[d]) for d in range(ndim)]
|
|
# Update only the concat dimension slice
|
|
slices[dim] = slice(index, index + t.shape[dim])
|
|
|
|
output[slices] = t
|
|
index += t.shape[dim]
|
|
|
|
return output
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_input_mapper("audio",
|
|
input_mapper_for_phi4mm_audio)
|
|
@MULTIMODAL_REGISTRY.register_input_mapper("image",
|
|
input_mapper_for_phi4mm_image)
|
|
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
|
"audio", get_max_phi4mm_audio_tokens)
|
|
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
|
"image", get_max_phi4mm_image_tokens)
|
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi4mm)
|
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi4mm)
|
|
class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
|
|
SupportsV0Only):
|
|
"""
|
|
Implements the Phi-4-multimodal-instruct model in vLLM.
|
|
"""
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"qkv_proj",
|
|
],
|
|
"gate_up_proj": [
|
|
"gate_up_proj",
|
|
],
|
|
}
|
|
|
|
hf_to_vllm_mapper = WeightsMapper(
|
|
orig_to_new_substr={
|
|
"base_layer.": "",
|
|
},
|
|
orig_to_new_prefix={
|
|
"model.embed_tokens_extend.audio_embed.audio_projection.vision.":
|
|
"embed_tokens_extend.audio_projection_for_vision.",
|
|
"model.embed_tokens_extend.audio_embed.audio_projection.speech.":
|
|
"embed_tokens_extend.audio_projection.",
|
|
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
|
|
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
|
|
},
|
|
)
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
assert multimodal_config, "multimodal_config is required"
|
|
quant_config = vllm_config.quant_config
|
|
lora_config = vllm_config.lora_config
|
|
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
self.quant_config = quant_config
|
|
self.lora_config = lora_config
|
|
|
|
# Tensor/Pipeline parallel not supported for now.
|
|
assert get_pp_group(
|
|
).world_size == 1, "pipeline parallel is not supported"
|
|
|
|
self.vision_encoder = Phi4MMImageEncoder(
|
|
config,
|
|
quant_config,
|
|
prefix="model.vision_embed_tokens",
|
|
model_dir=config._name_or_path)
|
|
|
|
if isinstance(config.embd_layer["audio_embd_layer"], dict):
|
|
embedding_config = {
|
|
"embedding_cls":
|
|
config.embd_layer["audio_embd_layer"]["embedding_cls"],
|
|
**config.embd_layer["audio_embd_layer"],
|
|
}
|
|
else:
|
|
embedding_config = {
|
|
"embedding_cls": self.config.embd_layer["embedding_cls"]
|
|
}
|
|
|
|
self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
|
|
self.model = LlamaModel(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"))
|
|
|
|
self.unpadded_vocab_size = config.vocab_size
|
|
if lora_config:
|
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
|
self.lm_head = ParallelLMHead(
|
|
self.unpadded_vocab_size,
|
|
config.hidden_size,
|
|
org_num_embeddings=config.vocab_size,
|
|
padding_size=(
|
|
DEFAULT_VOCAB_PADDING_SIZE
|
|
# We need bigger padding if using lora for kernel
|
|
# compatibility
|
|
if not lora_config else lora_config.lora_vocab_padding_size),
|
|
quant_config=quant_config,
|
|
)
|
|
if config.tie_word_embeddings:
|
|
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
config.vocab_size, logit_scale)
|
|
self.sampler = Sampler()
|
|
|
|
def _audio_features_to_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
input_features: List[torch.Tensor],
|
|
audio_input_sizes: torch.Tensor,
|
|
audio_projection_mode: str,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Convert audio features to embeddings, which are used as input to the
|
|
model (via `inputs_embeds`).
|
|
|
|
Args:
|
|
input_ids (torch.Tensor): Input IDs (the prompt in this case).
|
|
input_features (list[torch.Tensor]): Input features (the audio
|
|
embeddings).
|
|
audio_input_sizes (list[torch.Tensor]): Audio input sizes (the
|
|
audio embed lengths to use for padding the audio placeholder token
|
|
in the input prompt IDs).
|
|
"""
|
|
# The audio projection can either be a single linear or Sequential,
|
|
# so handle both cases
|
|
if isinstance(self.embed_tokens_extend.audio_projection,
|
|
nn.Sequential):
|
|
target_dtype = self.embed_tokens_extend.audio_projection[
|
|
0].bias.dtype
|
|
else:
|
|
target_dtype = self.embed_tokens_extend.audio_projection.bias.dtype
|
|
|
|
audio_input = [
|
|
input.unsqueeze(0).to(target_dtype) for input in input_features
|
|
]
|
|
kwargs = {
|
|
"wte": self.model.embed_tokens,
|
|
'audio_projection_mode': audio_projection_mode
|
|
}
|
|
audio_embeddings = self.embed_tokens_extend(input_ids, audio_input,
|
|
audio_input_sizes,
|
|
**kwargs)
|
|
audio_embeddings = audio_embeddings.to(target_dtype)
|
|
return audio_embeddings
|
|
|
|
def _parse_and_validate_audio_input(
|
|
self, **kwargs: object) -> Optional[Phi4MMAudioInputs]:
|
|
"""
|
|
Parse and validate the audio input to the model. This handles both
|
|
audio features and audio embeddings, but only the former is used for
|
|
now.
|
|
|
|
Args:
|
|
kwargs (object): Keyword arguments.
|
|
|
|
Returns:
|
|
Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs.
|
|
"""
|
|
audio_features = kwargs.pop("audio_features", None)
|
|
audio_embeds = kwargs.pop("audio_embeds", None)
|
|
|
|
if audio_features is None and audio_embeds is None:
|
|
return None
|
|
|
|
if audio_features is not None:
|
|
if not isinstance(audio_features, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of audio features. "
|
|
f"Got type: {type(audio_features)}")
|
|
|
|
return Phi4MMAudioFeatureInputs(type="audio_features",
|
|
data=audio_features)
|
|
|
|
if audio_embeds is not None:
|
|
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of audio embeds. "
|
|
f"Got type: {type(audio_embeds)}")
|
|
|
|
return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
|
|
data=audio_embeds)
|
|
|
|
raise AssertionError("This line should be unreachable.")
|
|
|
|
def _process_audio_input(self, input_ids: torch.Tensor,
|
|
audio_input: Phi4MMAudioInputs,
|
|
audio_projection_mode: str) -> NestedTensors:
|
|
"""
|
|
Create the audio embeddings from the audio input, where the audio input
|
|
is pairs of audio features and audio embed lengths. The audio input is
|
|
created by `input_mapper_for_phi4mm_audio`.
|
|
|
|
Args:
|
|
input_ids (torch.Tensor): Input IDs (the prompt in this case,
|
|
before the audio token replication).
|
|
audio_input (Phi4MMAudioInputs): Audio input.
|
|
|
|
Returns:
|
|
NestedTensors: Audio embeddings
|
|
"""
|
|
if audio_input["type"] == "audio_embeds":
|
|
return audio_input["data"]
|
|
|
|
audio_features = audio_input["data"]
|
|
# (e.g. multiple examples) and the second dim is the multi-audio dim
|
|
# (e.g. multiple audios in the same example)
|
|
audio_feature = [i[0] for j in audio_features for i in j]
|
|
audio_feature_len = [i[1].item() for j in audio_features for i in j]
|
|
# Add the batch dim via `squeeze`
|
|
|
|
return self._audio_features_to_embeddings(
|
|
input_ids.unsqueeze(0),
|
|
audio_feature,
|
|
audio_feature_len,
|
|
audio_projection_mode,
|
|
).squeeze(0)
|
|
|
|
def _parse_and_validate_image_input(self,
|
|
**kwargs: object) -> Optional[Dict]:
|
|
pixel_values: Optional[Dict] = kwargs.get("pixel_values")
|
|
if pixel_values is None:
|
|
return None
|
|
|
|
image_sizes = kwargs.get("image_sizes")
|
|
image_attention_mask = kwargs.get("image_attention_mask")
|
|
num_img_tokens = kwargs.get("num_img_tokens")
|
|
assert image_sizes is not None and image_attention_mask is not None\
|
|
and num_img_tokens is not None, "Missing image inputs"
|
|
|
|
if isinstance(pixel_values, list):
|
|
assert pixel_values[0].dim() == 5, "Incorrect image inputs"
|
|
# list len is batch_size.
|
|
# each tensor has dimension: num_img_per_example, num_hd_patches,
|
|
# channels, height, width.
|
|
# need to pad along num_hd_patches.
|
|
# mask size num_img_per_prompt, num_hd_patches, feat_h, heat_w.
|
|
pixel_values = cat_with_pad(pixel_values, dim=0)
|
|
elif isinstance(pixel_values, torch.Tensor):
|
|
# dimension: batch_size, num_img_per_example, num_hd_patches,
|
|
# channels, height, width.
|
|
# we flatten first 2 dims to make it a single large batch for
|
|
# SigLIP Encoder.
|
|
assert pixel_values.dim() == 6, "Incorrect image inputs"
|
|
pixel_values = pixel_values.flatten(0, 1)
|
|
else:
|
|
raise ValueError("Incorrect pixel_values inputs")
|
|
|
|
if isinstance(image_attention_mask, list):
|
|
image_attention_mask = cat_with_pad(image_attention_mask, dim=0)
|
|
elif isinstance(image_attention_mask, torch.Tensor):
|
|
image_attention_mask = image_attention_mask.flatten(0, 1)
|
|
else:
|
|
raise ValueError("Incorrect image_attention_mask inputs")
|
|
|
|
if isinstance(image_sizes, list):
|
|
image_sizes = torch.cat(image_sizes, dim=0)
|
|
elif isinstance(image_sizes, torch.Tensor):
|
|
image_sizes = image_sizes.flatten(0, 1)
|
|
else:
|
|
raise ValueError("Incorrect image_attention_mask inputs")
|
|
|
|
if isinstance(num_img_tokens, list):
|
|
num_img_tokens = [
|
|
n for num_tensor in num_img_tokens
|
|
for n in num_tensor.tolist()
|
|
]
|
|
elif isinstance(num_img_tokens, torch.Tensor):
|
|
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
|
|
else:
|
|
raise ValueError("Incorrect image_attention_mask inputs")
|
|
|
|
return {
|
|
'pixel_values': pixel_values,
|
|
'image_sizes': image_sizes,
|
|
'image_attention_mask': image_attention_mask,
|
|
'num_img_tokens': num_img_tokens,
|
|
}
|
|
|
|
def merge_image_features_to_inputs_embeds(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
inputs_embeds: torch.Tensor,
|
|
image_set_tensors: List[torch.Tensor],
|
|
):
|
|
position_tuple = (input_ids == _IMAGE_PLACEHOLDER_TOKEN_ID).nonzero(
|
|
as_tuple=True)
|
|
|
|
assert all([t.shape[0] == 1 for t in image_set_tensors
|
|
]), 'img_set_tensor should have shape (1, N_tokens, C)'
|
|
# Shape: (merged_N_tokens, C)
|
|
image_set_tensor = torch.cat(image_set_tensors, dim=1).squeeze(0)
|
|
image_set_tensor = image_set_tensor.to(inputs_embeds.dtype).to(
|
|
inputs_embeds.device)
|
|
merged_embeds = inputs_embeds.index_put(
|
|
indices=position_tuple,
|
|
values=image_set_tensor,
|
|
accumulate=False,
|
|
)
|
|
return merged_embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
**kwargs: object,
|
|
) -> torch.Tensor:
|
|
if intermediate_tensors is not None:
|
|
input_ids = None
|
|
inputs_embeds = None
|
|
else:
|
|
# Each entry in this is a pair of audio_features and audio_embed
|
|
# lengths
|
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
|
image_inputs = self._parse_and_validate_image_input(**kwargs)
|
|
|
|
has_audio = audio_input is not None
|
|
has_image = image_inputs is not None
|
|
|
|
if has_audio:
|
|
audio_projection_mode = 'vision' if has_image else 'speech'
|
|
inputs_embeds = self._process_audio_input(
|
|
input_ids, audio_input, audio_projection_mode)
|
|
|
|
if has_image:
|
|
dtype = self.vision_encoder.img_processor.embeddings.\
|
|
patch_embedding.weight.dtype
|
|
pixel_values = image_inputs['pixel_values'].to(dtype)
|
|
image_sizes = image_inputs['image_sizes']
|
|
image_attention_mask = image_inputs['image_attention_mask']
|
|
image_set_tensors = self.vision_encoder(
|
|
pixel_values, image_sizes, image_attention_mask)
|
|
if not has_audio:
|
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
|
|
inputs_embeds = self.merge_image_features_to_inputs_embeds(
|
|
input_ids, inputs_embeds, image_set_tensors)
|
|
|
|
if has_image or has_audio:
|
|
# multi-modal input, we have set inputs_embeds properly in
|
|
# previous steps
|
|
input_ids = None
|
|
else:
|
|
# text-only, we keep using original input_ids
|
|
inputs_embeds = None
|
|
|
|
hidden_states = self.model(
|
|
input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> None:
|
|
weights = ((name, data) for name, data in weights
|
|
if "lora" not in name)
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
|
|
|
def get_mm_mapping(self) -> MultiModelKeys:
|
|
"""
|
|
Get the module prefix in multimodal models
|
|
"""
|
|
return MultiModelKeys.from_string_field(
|
|
language_model="model.",
|
|
connector=["audio_projection_for_vision", "audio_projection"],
|
|
tower_model=["vision_encoder", "embed_tokens_extend"],
|
|
)
|