[Model] Initialize Fuyu-8B support (#3924)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Isotr0py 2024-07-14 13:27:14 +08:00 committed by GitHub
parent fb6af8bc08
commit 540c0368b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 844 additions and 0 deletions

View File

@ -137,6 +137,10 @@ Decoder-only Language Models
- Phi-3-Small
- :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc.
-
* - :code:`PersimmonForCausalLM`
- Persimmon
- :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc.
-
* - :code:`QWenLMHeadModel`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
@ -178,6 +182,10 @@ Vision Language Models
- Models
- Example HuggingFace Models
- :ref:`LoRA <lora>`
* - :code:`FuyuForCausalLM`
- Fuyu
- :code:`adept/fuyu-8b` etc.
-
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.

31
examples/fuyu_example.py Normal file
View File

@ -0,0 +1,31 @@
import requests
from PIL import Image
from vllm import LLM, SamplingParams
def run_fuyu():
llm = LLM(model="adept/fuyu-8b", max_model_len=4096)
# single-image prompt
prompt = "What is the highest life expectancy at of male?\n"
url = "https://huggingface.co/adept/fuyu-8b/resolve/main/chart.png"
image = Image.open(requests.get(url, stream=True).raw)
sampling_params = SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {
"image": image
},
},
sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
run_fuyu()

142
tests/models/test_fuyu.py Normal file
View File

@ -0,0 +1,142 @@
from typing import List, Optional, Tuple, Type
import pytest
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "What's the content of the image?\n", # noqa: E501
"cherry_blossom": "What is the season?\n",
"boardwalk": "What's in this image?\n",
})
models = ["adept/fuyu-8b"]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]]):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str.lstrip() + "|ENDOFTEXT|"
return output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=2560,
max_num_seqs=1,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=vllm_images)
for prompts, vllm_images in inputs_per_image
]
with hf_runner(model, dtype=dtype) as hf_model:
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.get_output_embeddings()
eos_token_id = hf_model.processor.tokenizer.eos_token_id
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=hf_images,
eos_token_id=eos_token_id)
for prompts, hf_images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
target_dtype = "half"
if is_cpu():
target_dtype = "bfloat16"
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[0.25],
# Single-scale, batched
[0.25, 0.25, 0.25],
# Multi-scale
[0.25, 0.2, 0.15],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

View File

@ -23,6 +23,7 @@ _GENERATION_MODELS = {
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
@ -49,6 +50,7 @@ _GENERATION_MODELS = {
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PaliGemmaForConditionalGeneration":
("paligemma", "PaliGemmaForConditionalGeneration"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),

View File

@ -0,0 +1,328 @@
# coding=utf-8
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
import torch.utils.checkpoint
from PIL import Image
from transformers import FuyuConfig, FuyuImageProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
logger = init_logger(__name__)
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
class FuyuImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""
Shape:
(batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
"""
def _calculate_num_image_tokens(
height: int,
width: int,
) -> Tuple[int, int]:
"""
calculate number of image tokens needed for a given image size
The expected Fuyu image prompts is in format:
(image_token * ncols + newline_token) * nrows
args:
image_size: Tuple[int, int] - (width, height) of the image
returns:
ncols: int - number of image tokens in x direction
nrows: int - number of image tokens in y direction
"""
ncol = math.ceil(width / 30)
nrow = math.ceil(height / 30)
return ncol, nrow
def get_max_fuyu_image_feature_size():
return _calculate_num_image_tokens(
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
def get_max_fuyu_image_tokens(ctx: InputContext):
ncol, nrow = get_max_fuyu_image_feature_size()
return (ncol + 1) * nrow
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int):
ncol, nrow = get_max_fuyu_image_feature_size()
image_feature_size = get_max_fuyu_image_tokens(ctx)
token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
token_ids += [0] * (seq_len - image_feature_size)
return SequenceData(token_ids)
def dummy_image_for_fuyu(
image_width: int,
image_height: int,
):
image = Image.new("RGB", (image_width, image_height), color=0)
return {"image": image}
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len)
mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH,
MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return seq_data, mm_data
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
data: Image.Image):
image_encoding = image_processor.preprocess(data, return_tensors="pt")
batch_images = torch.stack([img[0] for img in image_encoding["images"]
]).unsqueeze(1)
image_unpadded_heights = torch.tensor(
image_encoding["image_unpadded_heights"])
image_unpadded_widths = torch.tensor(
image_encoding["image_unpadded_widths"])
batch_size = len(image_encoding["images"])
image_present = torch.ones(batch_size, 1, 1)
model_image_input = image_processor.preprocess_with_tokenizer_info(
image_input=batch_images,
image_present=image_present,
image_unpadded_h=image_unpadded_heights,
image_unpadded_w=image_unpadded_widths,
image_placeholder_id=_IMAGE_TOKEN_ID,
image_newline_id=_NEWLINE_TOKEN_ID,
variable_sized=True,
)
return model_image_input
def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
image_data = multi_modal_data["image"]
new_multi_modal_data = {}
# process image data
if isinstance(image_data, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, image_data)
image_patches = torch.stack([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
new_multi_modal_data["image"] = image_patches
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
# process prompts
prompt = llm_inputs["prompt"]
prompt_token_ids = llm_inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(model_config.model)
# dim0 is batch_size, dim1 is subseq_size which will always be 1
image_input_ids: List[List[
torch.Tensor]] = model_image_input["image_input_ids"]
image_input_ids = image_input_ids[0][0].tolist()
bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
new_prompt = prompt + "\x04"
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
1:] + boa_token
return LLMInputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=new_multi_modal_data)
def input_mapper_for_fuyu(ctx: InputContext, data: object):
model_config = ctx.model_config
if isinstance(data, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, data)
data = torch.stack([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
# image has been processed with prompt in input processor
return MultiModalInputs({"image_patches": data})
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsVision):
def __init__(self,
config: FuyuConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.image_token_id = _IMAGE_TOKEN_ID
self.image_feature_size = config.patch_size**2 * config.num_channels
self.vision_embed_tokens = ColumnParallelLinear(
self.image_feature_size,
config.hidden_size,
quant_config=quant_config,
)
self.language_model = PersimmonForCausalLM(config,
cache_config=cache_config,
quant_config=quant_config)
def _parse_and_validate_image_input(self, **kwargs: object):
image_patches = kwargs.pop("image_patches", None)
if isinstance(image_patches, torch.Tensor):
expected_feature_size = self.image_feature_size
if image_patches.size(-1) != expected_feature_size:
raise ValueError(
f"Expected image patches to have the last dimension of "
f"{expected_feature_size}, got {image_patches.size(-1)}")
image_patches = image_patches.to(
self.vision_embed_tokens.weight.dtype)
return FuyuImagePixelInputs(type="pixel_values",
data=image_patches)
return None
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings, _ = self.vision_embed_tokens(
image_input["data"])
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
self.image_token_id)
else:
inputs_embeds = None
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.language_model.logits_processor(
self.language_model.lm_head, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.language_model.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
param = params_dict[name]
if "query_key_value" in name:
# copy from vllm/model_executor/models/bloom.py
# NOTE: Fuyu's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -0,0 +1,333 @@
# coding=utf-8
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PersimmonConfig
from transformers.activations import ReLUSquaredActivation
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
class PersimmonMLP(nn.Module):
def __init__(self,
config: PersimmonConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
quant_config=quant_config)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
quant_config=quant_config)
self.act = ReLUSquaredActivation()
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.dense_4h_to_h(hidden_states)
return hidden_states
class PersimmonAttention(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
tensor_parallel_world_size = get_tensor_model_parallel_world_size()
self.hidden_size = config.hidden_size
self.total_num_heads = config.num_attention_heads
self.num_heads = self.total_num_heads // tensor_parallel_world_size
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
self.is_causal = True
assert (self.head_dim * self.total_num_heads) == self.hidden_size
assert self.total_num_heads % tensor_parallel_world_size == 0
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
self.is_qk_layernorm = config.qk_layernorm
if self.is_qk_layernorm:
self.q_layernorm = nn.LayerNorm(self.head_dim)
self.k_layernorm = nn.LayerNorm(self.head_dim)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=int(self.partial_rotary_factor * self.head_dim),
max_position=self.max_position_embeddings,
base=self.rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
quant_config=quant_config)
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
# [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
seq_length = x.shape[0]
return x.view(seq_length, self.num_heads, self.head_dim)
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
# [seq_length, num_heads, head_dim] -> [seq_length, hidden_size]
seq_length = x.shape[0]
return x.view(seq_length, self.num_heads * self.head_dim)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# [seq_length, 3 x hidden_size]
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.is_qk_layernorm:
# [seq_length, num_heads, head_dim]
q = self._split_heads(q)
k = self._split_heads(k)
q = self.q_layernorm(q)
k = self.k_layernorm(k)
q = self._merge_heads(q)
k = self._merge_heads(k)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.dense(attn_output)
return output
class PersimmonDecoderLayer(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PersimmonAttention(config=config,
cache_config=cache_config,
quant_config=quant_config)
self.mlp = PersimmonMLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
outputs = hidden_states
return outputs
class PersimmonModel(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
PersimmonDecoderLayer(config,
cache_config=cache_config,
quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
hidden_states = self.layers[i](
positions,
hidden_states,
kv_caches[i],
attn_metadata,
)
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class PersimmonForCausalLM(nn.Module):
def __init__(self,
config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.model = PersimmonModel(config,
cache_config=cache_config,
quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=False)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
):
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
param = params_dict[name]
if "query_key_value" in name:
# copy from vllm/model_executor/models/bloom.py
# NOTE: Persimmon's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)