mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 06:55:01 +08:00
[Model] Refactor and decouple phi3v image embedding (#6621)
This commit is contained in:
parent
b6df37f943
commit
25e778aa16
@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
input_processor_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
from .utils import merge_vision_embeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -71,9 +72,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
||||
|
||||
class Phi3ImageEmbeddingBase(nn.Module):
|
||||
|
||||
def __init__(self, wte=None) -> None:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.wte = wte
|
||||
self.layer_idx: int
|
||||
self.type_feature: str
|
||||
self.img_processor: CLIPVisionModel
|
||||
@ -100,10 +100,9 @@ class Phi3ImageEmbeddingBase(nn.Module):
|
||||
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
"""Phi3 Image embedding with HD transform."""
|
||||
|
||||
def __init__(self, config: PretrainedConfig, wte=None) -> None:
|
||||
super().__init__(wte)
|
||||
def __init__(self, config: PretrainedConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_token_id = _IMAGE_TOKEN_ID
|
||||
# n_embed or hidden_size
|
||||
hidden_size = config.n_embd if hasattr(
|
||||
config, 'n_embd') else config.hidden_size
|
||||
@ -149,118 +148,115 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
nn.Linear(dim_projection, dim_projection)])
|
||||
self.img_projection = nn.Sequential(*layers)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.type_feature = config.img_processor.get('type_feature', 'patch')
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
def forward(self, pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor) -> torch.FloatTensor:
|
||||
"""process and merge text embeddings with image embeddings."""
|
||||
"""
|
||||
process image and return vision embeddings.
|
||||
|
||||
# (batch_size, max_num_crops, 3, height, width)
|
||||
img_embeds = pixel_values
|
||||
pixel_values: (num_images, num_crops, c, h, w)
|
||||
output: (num_images, num_img_tokens, hidden_size)
|
||||
"""
|
||||
num_images, num_crops, c, h, w = pixel_values.shape
|
||||
pixel_values = pixel_values.flatten(0, 1)
|
||||
img_features = self.get_img_features(pixel_values)
|
||||
img_features = img_features.reshape(num_images, num_crops, -1,
|
||||
self.image_dim_out)
|
||||
image_features_proj = self.hd_feature_transform(
|
||||
img_features, image_sizes)
|
||||
return image_features_proj
|
||||
|
||||
# (batch_size, 2)
|
||||
img_sizes = image_sizes
|
||||
def hd_feature_transform(self, image_features, image_sizes):
|
||||
"""
|
||||
image_features: (num_images, num_crops+1, 24*24, 1024)
|
||||
"""
|
||||
assert (
|
||||
self.hd_transform_order == 'sub_glb'
|
||||
), f'hd_transform_order `{self.hd_transform_order}` not implemented'
|
||||
if isinstance(self.img_projection, nn.Sequential):
|
||||
target_device = self.img_projection[0].bias.device
|
||||
target_dtype = self.img_projection[0].bias.dtype
|
||||
else: # It's a single nn.Linear layer
|
||||
target_device = self.img_projection.bias.device
|
||||
target_dtype = self.img_projection.bias.dtype
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
global_image_features = image_features[:,
|
||||
0] # (num_images, 24*24, 1024)
|
||||
# global feature can be viewed as a special HD case with num_crops 1x1
|
||||
global_image_features_hd = self.reshape_hd_patches_2x2merge(
|
||||
global_image_features, 1, 1)
|
||||
global_image_features_hd_newline = self.add_image_newline(
|
||||
global_image_features_hd)
|
||||
|
||||
positions = torch.nonzero(input_ids == self.image_token_id)
|
||||
all_image_embeddings = []
|
||||
# need a for loop to process each image because of different image sizes
|
||||
# (patch arrangement is different for each image)
|
||||
for i, img_size in enumerate(image_sizes):
|
||||
h, w = img_size
|
||||
h_crop = h // 336
|
||||
w_crop = w // 336
|
||||
num_crops = h_crop * w_crop
|
||||
|
||||
select = False
|
||||
# NOTE: real num_crops is padded
|
||||
# (num_crops, 24*24, 1024)
|
||||
sub_image_features = image_features[i, 1:1 + num_crops]
|
||||
sub_image_features_hd = self.reshape_hd_patches_2x2merge(
|
||||
sub_image_features, h_crop, w_crop)
|
||||
sub_image_features_hd_newline = self.add_image_newline(
|
||||
sub_image_features_hd)
|
||||
|
||||
target_dtype = self.img_projection[0].bias.dtype
|
||||
# [sub features, separator, global features]
|
||||
all_image_embeddings.append(
|
||||
torch.cat([
|
||||
sub_image_features_hd_newline.squeeze(
|
||||
0), # (h_crop*12*(w_crop*12+1), 4096)
|
||||
self.glb_GN.squeeze(0),
|
||||
global_image_features_hd_newline[i],
|
||||
]))
|
||||
|
||||
if len(positions.tolist()) > 0:
|
||||
# if self.use_hd_transform and img_sizes:
|
||||
# img_embeds: (num_images, max_num_crops, 3, H, W)
|
||||
# img_sizes: (num_images, 2).view(1, -1)
|
||||
image_features_proj = self.img_projection(
|
||||
torch.stack(all_image_embeddings).to(target_device, target_dtype)
|
||||
) # (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)
|
||||
|
||||
bs = img_embeds.shape[0]
|
||||
# Nx(HW)xC
|
||||
img_features = self.get_img_features(img_embeds.flatten(0, 1))
|
||||
base_feat_height = base_feat_width = int(
|
||||
img_features.shape[1]**0.5)
|
||||
return image_features_proj
|
||||
|
||||
# bs x max_num_crops x (24x24) x C
|
||||
img_features = img_features.view(
|
||||
bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
|
||||
C = self.image_dim_out
|
||||
H = base_feat_height
|
||||
def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
|
||||
"""
|
||||
image_features: (num_images*num_crops, 24*24, 1024)
|
||||
output: (num_images, h_crop*12, w_crop*12, 4096)
|
||||
where h_crop*w_crop == num_crops
|
||||
"""
|
||||
N, L, C = image_features.shape
|
||||
assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
|
||||
num_images = N // (h_crop * w_crop)
|
||||
H = int(L**0.5)
|
||||
image_features_hd = (
|
||||
image_features.reshape(N, H, H, C) # N, 24, 24, 1024
|
||||
.reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
|
||||
.permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
|
||||
.reshape(N, -1, 4 * C) # N, 144, 4096
|
||||
.reshape(num_images, h_crop, w_crop, H // 2, H // 2,
|
||||
-1) # n_img, h_crop, w_crop, 12, 12, 4096
|
||||
.permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
|
||||
.reshape(num_images, h_crop * H // 2, w_crop * H // 2,
|
||||
4 * C) # n_img, h_crop*12, w_crop*12, 4096
|
||||
)
|
||||
return image_features_hd
|
||||
|
||||
output_imgs = []
|
||||
output_len = []
|
||||
|
||||
for _bs in range(bs):
|
||||
h, w = img_sizes[_bs]
|
||||
h = h // 336
|
||||
w = w // 336
|
||||
B_ = h * w
|
||||
|
||||
# 1 x (24x24) x 1024
|
||||
global_img_feature = img_features[_bs, :1]
|
||||
|
||||
# 1 x 12 x 12 x 4096
|
||||
glb_img = global_img_feature \
|
||||
.reshape(1, H // 2, 2, H // 2, 2,C) \
|
||||
.permute(0, 1, 3, 2, 4, 5) \
|
||||
.reshape(1, H // 2, H // 2, 4 * C)
|
||||
temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)
|
||||
|
||||
# 1 x 156 x 4096
|
||||
glb_img = torch.cat([glb_img, temp_glb_GN],
|
||||
dim=2).reshape(1, -1, 4 * C)
|
||||
|
||||
# (max_num_crops-1) x (12x12) x C
|
||||
sub_img = img_features[_bs, 1:]
|
||||
# 16x574x1024
|
||||
# get rid of padding sub_img
|
||||
sub_img = sub_img[:B_]
|
||||
|
||||
sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \
|
||||
.permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C)
|
||||
sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \
|
||||
.permute(0, 1, 3, 2, 4, 5) \
|
||||
.reshape(1, h * 12, w * 12, 4 * C)
|
||||
temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
|
||||
sub_img = torch.cat([sub_img, temp_sub_GN],
|
||||
dim=2).reshape(1, -1, 4 * C)
|
||||
# (1, num_img_tokens, 1024*4)
|
||||
|
||||
# glb + sub
|
||||
if self.hd_transform_order == 'glb_sub':
|
||||
output_imgs.append(
|
||||
torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
|
||||
elif self.hd_transform_order == 'sub_glb':
|
||||
output_imgs.append(
|
||||
torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
|
||||
|
||||
temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
|
||||
output_len.append(temp_len)
|
||||
|
||||
num_img_tokens = output_len
|
||||
img_set_tensor = []
|
||||
for _output_img in output_imgs:
|
||||
img_feature_proj = self.img_projection(
|
||||
_output_img.to(target_dtype))
|
||||
img_set_tensor.append(img_feature_proj)
|
||||
select = True
|
||||
|
||||
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
|
||||
|
||||
hidden_states = self.wte(input_ids)
|
||||
|
||||
if select:
|
||||
idx = 0
|
||||
for i, cnt in enumerate(num_img_tokens):
|
||||
hidden_states[positions[idx, 0],
|
||||
positions[idx, 1]:positions[idx, 1] +
|
||||
cnt] = (img_set_tensor[i].to(
|
||||
hidden_states.dtype))
|
||||
idx += cnt
|
||||
|
||||
return hidden_states.squeeze(0)
|
||||
def add_image_newline(self, image_features_hd):
|
||||
"""
|
||||
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
|
||||
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
|
||||
"""
|
||||
num_images, h, w, hid_dim = image_features_hd.shape
|
||||
# add the newline token to the HD image feature patches
|
||||
newline_embeddings = self.sub_GN.expand(num_images, h, -1,
|
||||
-1) # (n_img, h, 1, hid_dim)
|
||||
image_features_hd_newline = torch.cat(
|
||||
[image_features_hd, newline_embeddings],
|
||||
dim=2).reshape(num_images, -1, hid_dim)
|
||||
return image_features_hd_newline
|
||||
|
||||
|
||||
class Phi3VImagePixelInputs(TypedDict):
|
||||
@ -458,12 +454,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.image_token_id = _IMAGE_TOKEN_ID
|
||||
|
||||
self.model = LlamaModel(config, cache_config, quant_config)
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
config, self.model.embed_tokens)
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
@ -530,9 +526,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
inputs_embeds = self.vision_embed_tokens(
|
||||
input_ids, image_input["data"], image_input["image_sizes"])
|
||||
|
||||
vision_embeddings = self.vision_embed_tokens(
|
||||
image_input["data"], image_input["image_sizes"])
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||
vision_embeddings,
|
||||
self.image_token_id)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user