[Model] Initialize Phi-3-vision support (#4986)

This commit is contained in:
Isotr0py 2024-06-18 10:34:33 +08:00 committed by GitHub
parent fa9e385229
commit daef218b55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 571 additions and 0 deletions

View File

@ -135,6 +135,10 @@ Alongside each architecture, we include some popular models that use it.
- Phi-3-Small
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
-
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
-
* - :code:`QWenLMHeadModel`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.

57
examples/phi3v_example.py Normal file
View File

@ -0,0 +1,57 @@
import os
import subprocess
from PIL import Image
from vllm import LLM, SamplingParams
from vllm.multimodal.image import ImagePixelData
def run_phi3v():
model_path = "microsoft/Phi-3-vision-128k-instruct"
llm = LLM(
model=model_path,
trust_remote_code=True,
max_model_len=4096,
image_input_type="pixel_values",
image_token_id=32044,
image_input_shape="1,3,1008,1344",
image_feature_size=1921,
disable_image_processor=False,
)
image = Image.open("images/cherry_blossom.jpg")
# single-image prompt
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
prompt = prompt.replace("<|image_1|>", "<|image|>" * 1921 + "<s>")
sampling_params = SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate({
"prompt": prompt,
"sampling_params": sampling_params,
"multi_modal_data": ImagePixelData(image),
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"
# Make sure the local directory exists or create it
os.makedirs(local_directory, exist_ok=True)
# Use AWS CLI to sync the directory, assume anonymous access
subprocess.check_call([
"aws",
"s3",
"sync",
s3_bucket_path,
local_directory,
"--no-sign-request",
])
run_phi3v()

View File

@ -14,6 +14,7 @@ peft
requests
ray
sentence-transformers # required for embedding
torchvision # required for the image processor of phi3v
# Benchmarking
aiohttp

View File

@ -144,6 +144,7 @@ class HfRunner:
model_name: str,
dtype: str = "half",
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_vision_model: bool = False,
) -> None:
@ -166,11 +167,13 @@ class HfRunner:
else:
auto_cls = AutoModelForCausalLM
model_kwargs = model_kwargs if model_kwargs is not None else {}
self.model = self.wrap_device(
auto_cls.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs,
))
self.tokenizer = AutoTokenizer.from_pretrained(

124
tests/models/test_phi3v.py Normal file
View File

@ -0,0 +1,124 @@
from typing import List, Tuple
import pytest
from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
from vllm.utils import is_cpu
from ..conftest import IMAGE_FILES
pytestmark = pytest.mark.llava
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = [
"<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
]
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
def iter_phi3v_configs(model_name: str):
image_hw_to_feature_size = {
(1008, 1344): 1921,
}
for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=32044,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))
model_and_vl_config = [
*iter_phi3v_configs("microsoft/Phi-3-vision-128k-instruct"),
]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
input_ids, output_str = vllm_output
image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)
hf_input_ids = [
input_id if input_id != image_token_id else 0
for idx, input_id in enumerate(input_ids)
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "") \
.replace("<s>", " ").replace("<|user|>", "") \
.replace("<|end|>\n<|assistant|>", " ")
return hf_input_ids, hf_output_str
target_dtype = "half"
if is_cpu():
target_dtype = "bfloat16"
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
# Since we use _attn_implementation="eager" for hf_runner, here is
# numeric difference for longer context and test can't pass
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [8])
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
model_and_config, dtype: str, max_tokens: int) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model_id, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images)
vllm_image_prompts = [
p.replace("<|image_1|>",
"<|image|>" * vlm_config.image_feature_size + "<s>")
for p in HF_IMAGE_PROMPTS
]
with vllm_runner(model_id,
max_model_len=2048,
dtype=dtype,
enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
for i in range(len(HF_IMAGE_PROMPTS)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
vllm_outputs[i], vlm_config, model_id)
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

View File

@ -49,6 +49,7 @@ _GENERATION_MODELS = {
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),

View File

@ -0,0 +1,379 @@
# coding=utf-8
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig
from transformers.utils import logging
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data
from vllm.sequence import SamplerOutput
logger = logging.get_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"model.vision_embed_tokens": "vision_embed_tokens",
}
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
hidden_act="quick_gelu",
hidden_size=1024,
image_size=336,
intermediate_size=4096,
num_attention_heads=16,
num_channels=3,
num_hidden_layers=24,
patch_size=14,
projection_dim=768)
class Phi3ImageEmbeddingBase(nn.Module):
def __init__(self, wte=None) -> None:
super().__init__()
self.wte = wte
self.layer_idx: int
self.type_feature: str
self.img_processor: CLIPVisionModel
def set_img_features(self, img_features: torch.FloatTensor) -> None:
self.img_features = img_features
def set_img_sizes(self, img_sizes: torch.LongTensor) -> None:
self.img_sizes = img_sizes
def get_img_features(self,
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
LAYER_IDX = self.layer_idx
TYPE_FEATURE = self.type_feature
img_processor_output = self.img_processor(img_embeds,
output_hidden_states=True)
img_feature = img_processor_output.hidden_states[LAYER_IDX]
if TYPE_FEATURE == "patch":
patch_feature = img_feature[:, 1:]
return patch_feature
if TYPE_FEATURE == "cls_patch":
return img_feature
raise NotImplementedError
# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform."""
def __init__(self,
vision_language_config: VisionLanguageConfig,
config: PretrainedConfig,
wte=None) -> None:
super().__init__(wte)
self.image_token_id = vision_language_config.image_token_id
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
self.img_processor = CLIPVisionModel(clip_config)
image_dim_out = config.img_processor['image_dim_out']
self.num_img_tokens = config.img_processor['num_img_tokens']
self.image_dim_out = image_dim_out
self.img_sizes = None
# global_gn and sub_gn for hd transform, serves as line separator
self.use_hd_transform = config.embd_layer.get('use_hd_transform',
False)
self.with_learnable_separator = config.embd_layer.get(
'with_learnable_separator', False)
self.hd_transform_order = config.embd_layer.get(
'hd_transform_order', 'glb_sub')
# with_hd_transform and with_learnable_separator should have same value
assert self.use_hd_transform and self.with_learnable_separator
# 1024 * 4, merge spatial to channel dimension
self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
self.sub_GN = nn.Parameter(
torch.empty([1, 1, 1, self.image_dim_out * 4]))
dim_projection = hidden_size
depth = 2
layers = [nn.Linear(image_dim_out * 4, dim_projection)]
for _ in range(1, depth):
layers.extend(
[nn.GELU(),
nn.Linear(dim_projection, dim_projection)])
self.img_projection = nn.Sequential(*layers)
self.vocab_size = config.vocab_size
self.img_features = None
self.layer_idx = config.img_processor.get('layer_idx', -2)
self.type_feature = config.img_processor.get('type_feature', 'patch')
def forward(self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
image_sizes=None) -> torch.FloatTensor:
"""process and merge text embeddings with image embeddings."""
img_embeds = pixel_values
img_sizes = image_sizes
if self.img_features is not None:
img_embeds = self.img_features.clone()
self.img_features = None
if self.img_sizes is not None:
img_sizes = self.img_sizes
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
positions = torch.nonzero(input_ids == self.image_token_id)
select = False
target_device = self.img_projection[0].bias.device
target_dtype = self.img_projection[0].bias.dtype
if len(positions.tolist()) > 0:
# if self.use_hd_transform and img_sizes:
# img_embeds: (num_images, max_num_crops, 3, H, W)
# img_sizes: (num_images, 2).view(1, -1)
bs = img_embeds.shape[0]
# Nx(HW)xC
img_features = self.get_img_features(img_embeds.flatten(0, 1))
base_feat_height = base_feat_width = int(
img_features.shape[1]**0.5)
# bs x max_num_crops x (24x24) x C
img_features = img_features.view(
bs, -1, base_feat_height * base_feat_width, self.image_dim_out)
C = self.image_dim_out
H = base_feat_height
output_imgs = []
output_len = []
if isinstance(img_sizes, torch.Tensor):
img_sizes.squeeze_(0)
for _bs in range(bs):
h, w = img_sizes
h = h // 336
w = w // 336
B_ = h * w
# 1 x (24x24) x 1024
global_img_feature = img_features[_bs, :1]
# 1 x 12 x 12 x 4096
glb_img = global_img_feature \
.reshape(1, H // 2, 2, H // 2, 2,C) \
.permute(0, 1, 3, 2, 4, 5) \
.reshape(1, H // 2, H // 2, 4 * C)
temp_glb_GN = self.sub_GN.repeat(1, H // 2, 1, 1)
# 1 x 156 x 4096
glb_img = torch.cat([glb_img, temp_glb_GN],
dim=2).reshape(1, -1, 4 * C)
# (max_num_crops-1) x (12x12) x C
sub_img = img_features[_bs, 1:]
# 16x574x1024
# get rid of padding sub_img
sub_img = sub_img[:B_]
sub_img = sub_img.reshape(B_, H // 2, 2, H // 2, 2, C) \
.permute(0, 1, 3, 2, 4, 5).reshape(B_, -1, 4 * C)
sub_img = sub_img.reshape(1, h, w, 12, 12, -1) \
.permute(0, 1, 3, 2, 4, 5) \
.reshape(1, h * 12, w * 12, 4 * C)
temp_sub_GN = self.sub_GN.repeat(1, h * 12, 1, 1)
sub_img = torch.cat([sub_img, temp_sub_GN],
dim=2).reshape(1, -1, 4 * C)
# (1, num_img_tokens, 1024*4)
# glb + sub
if self.hd_transform_order == 'glb_sub':
output_imgs.append(
torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
elif self.hd_transform_order == 'sub_glb':
output_imgs.append(
torch.cat([sub_img, self.glb_GN, glb_img], dim=1))
temp_len = int((h * w + 1) * 144 + 1 + (h + 1) * 12)
output_len.append(temp_len)
num_img_tokens = output_len
img_set_tensor = []
for _output_img in output_imgs:
img_feature_proj = self.img_projection(
_output_img.to(target_device, target_dtype))
img_set_tensor.append(img_feature_proj)
select = True
input_ids.clamp_min_(0).clamp_max_(self.vocab_size)
hidden_states = self.wte(input_ids)
if select:
idx = 0
for i, cnt in enumerate(num_img_tokens):
hidden_states[positions[idx, 0],
positions[idx, 1]:positions[idx, 1] +
cnt] = (img_set_tensor[i].to(
hidden_states.device, hidden_states.dtype))
idx += cnt
return hidden_states.squeeze(0)
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
image_sizes: torch.Tensor
"""Shape: (batch_size, 2)"""
@MULTIMODAL_REGISTRY.register_image_pixel_input()
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class Phi3VForCausalLM(VisionLanguageModelBase):
def __init__(self,
config: PretrainedConfig,
vision_language_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config)
self.config = config
self.model = LlamaModel(config, cache_config, quant_config)
self.vision_embed_tokens = Phi3HDImageEmbedding(
vision_language_config, config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
expected_input_type = self.vision_language_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type != ImageInputType.PIXEL_VALUES:
raise ValueError(
f"Unexpected image input type: {expected_input_type}."
"Phi3v only support pixel_values input currently.")
if pixel_values is not None and image_sizes is not None:
return Phi3VImagePixelInputs(type="pixel_values",
data=pixel_values,
image_sizes=image_sizes)
return None
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, **kwargs: object):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.vision_embed_tokens(
input_ids, image_input["data"], image_input["image_sizes"])
input_ids = None
else:
inputs_embeds = None
hidden_states = self.model(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# We only do sharding for language model
# and not vision model for now.
if "vision_embed_tokens" in name and self.vision_embed_tokens:
continue
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -79,6 +79,8 @@ def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
if config.hf_config.model_type in ("llava", "llava_next"):
full_prompt = f"{image_prompt}\n{text_prompt}"
elif config.hf_config.model_type == 'phi3_v':
full_prompt = f"{image_prompt}<s>\n{text_prompt}"
else:
raise ValueError(
f"Unsupported model type: {config.hf_config.model_type}")