mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 16:49:08 +08:00
[Model] Initialize Fuyu-8B support (#3924)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
fb6af8bc08
commit
540c0368b1
@ -137,6 +137,10 @@ Decoder-only Language Models
|
|||||||
- Phi-3-Small
|
- Phi-3-Small
|
||||||
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
|
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
|
||||||
-
|
-
|
||||||
|
* - :code:`PersimmonForCausalLM`
|
||||||
|
- Persimmon
|
||||||
|
- :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc.
|
||||||
|
-
|
||||||
* - :code:`QWenLMHeadModel`
|
* - :code:`QWenLMHeadModel`
|
||||||
- Qwen
|
- Qwen
|
||||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||||
@ -178,6 +182,10 @@ Vision Language Models
|
|||||||
- Models
|
- Models
|
||||||
- Example HuggingFace Models
|
- Example HuggingFace Models
|
||||||
- :ref:`LoRA <lora>`
|
- :ref:`LoRA <lora>`
|
||||||
|
* - :code:`FuyuForCausalLM`
|
||||||
|
- Fuyu
|
||||||
|
- :code:`adept/fuyu-8b` etc.
|
||||||
|
-
|
||||||
* - :code:`LlavaForConditionalGeneration`
|
* - :code:`LlavaForConditionalGeneration`
|
||||||
- LLaVA-1.5
|
- LLaVA-1.5
|
||||||
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
|
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
|
||||||
|
|||||||
31
examples/fuyu_example.py
Normal file
31
examples/fuyu_example.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
def run_fuyu():
|
||||||
|
llm = LLM(model="adept/fuyu-8b", max_model_len=4096)
|
||||||
|
|
||||||
|
# single-image prompt
|
||||||
|
prompt = "What is the highest life expectancy at of male?\n"
|
||||||
|
url = "https://huggingface.co/adept/fuyu-8b/resolve/main/chart.png"
|
||||||
|
image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
sampling_params = SamplingParams(temperature=0, max_tokens=64)
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"image": image
|
||||||
|
},
|
||||||
|
},
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_fuyu()
|
||||||
142
tests/models/test_fuyu.py
Normal file
142
tests/models/test_fuyu.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
|
from vllm.sequence import SampleLogprobs
|
||||||
|
from vllm.utils import is_cpu
|
||||||
|
|
||||||
|
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||||
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
|
"stop_sign": "What's the content of the image?\n", # noqa: E501
|
||||||
|
"cherry_blossom": "What is the season?\n",
|
||||||
|
"boardwalk": "What's in this image?\n",
|
||||||
|
})
|
||||||
|
|
||||||
|
models = ["adept/fuyu-8b"]
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||||
|
Optional[SampleLogprobs]]):
|
||||||
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
hf_output_str = output_str.lstrip() + "|ENDOFTEXT|"
|
||||||
|
|
||||||
|
return output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
size_factors: List[float],
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
|
All the image fixtures for the test is under tests/images.
|
||||||
|
For huggingface runner, we provide the PIL images as input.
|
||||||
|
For vllm runner, we provide MultiModalDataDict objects
|
||||||
|
and corresponding vision language config as input.
|
||||||
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
|
The text output is sanitized to be able to compare with hf.
|
||||||
|
"""
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
inputs_per_image = [(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||||
|
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
|
# max_model_len should be greater than image_feature_size
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_model_len=2560,
|
||||||
|
max_num_seqs=1,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True) as vllm_model:
|
||||||
|
vllm_outputs_per_image = [
|
||||||
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=vllm_images)
|
||||||
|
for prompts, vllm_images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
hf_model.model.get_output_embeddings = lambda: \
|
||||||
|
hf_model.model.language_model.get_output_embeddings()
|
||||||
|
eos_token_id = hf_model.processor.tokenizer.eos_token_id
|
||||||
|
hf_outputs_per_image = [
|
||||||
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
images=hf_images,
|
||||||
|
eos_token_id=eos_token_id)
|
||||||
|
for prompts, hf_images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||||
|
vllm_outputs_per_image):
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=[
|
||||||
|
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
|
||||||
|
],
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
target_dtype = "half"
|
||||||
|
if is_cpu():
|
||||||
|
target_dtype = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# No image
|
||||||
|
[],
|
||||||
|
# Single-scale
|
||||||
|
[0.25],
|
||||||
|
# Single-scale, batched
|
||||||
|
[0.25, 0.25, 0.25],
|
||||||
|
# Multi-scale
|
||||||
|
[0.25, 0.2, 0.15],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
|
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||||
|
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
image_assets,
|
||||||
|
model,
|
||||||
|
size_factors=size_factors,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
@ -23,6 +23,7 @@ _GENERATION_MODELS = {
|
|||||||
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
||||||
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
|
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
|
||||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||||
|
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||||
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
|
||||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||||
@ -49,6 +50,7 @@ _GENERATION_MODELS = {
|
|||||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||||
|
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
||||||
"PaliGemmaForConditionalGeneration":
|
"PaliGemmaForConditionalGeneration":
|
||||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||||
|
|||||||
328
vllm/model_executor/models/fuyu.py
Normal file
328
vllm/model_executor/models/fuyu.py
Normal file
@ -0,0 +1,328 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Fuyu model."""
|
||||||
|
import math
|
||||||
|
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import FuyuConfig, FuyuImageProcessor
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
|
from vllm.multimodal.image import (cached_get_image_processor,
|
||||||
|
cached_get_tokenizer)
|
||||||
|
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||||
|
|
||||||
|
from .interfaces import SupportsVision
|
||||||
|
from .utils import merge_vision_embeddings
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Cannot find the following 2 numbers from hf config.
|
||||||
|
_IMAGE_TOKEN_ID = 71011
|
||||||
|
_NEWLINE_TOKEN_ID = 71019
|
||||||
|
|
||||||
|
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
|
||||||
|
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
|
||||||
|
|
||||||
|
|
||||||
|
class FuyuImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""
|
||||||
|
Shape:
|
||||||
|
(batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_num_image_tokens(
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
calculate number of image tokens needed for a given image size
|
||||||
|
The expected Fuyu image prompts is in format:
|
||||||
|
(image_token * ncols + newline_token) * nrows
|
||||||
|
args:
|
||||||
|
image_size: Tuple[int, int] - (width, height) of the image
|
||||||
|
returns:
|
||||||
|
ncols: int - number of image tokens in x direction
|
||||||
|
nrows: int - number of image tokens in y direction
|
||||||
|
"""
|
||||||
|
ncol = math.ceil(width / 30)
|
||||||
|
nrow = math.ceil(height / 30)
|
||||||
|
return ncol, nrow
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_fuyu_image_feature_size():
|
||||||
|
|
||||||
|
return _calculate_num_image_tokens(
|
||||||
|
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||||
|
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_fuyu_image_tokens(ctx: InputContext):
|
||||||
|
ncol, nrow = get_max_fuyu_image_feature_size()
|
||||||
|
return (ncol + 1) * nrow
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int):
|
||||||
|
ncol, nrow = get_max_fuyu_image_feature_size()
|
||||||
|
image_feature_size = get_max_fuyu_image_tokens(ctx)
|
||||||
|
|
||||||
|
token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
|
||||||
|
token_ids += [0] * (seq_len - image_feature_size)
|
||||||
|
return SequenceData(token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_image_for_fuyu(
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
):
|
||||||
|
image = Image.new("RGB", (image_width, image_height), color=0)
|
||||||
|
return {"image": image}
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int):
|
||||||
|
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len)
|
||||||
|
mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||||
|
MAX_IMAGE_FEATURE_SIZE_HEIGHT)
|
||||||
|
return seq_data, mm_data
|
||||||
|
|
||||||
|
|
||||||
|
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
|
||||||
|
data: Image.Image):
|
||||||
|
image_encoding = image_processor.preprocess(data, return_tensors="pt")
|
||||||
|
batch_images = torch.stack([img[0] for img in image_encoding["images"]
|
||||||
|
]).unsqueeze(1)
|
||||||
|
image_unpadded_heights = torch.tensor(
|
||||||
|
image_encoding["image_unpadded_heights"])
|
||||||
|
image_unpadded_widths = torch.tensor(
|
||||||
|
image_encoding["image_unpadded_widths"])
|
||||||
|
|
||||||
|
batch_size = len(image_encoding["images"])
|
||||||
|
image_present = torch.ones(batch_size, 1, 1)
|
||||||
|
model_image_input = image_processor.preprocess_with_tokenizer_info(
|
||||||
|
image_input=batch_images,
|
||||||
|
image_present=image_present,
|
||||||
|
image_unpadded_h=image_unpadded_heights,
|
||||||
|
image_unpadded_w=image_unpadded_widths,
|
||||||
|
image_placeholder_id=_IMAGE_TOKEN_ID,
|
||||||
|
image_newline_id=_NEWLINE_TOKEN_ID,
|
||||||
|
variable_sized=True,
|
||||||
|
)
|
||||||
|
return model_image_input
|
||||||
|
|
||||||
|
|
||||||
|
def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
|
||||||
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||||
|
return llm_inputs
|
||||||
|
|
||||||
|
model_config = ctx.model_config
|
||||||
|
image_data = multi_modal_data["image"]
|
||||||
|
new_multi_modal_data = {}
|
||||||
|
# process image data
|
||||||
|
if isinstance(image_data, Image.Image):
|
||||||
|
# Fuyu's image_processor can also finish token padding
|
||||||
|
image_processor: FuyuImageProcessor = cached_get_image_processor(
|
||||||
|
model_config.model)
|
||||||
|
|
||||||
|
model_image_input = _fuyu_image_preprocess(image_processor, image_data)
|
||||||
|
image_patches = torch.stack([
|
||||||
|
image_patch[0]
|
||||||
|
for image_patch in model_image_input["image_patches"]
|
||||||
|
])
|
||||||
|
new_multi_modal_data["image"] = image_patches
|
||||||
|
|
||||||
|
elif isinstance(image_data, torch.Tensor):
|
||||||
|
raise NotImplementedError("Embeddings input is not supported yet")
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||||
|
|
||||||
|
# process prompts
|
||||||
|
prompt = llm_inputs["prompt"]
|
||||||
|
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||||
|
tokenizer = cached_get_tokenizer(model_config.model)
|
||||||
|
# dim0 is batch_size, dim1 is subseq_size which will always be 1
|
||||||
|
image_input_ids: List[List[
|
||||||
|
torch.Tensor]] = model_image_input["image_input_ids"]
|
||||||
|
image_input_ids = image_input_ids[0][0].tolist()
|
||||||
|
bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
|
||||||
|
boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
|
||||||
|
|
||||||
|
new_prompt = prompt + "\x04"
|
||||||
|
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
|
||||||
|
1:] + boa_token
|
||||||
|
|
||||||
|
return LLMInputs(prompt=new_prompt,
|
||||||
|
prompt_token_ids=new_prompt_token_ids,
|
||||||
|
multi_modal_data=new_multi_modal_data)
|
||||||
|
|
||||||
|
|
||||||
|
def input_mapper_for_fuyu(ctx: InputContext, data: object):
|
||||||
|
model_config = ctx.model_config
|
||||||
|
if isinstance(data, Image.Image):
|
||||||
|
# Fuyu's image_processor can also finish token padding
|
||||||
|
image_processor: FuyuImageProcessor = cached_get_image_processor(
|
||||||
|
model_config.model)
|
||||||
|
|
||||||
|
model_image_input = _fuyu_image_preprocess(image_processor, data)
|
||||||
|
data = torch.stack([
|
||||||
|
image_patch[0]
|
||||||
|
for image_patch in model_image_input["image_patches"]
|
||||||
|
])
|
||||||
|
|
||||||
|
# image has been processed with prompt in input processor
|
||||||
|
return MultiModalInputs({"image_patches": data})
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
|
||||||
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
|
||||||
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
|
||||||
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
|
||||||
|
class FuyuForCausalLM(nn.Module, SupportsVision):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: FuyuConfig,
|
||||||
|
multimodal_config: MultiModalConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.image_token_id = _IMAGE_TOKEN_ID
|
||||||
|
self.image_feature_size = config.patch_size**2 * config.num_channels
|
||||||
|
|
||||||
|
self.vision_embed_tokens = ColumnParallelLinear(
|
||||||
|
self.image_feature_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.language_model = PersimmonForCausalLM(config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||||
|
image_patches = kwargs.pop("image_patches", None)
|
||||||
|
|
||||||
|
if isinstance(image_patches, torch.Tensor):
|
||||||
|
expected_feature_size = self.image_feature_size
|
||||||
|
if image_patches.size(-1) != expected_feature_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected image patches to have the last dimension of "
|
||||||
|
f"{expected_feature_size}, got {image_patches.size(-1)}")
|
||||||
|
image_patches = image_patches.to(
|
||||||
|
self.vision_embed_tokens.weight.dtype)
|
||||||
|
return FuyuImagePixelInputs(type="pixel_values",
|
||||||
|
data=image_patches)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
):
|
||||||
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
|
||||||
|
if image_input is not None:
|
||||||
|
vision_embeddings, _ = self.vision_embed_tokens(
|
||||||
|
image_input["data"])
|
||||||
|
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||||
|
vision_embeddings,
|
||||||
|
self.image_token_id)
|
||||||
|
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.language_model.logits_processor(
|
||||||
|
self.language_model.lm_head, hidden_states, sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.language_model.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name
|
||||||
|
or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
|
||||||
|
if "query_key_value" in name:
|
||||||
|
# copy from vllm/model_executor/models/bloom.py
|
||||||
|
# NOTE: Fuyu's fused QKV's output_dim has the shape of
|
||||||
|
# (num_heads * 3 * head_size), while the
|
||||||
|
# required shape is (3 * num_heads * head_size).
|
||||||
|
# Thus, we need weight conversion.
|
||||||
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
num_heads = self.config.num_attention_heads
|
||||||
|
if output_dim is not None:
|
||||||
|
loaded_weight_shape = loaded_weight.shape
|
||||||
|
loaded_weight = loaded_weight.view(
|
||||||
|
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||||
|
loaded_weight_shape[output_dim + 1:])
|
||||||
|
loaded_weight = loaded_weight.transpose(
|
||||||
|
output_dim, output_dim + 1)
|
||||||
|
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||||
|
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
333
vllm/model_executor/models/persimmon.py
Normal file
333
vllm/model_executor/models/persimmon.py
Normal file
@ -0,0 +1,333 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only persimmon model compatible with HuggingFace weights."""
|
||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import PersimmonConfig
|
||||||
|
from transformers.activations import ReLUSquaredActivation
|
||||||
|
|
||||||
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
|
class PersimmonMLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PersimmonConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.act = ReLUSquaredActivation()
|
||||||
|
|
||||||
|
def forward(self, hidden_states) -> torch.Tensor:
|
||||||
|
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states, _ = self.dense_4h_to_h(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PersimmonAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PersimmonConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
tensor_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.total_num_heads = config.num_attention_heads
|
||||||
|
self.num_heads = self.total_num_heads // tensor_parallel_world_size
|
||||||
|
self.head_dim = self.hidden_size // self.total_num_heads
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.partial_rotary_factor = config.partial_rotary_factor
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
|
assert (self.head_dim * self.total_num_heads) == self.hidden_size
|
||||||
|
assert self.total_num_heads % tensor_parallel_world_size == 0
|
||||||
|
|
||||||
|
self.query_key_value = QKVParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.dense = RowParallelLinear(
|
||||||
|
self.num_heads * self.head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.is_qk_layernorm = config.qk_layernorm
|
||||||
|
|
||||||
|
if self.is_qk_layernorm:
|
||||||
|
self.q_layernorm = nn.LayerNorm(self.head_dim)
|
||||||
|
self.k_layernorm = nn.LayerNorm(self.head_dim)
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=int(self.partial_rotary_factor * self.head_dim),
|
||||||
|
max_position=self.max_position_embeddings,
|
||||||
|
base=self.rope_theta,
|
||||||
|
)
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
scale=self.scaling,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
|
||||||
|
seq_length = x.shape[0]
|
||||||
|
return x.view(seq_length, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# [seq_length, num_heads, head_dim] -> [seq_length, hidden_size]
|
||||||
|
seq_length = x.shape[0]
|
||||||
|
return x.view(seq_length, self.num_heads * self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# [seq_length, 3 x hidden_size]
|
||||||
|
qkv, _ = self.query_key_value(hidden_states)
|
||||||
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
|
|
||||||
|
if self.is_qk_layernorm:
|
||||||
|
# [seq_length, num_heads, head_dim]
|
||||||
|
q = self._split_heads(q)
|
||||||
|
k = self._split_heads(k)
|
||||||
|
|
||||||
|
q = self.q_layernorm(q)
|
||||||
|
k = self.k_layernorm(k)
|
||||||
|
|
||||||
|
q = self._merge_heads(q)
|
||||||
|
k = self._merge_heads(k)
|
||||||
|
|
||||||
|
q, k = self.rotary_emb(position_ids, q, k)
|
||||||
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
|
output, _ = self.dense(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class PersimmonDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PersimmonConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = PersimmonAttention(config=config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.mlp = PersimmonMLP(config, quant_config=quant_config)
|
||||||
|
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
position_ids=position_ids,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
outputs = hidden_states
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class PersimmonModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PersimmonConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||||
|
config.hidden_size)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
PersimmonDecoderLayer(config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
hidden_states = self.layers[i](
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
kv_caches[i],
|
||||||
|
attn_metadata,
|
||||||
|
)
|
||||||
|
hidden_states = self.final_layernorm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class PersimmonForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.model = PersimmonModel(config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if ("rotary_emb.cos_cached" in name
|
||||||
|
or "rotary_emb.sin_cached" in name):
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
|
||||||
|
if "query_key_value" in name:
|
||||||
|
# copy from vllm/model_executor/models/bloom.py
|
||||||
|
# NOTE: Persimmon's fused QKV's output_dim has the shape of
|
||||||
|
# (num_heads * 3 * head_size), while the
|
||||||
|
# required shape is (3 * num_heads * head_size).
|
||||||
|
# Thus, we need weight conversion.
|
||||||
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
num_heads = self.config.num_attention_heads
|
||||||
|
if output_dim is not None:
|
||||||
|
loaded_weight_shape = loaded_weight.shape
|
||||||
|
loaded_weight = loaded_weight.view(
|
||||||
|
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||||
|
loaded_weight_shape[output_dim + 1:])
|
||||||
|
loaded_weight = loaded_weight.transpose(
|
||||||
|
output_dim, output_dim + 1)
|
||||||
|
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||||
|
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
Loading…
x
Reference in New Issue
Block a user