[VLM][Model] Support image input for Chameleon (#6633)

This commit is contained in:
Roger Wang 2024-07-22 23:50:48 -07:00 committed by GitHub
parent c5201240a4
commit 22fa2e35cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 696 additions and 58 deletions

View File

@ -182,6 +182,10 @@ Vision Language Models
- Models - Models
- Example HuggingFace Models - Example HuggingFace Models
- :ref:`LoRA <lora>` - :ref:`LoRA <lora>`
* - :code:`ChameleonForConditionalGeneration`
- Chameleon
- :code:`facebook/chameleon-7b` etc.
-
* - :code:`FuyuForCausalLM` * - :code:`FuyuForCausalLM`
- Fuyu - Fuyu
- :code:`adept/fuyu-8b` etc. - :code:`adept/fuyu-8b` etc.

View File

@ -0,0 +1,102 @@
import re
from typing import List, Optional, Type
import pytest
from vllm.multimodal.utils import rescale_image_size
from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
pytestmark = pytest.mark.vlm
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
"cherry_blossom":
"USER: <image>\nWhat is the season?\nASSISTANT:",
})
models = ["facebook/chameleon-7b"]
#TODO (ywang96): Add correctness test when chameleon is
# available on transformers.
def run_test(
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Test if the model can generate text given
a batch of images and prompts.
"""
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
with vllm_runner(model,
max_model_len=4096,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
for prompts, images in inputs_per_image:
vllm_outputs = vllm_model.generate_greedy(prompts,
max_tokens,
images=images)
for i in range(len(vllm_outputs)):
# format prompt back to original
replacements = {
"<racm3:break>": "",
"<eoss>": "",
"<reserved08706>": ""
}
pattern = '|'.join(replacements.keys())
vllm_result = re.sub(
pattern,
lambda match: replacements[match.group(0)], #noqa B023
vllm_outputs[i][1])
vllm_result = vllm_result.replace("<image>", "", 1023)
assert vllm_result[:len(prompts[i])] == prompts[i]
# assert at least 10 new characters are generated
# (to take stop token into account)
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str,
max_tokens: int) -> None:
run_test(
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
tensor_parallel_size=1,
)

View File

@ -105,7 +105,8 @@ def _image_token_str(model_config: ModelConfig,
return None return None
if model_type.startswith("llava"): if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index) return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type == "chameleon":
return "<image>"
raise TypeError("Unknown model type: {model_type}") raise TypeError("Unknown model type: {model_type}")

View File

@ -16,9 +16,10 @@ _GENERATION_MODELS = {
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChameleonForCausalLM": #TODO(ywang96): remove this when huggingface fixes the model repo
("chameleon", "ChameleonForConditionalGeneration" "ChameleonForCausalLM": ("chameleon", "ChameleonForConditionalGeneration"),
), #TODO(ywang96): fix model name when huggingface fixes it "ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"CohereForCausalLM": ("commandr", "CohereForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"),

View File

@ -1,13 +1,17 @@
from functools import cached_property from functools import cached_property
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import (Any, Dict, Iterable, List, Literal, Optional, Tuple,
TypedDict)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -22,10 +26,114 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.configs import ChameleonConfig from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.transformers_utils.configs import (ChameleonConfig,
ChameleonVQVAEConfig)
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsVision
logger = init_logger(__name__)
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
CHAMELEON_IMAGE_SEQ_LENGTH = 1024
CHAMELEON_IMAGE_TOKEN_ID = 8711
CHAMELEON_IMAGE_START_TOKEN_ID = 8197
CHAMELEON_IMAGE_END_TOKEN_ID = 8196
CHAMELEON_SEP_TOKEN_ID = 8710
class ChameleonImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
def get_max_chameleon_image_tokens(ctx: InputContext):
return CHAMELEON_IMAGE_SEQ_LENGTH
def dummy_seq_data_for_chameleon(
seq_len: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
return SequenceData(token_ids)
def dummy_image_for_chameleon(
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = CHAMELEON_CROP_SIZE_WIDTH
height = CHAMELEON_CROP_SIZE_HEIGHT
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_chameleon(
seq_len,
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
)
mm_data = dummy_image_for_chameleon()
return seq_data, mm_data
def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
"""
Processing input prompt to insert required tokens for image placeholder.
See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
""" # noqa
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
)
# Appending sep token for chat mode to follow default processor
# behavior
new_prompt += tokenizer.sep_token
new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
class ChameleonLayerNorm(nn.LayerNorm): class ChameleonLayerNorm(nn.LayerNorm):
@ -318,12 +426,333 @@ class ChameleonSwinDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
class ChameleonVQVAEVectorQuantizer(nn.Module):
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__()
self.num_embeddings = config.num_embeddings
self.embedding_dim = config.embed_dim
self.beta = getattr(config, "beta", 0.25)
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.re_embed = self.num_embeddings
def forward(self, hidden_state: torch.Tensor):
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
distances = (
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
torch.sum(self.embedding.weight**2, dim=1) -
2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
self.embedding.weight.transpose(0, 1)))
min_encoding_indices = torch.argmin(distances, dim=1)
hidden_state_quant = self.embedding(min_encoding_indices).view(
hidden_state.shape)
# compute loss for embedding
loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
2) + self.beta * torch.mean(
(hidden_state_quant - hidden_state.detach())**2)
# preserve gradients
hidden_state_quant = hidden_state + (hidden_state_quant -
hidden_state).detach()
# reshape back to match original input shape
hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
2).contiguous()
return hidden_state_quant, loss, min_encoding_indices
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, hidden_states: torch.Tensor):
# no asymmetric padding in torch conv, must do it ourselves
hidden_states = F.pad(hidden_states,
pad=(0, 1, 0, 1),
mode="constant",
value=0)
hidden_states = self.conv(hidden_states)
return hidden_states
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
def __init__(
self,
config: ChameleonVQVAEConfig,
in_channels: int,
out_channels=None,
conv_shortcut=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None \
else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.norm2 = torch.nn.GroupNorm(num_groups=32,
num_channels=out_channels,
eps=1e-6,
affine=True)
self.dropout = torch.nn.Dropout(config.dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, hidden_states: torch.Tensor):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states *= torch.sigmoid(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states *= torch.sigmoid(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)
return residual + hidden_states
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, hidden_states: torch.Tensor):
residual = hidden_states
hidden_states = self.norm(hidden_states)
query_states = self.q(hidden_states)
key_states = self.k(hidden_states)
value_states = self.v(hidden_states)
# compute attention
batch_size, channels, height, width = query_states.shape
query_states = query_states.reshape(batch_size, channels,
height * width).permute(0, 2, 1)
key_states = key_states.reshape(batch_size, channels, height * width)
attn_weights = torch.bmm(query_states, key_states)
attn_weights = attn_weights * (int(channels)**(-0.5))
attn_weights = F.softmax(attn_weights, dim=2)
# attend to values
value_states = value_states.reshape(batch_size, channels,
height * width)
attn_weights = attn_weights.permute(0, 2, 1)
attn_output = torch.bmm(value_states,
attn_weights).reshape(batch_size, channels,
height, width)
attn_output = self.proj_out(attn_output)
return residual + attn_output
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
class ChameleonVQVAEEncoder(nn.Module):
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__()
self.num_resolutions = len(config.channel_multiplier)
self.num_res_blocks = config.num_res_blocks
base_channels = config.base_channels
resolution = config.resolution
in_channels = config.in_channels
double_latent = config.double_latent
latent_channels = config.latent_channels
channel_multiplier = config.channel_multiplier
self.conv_in = torch.nn.Conv2d(in_channels,
base_channels,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_channel_multiplier = (1, ) + tuple(channel_multiplier)
self.in_channel_multiplier = in_channel_multiplier
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = base_channels * in_channel_multiplier[i_level]
block_out = base_channels * channel_multiplier[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_out,
))
block_in = block_out
if (config.attn_resolutions is not None
and curr_res in config.attn_resolutions
and config.attn_type == "vanilla"):
attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_in,
)
self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
block_in) if config.attn_type == "vanilla" else nn.Identity()
self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_in,
)
self.norm_out = torch.nn.GroupNorm(num_groups=32,
num_channels=block_in,
eps=1e-6,
affine=True)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * latent_channels if double_latent else latent_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, pixel_values: torch.Tensor):
# downsampling
hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_state = self.down[i_level].block[i_block](
hidden_states[-1], )
if len(self.down[i_level].attn) > 0:
hidden_state = self.down[i_level].attn[i_block](
hidden_state)
hidden_states.append(hidden_state)
if i_level != self.num_resolutions - 1:
hidden_states.append(self.down[i_level].downsample(
hidden_states[-1]))
# middle
last_hidden_state = hidden_states[-1]
last_hidden_state = self.mid.block_1(last_hidden_state)
last_hidden_state = self.mid.attn_1(last_hidden_state)
last_hidden_state = self.mid.block_2(last_hidden_state)
# end
last_hidden_state = self.norm_out(last_hidden_state)
last_hidden_state *= torch.sigmoid(last_hidden_state)
last_hidden_state = self.conv_out(last_hidden_state)
return last_hidden_state
# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
class ChameleonVQVAE(nn.Module):
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__()
self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels,
config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
def encode(
self, pixel_values: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states)
return quant, emb_loss, indices
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
class ChameleonImageVocabularyMapping: class ChameleonImageVocabularyMapping:
""" """
A class for mapping discrete image tokens from VQGAN to BPE tokens. A class for mapping discrete image tokens from VQGAN to BPE tokens.
""" """
def __init__(self, vocab_map): def __init__(self, vocab_map: Dict[str, int]):
self.vocab_map = vocab_map self.vocab_map = vocab_map
self.image_token_id = vocab_map.get("<image>") self.image_token_id = vocab_map.get("<image>")
@ -401,13 +830,23 @@ class ChameleonModel(nn.Module):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.vqmodel = ChameleonVQVAE(config.vq_config)
# TODO: Support image input
# self.vqmodel = ChameleonVQModel(config.vq_config)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
special tokens.
"""
batch_size = pixel_values.shape[0]
_, _, image_toks = self.vqmodel.encode(pixel_values)
bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
bpe_toks = bpe_toks.view(batch_size, -1)
return bpe_toks
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
@ -434,16 +873,22 @@ class ChameleonModel(nn.Module):
return hidden_states return hidden_states
class ChameleonForConditionalGeneration(nn.Module): @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
def __init__( def __init__(
self, self,
config: ChameleonConfig, config: ChameleonConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.multimodal_config = multimodal_config
self.model = ChameleonModel(config, cache_config, quant_config) self.model = ChameleonModel(config, cache_config, quant_config)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
@ -458,6 +903,36 @@ class ChameleonForConditionalGeneration(nn.Module):
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = Sampler() self.sampler = Sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT,
CHAMELEON_CROP_SIZE_WIDTH)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return ChameleonImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -468,10 +943,17 @@ class ChameleonForConditionalGeneration(nn.Module):
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO (ywang96): Support image input image_input = self._parse_and_validate_image_input(**kwargs)
# image_tokens = self.process_image_input(**kwargs)
# image_mask = input_ids == self.vocabulary_mapping.image_token_id if image_input is not None:
# input_ids[special_image_mask] = image_tokens.flatten().to(input_ids.dtype) #noqa assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(image_input["data"].to(
self.config.torch_dtype))
image_token_id = self.model.vocabulary_mapping.image_token_id
special_image_mask = input_ids == image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask,
image_tokens)
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
@ -511,43 +993,52 @@ class ChameleonForConditionalGeneration(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
# Skip loading vqgan
# TODO: add support for the vision model
if "vqmodel" in name:
continue
if ("rotary_emb.cos_cached" in name if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name): or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: use_default_weight_loading = False
if weight_name not in name: if "vqmodel" in name:
continue if self.model.vqmodel is not None:
name = name.replace(weight_name, param_name) # We only do sharding for language model and
# Skip loading extra bias for GPTQ models. # not vqvae for now.
if name.endswith(".bias") and name not in params_dict: use_default_weight_loading = True
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else: else:
# Skip loading extra bias for GPTQ models. for (param_name, weight_name,
if name.endswith(".bias") and name not in params_dict: shard_id) in stacked_params_mapping:
continue if weight_name not in name:
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue continue
else: name = name.replace(weight_name, param_name)
name = remapped_kv_scale_name # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint (e.g. "
f"{name}), but not found the expected name in "
f"the model (e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if use_default_weight_loading and name in params_dict:
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)

View File

@ -1,4 +1,5 @@
from vllm.transformers_utils.configs.chameleon import ChameleonConfig from vllm.transformers_utils.configs.chameleon import (ChameleonConfig,
ChameleonVQVAEConfig)
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and # RWConfig is for the original tiiuae/falcon-40b(-instruct) and
@ -12,6 +13,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
__all__ = [ __all__ = [
"ChameleonConfig", "ChameleonConfig",
"ChameleonVQVAEConfig",
"ChatGLMConfig", "ChatGLMConfig",
"DbrxConfig", "DbrxConfig",
"MPTConfig", "MPTConfig",

View File

@ -1,3 +1,5 @@
from typing import List, Optional
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -5,9 +7,7 @@ from transformers import PretrainedConfig
# transformers once the new release with Chameleon support # transformers once the new release with Chameleon support
# is available. # is available.
class ChameleonConfig(PretrainedConfig): class ChameleonConfig(PretrainedConfig):
model_type = "chameleon" model_type = "chameleon"
is_composition = True
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
@ -31,7 +31,7 @@ class ChameleonConfig(PretrainedConfig):
rope_scaling=None, rope_scaling=None,
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
qk_layernorm=False, model_parallel_size=1,
swin_norm=False, swin_norm=False,
vq_config=None, vq_config=None,
vocabulary_map=None, vocabulary_map=None,
@ -46,10 +46,6 @@ class ChameleonConfig(PretrainedConfig):
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.mlp_bias = mlp_bias self.mlp_bias = mlp_bias
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.initializer_range = initializer_range self.initializer_range = initializer_range
@ -60,10 +56,14 @@ class ChameleonConfig(PretrainedConfig):
self._rope_scaling_validation() self._rope_scaling_validation()
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.qk_layernorm = qk_layernorm self.model_parallel_size = model_parallel_size
self.swin_norm = swin_norm self.swin_norm = swin_norm
# vq config is currently ignored
# self.vq_config = ChameleonVQConfig(**vq_config) if vq_config is None:
vq_config = {}
self.vq_config = ChameleonVQVAEConfig(**vq_config)
self.vocabulary_map = vocabulary_map self.vocabulary_map = vocabulary_map
super().__init__( super().__init__(
@ -99,3 +99,40 @@ class ChameleonConfig(PretrainedConfig):
raise ValueError( raise ValueError(
"`rope_scaling`'s factor field must be a float > 1, " "`rope_scaling`'s factor field must be a float > 1, "
f"got {rope_scaling_factor}") f"got {rope_scaling_factor}")
class ChameleonVQVAEConfig(PretrainedConfig):
model_type = "chameleon_vqgan"
def __init__(
self,
embed_dim: int = 256,
num_embeddings: int = 8192,
double_latent: bool = False,
latent_channels: int = 256,
resolution: int = 512,
in_channels: int = 3,
base_channels: int = 128,
channel_multiplier: List[int] = [1, 1, 2, 2, 4], #noqa
num_res_blocks: int = 2,
attn_resolutions: Optional[List[int]] = None,
dropout: float = 0.0,
attn_type: str = "vanilla",
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_embeddings = num_embeddings
self.double_latent = double_latent
self.latent_channels = latent_channels
self.resolution = resolution
self.in_channels = in_channels
self.base_channels = base_channels
self.channel_multiplier = channel_multiplier
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
self.dropout = dropout
self.attn_type = attn_type
self.initializer_range = initializer_range