mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[Model] Initialize Phi-3-vision support (#4986)
This commit is contained in:
parent
fa9e385229
commit
daef218b55
@ -135,6 +135,10 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- Phi-3-Small
|
||||
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
|
||||
-
|
||||
* - :code:`Phi3VForCausalLM`
|
||||
- Phi-3-Vision
|
||||
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
|
||||
-
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen
|
||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||
|
||||
57
examples/phi3v_example.py
Normal file
57
examples/phi3v_example.py
Normal file
@ -0,0 +1,57 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.multimodal.image import ImagePixelData
|
||||
|
||||
|
||||
def run_phi3v():
|
||||
model_path = "microsoft/Phi-3-vision-128k-instruct"
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
image_input_type="pixel_values",
|
||||
image_token_id=32044,
|
||||
image_input_shape="1,3,1008,1344",
|
||||
image_feature_size=1921,
|
||||
disable_image_processor=False,
|
||||
)
|
||||
|
||||
image = Image.open("images/cherry_blossom.jpg")
|
||||
|
||||
# single-image prompt
|
||||
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
|
||||
prompt = prompt.replace("<|image_1|>", "<|image|>" * 1921 + "<s>")
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=64)
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"multi_modal_data": ImagePixelData(image),
|
||||
})
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
|
||||
local_directory = "images"
|
||||
|
||||
# Make sure the local directory exists or create it
|
||||
os.makedirs(local_directory, exist_ok=True)
|
||||
|
||||
# Use AWS CLI to sync the directory, assume anonymous access
|
||||
subprocess.check_call([
|
||||
"aws",
|
||||
"s3",
|
||||
"sync",
|
||||
s3_bucket_path,
|
||||
local_directory,
|
||||
"--no-sign-request",
|
||||
])
|
||||
run_phi3v()
|
||||
@ -14,6 +14,7 @@ peft
|
||||
requests
|
||||
ray
|
||||
sentence-transformers # required for embedding
|
||||
torchvision # required for the image processor of phi3v
|
||||
|
||||
# Benchmarking
|
||||
aiohttp
|
||||
|
||||
@ -144,6 +144,7 @@ class HfRunner:
|
||||
model_name: str,
|
||||
dtype: str = "half",
|
||||
*,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
is_embedding_model: bool = False,
|
||||
is_vision_model: bool = False,
|
||||
) -> None:
|
||||
@ -166,11 +167,13 @@ class HfRunner:
|
||||
else:
|
||||
auto_cls = AutoModelForCausalLM
|
||||
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
self.model = self.wrap_device(
|
||||
auto_cls.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
**model_kwargs,
|
||||
))
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
||||
124
tests/models/test_phi3v.py
Normal file
124
tests/models/test_phi3v.py
Normal file
@ -0,0 +1,124 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from ..conftest import IMAGE_FILES
|
||||
|
||||
pytestmark = pytest.mark.llava
|
||||
|
||||
# The image token is placed before "user" on purpose so that the test can pass
|
||||
HF_IMAGE_PROMPTS = [
|
||||
"<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
|
||||
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
|
||||
]
|
||||
|
||||
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
|
||||
|
||||
|
||||
def iter_phi3v_configs(model_name: str):
|
||||
image_hw_to_feature_size = {
|
||||
(1008, 1344): 1921,
|
||||
}
|
||||
|
||||
for (h, w), f in image_hw_to_feature_size.items():
|
||||
for input_type, input_shape in [
|
||||
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
|
||||
]:
|
||||
yield (model_name,
|
||||
VisionLanguageConfig(image_input_type=input_type,
|
||||
image_feature_size=f,
|
||||
image_token_id=32044,
|
||||
image_input_shape=input_shape,
|
||||
image_processor=model_name,
|
||||
image_processor_revision=None))
|
||||
|
||||
|
||||
model_and_vl_config = [
|
||||
*iter_phi3v_configs("microsoft/Phi-3-vision-128k-instruct"),
|
||||
]
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
vlm_config: VisionLanguageConfig, model_id: str):
|
||||
"""Sanitize vllm output to be comparable with hf output.
|
||||
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
input_ids, output_str = vllm_output
|
||||
image_token_id = vlm_config.image_token_id
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
|
||||
hf_input_ids = [
|
||||
input_id if input_id != image_token_id else 0
|
||||
for idx, input_id in enumerate(input_ids)
|
||||
]
|
||||
hf_output_str = output_str \
|
||||
.replace(image_token_str * vlm_config.image_feature_size, "") \
|
||||
.replace("<s>", " ").replace("<|user|>", "") \
|
||||
.replace("<|end|>\n<|assistant|>", " ")
|
||||
|
||||
return hf_input_ids, hf_output_str
|
||||
|
||||
|
||||
target_dtype = "half"
|
||||
if is_cpu():
|
||||
target_dtype = "bfloat16"
|
||||
|
||||
|
||||
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
|
||||
# Since we use _attn_implementation="eager" for hf_runner, here is
|
||||
# numeric difference for longer context and test can't pass
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [8])
|
||||
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
|
||||
model_and_config, dtype: str, max_tokens: int) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalData objects and corresponding
|
||||
vision language config as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vlm_config = model_and_config
|
||||
|
||||
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
||||
hf_model_kwargs = {"_attn_implementation": "eager"}
|
||||
with hf_runner(model_id, dtype=dtype,
|
||||
model_kwargs=hf_model_kwargs) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
|
||||
max_tokens,
|
||||
images=hf_images)
|
||||
|
||||
vllm_image_prompts = [
|
||||
p.replace("<|image_1|>",
|
||||
"<|image|>" * vlm_config.image_feature_size + "<s>")
|
||||
for p in HF_IMAGE_PROMPTS
|
||||
]
|
||||
|
||||
with vllm_runner(model_id,
|
||||
max_model_len=2048,
|
||||
dtype=dtype,
|
||||
enforce_eager=True,
|
||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||
max_tokens,
|
||||
images=vllm_images)
|
||||
|
||||
for i in range(len(HF_IMAGE_PROMPTS)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
|
||||
vllm_outputs[i], vlm_config, model_id)
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
@ -49,6 +49,7 @@ _GENERATION_MODELS = {
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||
|
||||
379
vllm/model_executor/models/phi3v.py
Normal file
379
vllm/model_executor/models/phi3v.py
Normal file
@ -0,0 +1,379 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The vLLM team.
|
||||
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import get_dummy_image_data
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"model.vision_embed_tokens": "vision_embed_tokens",
|
||||
}
|
||||
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
||||
hidden_act="quick_gelu",
|
||||
hidden_size=1024,
|
||||
image_size=336,
|
||||
intermediate_size=4096,
|
||||
num_attention_heads=16,
|
||||
num_channels=3,
|
||||
num_hidden_layers=24,
|
||||
patch_size=14,
|
||||
projection_dim=768)
|
||||
|
||||
|
||||
class Phi3ImageEmbeddingBase(nn.Module):
|
||||
|
||||
def __init__(self, wte=None) -> None:
|
||||
super().__init__()
|
||||
self.wte = wte
|
||||
self.layer_idx: int
|
||||
self.type_feature: str
|
||||
self.img_processor: CLIPVisionModel
|
||||
|
||||
def set_img_features(self, img_features: torch.FloatTensor) -> None:
|
||||
self.img_features = img_features
|
||||
|
||||
def set_img_sizes(self, img_sizes: torch.LongTensor) -> None:
|
||||
self.img_sizes = img_sizes
|
||||
|
||||
def get_img_features(self,
|
||||
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
|
||||
LAYER_IDX = self.layer_idx
|
||||
TYPE_FEATURE = self.type_feature
|
||||
|
||||
img_processor_output = self.img_processor(img_embeds,
|
||||
output_hidden_states=True)
|
||||
img_feature = img_processor_output.hidden_states[LAYER_IDX]
|
||||
|
||||
if TYPE_FEATURE == "patch":
|
||||
patch_feature = img_feature[:, 1:]
|
||||
return patch_feature
|
||||
|
||||
if TYPE_FEATURE == "cls_patch":
|
||||
return img_feature
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
|
||||
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
"""Phi3 Image embedding with HD transform."""
|
||||
|
||||
def __init__(self,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
config: PretrainedConfig,
|
||||
wte=None) -> None:
|
||||
super().__init__(wte)
|
||||
|
||||
self.image_token_id = vision_language_config.image_token_id
|
||||
# n_embed or hidden_size
|
||||
hidden_size = config.n_embd if hasattr(
|
||||
config, 'n_embd') else config.hidden_size
|
||||
|
||||
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
||||
self.img_processor = CLIPVisionModel(clip_config)
|
||||
image_dim_out = config.img_processor['image_dim_out']
|
||||
self.num_img_tokens = config.img_processor['num_img_tokens']
|
||||
|
||||
self.image_dim_out = image_dim_out
|
||||
self.img_sizes = None
|
||||
|
||||
# global_gn and sub_gn for hd transform, serves as line separator
|
||||
self.use_hd_transform = config.embd_layer.get('use_hd_transform',
|
||||
False)
|
||||
self.with_learnable_separator = config.embd_layer.get(
|
||||
'with_learnable_separator', False)
|
||||
self.hd_transform_order = config.embd_layer.get(
|
||||
'hd_transform_order', 'glb_sub')
|
||||
# with_hd_transform and with_learnable_separator should have same value
|
||||
assert self.use_hd_transform and self.with_learnable_separator
|
||||
|
||||
# 1024 * 4, merge spatial to channel dimension
|
||||
self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
|
||||
self.sub_GN = nn.Parameter(
|
||||
torch.empty([1, 1, 1, self.image_dim_out * 4]))
|
||||
|
||||
dim_projection = hidden_size
|
||||
depth = 2
|
||||
layers = [nn.Linear(image_dim_out * 4, dim_projection)]
|
||||
for _ in range(1, depth):
|
||||
layers.extend(
|
||||
[nn.GELU(),
|
||||
nn.Linear(dim_projection, dim_projection)])
|
||||
self.img_projection = nn.Sequential(*layers)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.img_features = None
|
||||
|
||||
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
||||
self.type_feature = config.img_processor.get('type_feature', 'patch')
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.LongTensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes=None) -> torch.FloatTensor:
|
||||
"""process and merge text embeddings with image embeddings."""
|
||||
|
||||
img_embeds = pixel_values
|
||||
img_sizes = image_sizes
|
||||
|
||||
if self.img_features is not None:
|
||||
img_embeds = self.img_features.clone()
|
||||
self.img_features = None
|
||||
|
||||
if self.img_sizes is not None:
|
||||
img_sizes = self.img_sizes
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
positions = torch.nonzero(input_ids == self.image_token_id)
|
||||
|
||||
select = False
|
||||
|
||||
target_device = self.img_projection[0].bias.device
|
||||
target_dtype = self.img_projection[0].bias.dtype
|
||||
|
||||
if len(positions.tolist()) > 0:
|
||||
# if self.use_hd_transform and img_sizes:
|
||||
# img_embeds: (num_images, max_num_crops, 3, H, W)
|
||||
# img_sizes: (num_images, 2).view(1, -1)
|
||||
|
||||
bs = img_embeds.shape[0]
|
||||
# Nx(HW)xC
|
||||
img_features = self.get_img_features(img_embeds.flatten(0, 1))
|
||||
base_feat_height = base_feat_width = int(
|
||||
img_features.shape[1]**0.5)
|
||||
|
||||
# bs x max_num_crops x (24x24) x C
|
||||
img_features = img_features.view(
|
||||
bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
|
||||
C = self.image_dim_out
|
||||
H = base_feat_height
|
||||
|
||||
output_imgs = []
|
||||
output_len = []
|
||||
|
||||
if isinstance(img_sizes, torch.Tensor):
|
||||
img_sizes.squeeze_(0)
|
||||
|
||||
for _bs in range(bs):
|
||||
h, w = img_sizes
|
||||
h = h // 336
|
||||
w = w // 336
|
||||
B_ = h * w
|
||||
|
||||
# 1 x (24x24) x 1024
|
||||
global_img_feature = img_features[_bs, :1]
|
||||
|
||||
# 1 x 12 x 12 x 4096
|
||||
glb_img = global_img_feature \
|
||||
.reshape(1, H // 2, 2, H // 2, 2,C) \
|
||||
.permute(0, 1, 3, 2, 4, 5) \
|
||||
.reshape(1, H // 2, H // 2, 4 * C)
|
||||
temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)
|
||||
|
||||
# 1 x 156 x 4096
|
||||
glb_img = torch.cat([glb_img, temp_glb_GN],
|
||||
dim=2).reshape(1, -1, 4 * C)
|
||||
|
||||
# (max_num_crops-1) x (12x12) x C
|
||||
sub_img = img_features[_bs, 1:]
|
||||
# 16x574x1024
|
||||
# get rid of padding sub_img
|
||||
sub_img = sub_img[:B_]
|
||||
|
||||
sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \
|
||||
.permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C)
|
||||
sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \
|
||||
.permute(0, 1, 3, 2, 4, 5) \
|
||||
.reshape(1, h * 12, w * 12, 4 * C)
|
||||
temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
|
||||
sub_img = torch.cat([sub_img, temp_sub_GN],
|
||||
dim=2).reshape(1, -1, 4 * C)
|
||||
# (1, num_img_tokens, 1024*4)
|
||||
|
||||
# glb + sub
|
||||
if self.hd_transform_order == 'glb_sub':
|
||||
output_imgs.append(
|
||||
torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
|
||||
elif self.hd_transform_order == 'sub_glb':
|
||||
output_imgs.append(
|
||||
torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
|
||||
|
||||
temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
|
||||
output_len.append(temp_len)
|
||||
|
||||
num_img_tokens = output_len
|
||||
img_set_tensor = []
|
||||
for _output_img in output_imgs:
|
||||
img_feature_proj = self.img_projection(
|
||||
_output_img.to(target_device, target_dtype))
|
||||
img_set_tensor.append(img_feature_proj)
|
||||
select = True
|
||||
|
||||
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
|
||||
|
||||
hidden_states = self.wte(input_ids)
|
||||
|
||||
if select:
|
||||
idx = 0
|
||||
for i, cnt in enumerate(num_img_tokens):
|
||||
hidden_states[positions[idx, 0],
|
||||
positions[idx, 1]:positions[idx, 1] +
|
||||
cnt] = (img_set_tensor[i].to(
|
||||
hidden_states.device, hidden_states.dtype))
|
||||
idx += cnt
|
||||
|
||||
return hidden_states.squeeze(0)
|
||||
|
||||
|
||||
class Phi3VImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
"""Shape: (batch_size, 2)"""
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input()
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||
class Phi3VForCausalLM(VisionLanguageModelBase):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__(vision_language_config)
|
||||
self.config = config
|
||||
self.model = LlamaModel(config, cache_config, quant_config)
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
vision_language_config, config, self.model.embed_tokens)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
|
||||
expected_input_type = self.vision_language_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if expected_input_type != ImageInputType.PIXEL_VALUES:
|
||||
raise ValueError(
|
||||
f"Unexpected image input type: {expected_input_type}."
|
||||
"Phi3v only support pixel_values input currently.")
|
||||
|
||||
if pixel_values is not None and image_sizes is not None:
|
||||
return Phi3VImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
image_sizes=image_sizes)
|
||||
|
||||
return None
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata, **kwargs: object):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
inputs_embeds = self.vision_embed_tokens(
|
||||
input_ids, image_input["data"], image_input["image_sizes"])
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# We only do sharding for language model
|
||||
# and not vision model for now.
|
||||
if "vision_embed_tokens" in name and self.vision_embed_tokens:
|
||||
continue
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
@ -79,6 +79,8 @@ def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
|
||||
|
||||
if config.hf_config.model_type in ("llava", "llava_next"):
|
||||
full_prompt = f"{image_prompt}\n{text_prompt}"
|
||||
elif config.hf_config.model_type == 'phi3_v':
|
||||
full_prompt = f"{image_prompt}<s>\n{text_prompt}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported model type: {config.hf_config.model_type}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user