mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:35:58 +08:00
[VLM][Model] Support image input for Chameleon (#6633)
This commit is contained in:
parent
c5201240a4
commit
22fa2e35cb
@ -182,6 +182,10 @@ Vision Language Models
|
|||||||
- Models
|
- Models
|
||||||
- Example HuggingFace Models
|
- Example HuggingFace Models
|
||||||
- :ref:`LoRA <lora>`
|
- :ref:`LoRA <lora>`
|
||||||
|
* - :code:`ChameleonForConditionalGeneration`
|
||||||
|
- Chameleon
|
||||||
|
- :code:`facebook/chameleon-7b` etc.
|
||||||
|
-
|
||||||
* - :code:`FuyuForCausalLM`
|
* - :code:`FuyuForCausalLM`
|
||||||
- Fuyu
|
- Fuyu
|
||||||
- :code:`adept/fuyu-8b` etc.
|
- :code:`adept/fuyu-8b` etc.
|
||||||
|
|||||||
102
tests/models/test_chameleon.py
Normal file
102
tests/models/test_chameleon.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import re
|
||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.multimodal.utils import rescale_image_size
|
||||||
|
|
||||||
|
from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||||
|
"stop_sign":
|
||||||
|
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
||||||
|
"cherry_blossom":
|
||||||
|
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
||||||
|
})
|
||||||
|
|
||||||
|
models = ["facebook/chameleon-7b"]
|
||||||
|
|
||||||
|
|
||||||
|
#TODO (ywang96): Add correctness test when chameleon is
|
||||||
|
# available on transformers.
|
||||||
|
def run_test(
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
size_factors: List[float],
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Test if the model can generate text given
|
||||||
|
a batch of images and prompts.
|
||||||
|
|
||||||
|
"""
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
inputs_per_image = [(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||||
|
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_model_len=4096,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True) as vllm_model:
|
||||||
|
|
||||||
|
for prompts, images in inputs_per_image:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(prompts,
|
||||||
|
max_tokens,
|
||||||
|
images=images)
|
||||||
|
for i in range(len(vllm_outputs)):
|
||||||
|
|
||||||
|
# format prompt back to original
|
||||||
|
replacements = {
|
||||||
|
"<racm3:break>": "",
|
||||||
|
"<eoss>": "",
|
||||||
|
"<reserved08706>": ""
|
||||||
|
}
|
||||||
|
pattern = '|'.join(replacements.keys())
|
||||||
|
vllm_result = re.sub(
|
||||||
|
pattern,
|
||||||
|
lambda match: replacements[match.group(0)], #noqa B023
|
||||||
|
vllm_outputs[i][1])
|
||||||
|
vllm_result = vllm_result.replace("<image>", "", 1023)
|
||||||
|
assert vllm_result[:len(prompts[i])] == prompts[i]
|
||||||
|
|
||||||
|
# assert at least 10 new characters are generated
|
||||||
|
# (to take stop token into account)
|
||||||
|
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# Single-scale
|
||||||
|
[1.0],
|
||||||
|
# Single-scale, batched
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
# Multi-scale
|
||||||
|
[0.25, 0.5, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str,
|
||||||
|
max_tokens: int) -> None:
|
||||||
|
run_test(
|
||||||
|
vllm_runner,
|
||||||
|
image_assets,
|
||||||
|
model,
|
||||||
|
size_factors=size_factors,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
@ -105,7 +105,8 @@ def _image_token_str(model_config: ModelConfig,
|
|||||||
return None
|
return None
|
||||||
if model_type.startswith("llava"):
|
if model_type.startswith("llava"):
|
||||||
return tokenizer.decode(model_config.hf_config.image_token_index)
|
return tokenizer.decode(model_config.hf_config.image_token_index)
|
||||||
|
if model_type == "chameleon":
|
||||||
|
return "<image>"
|
||||||
raise TypeError("Unknown model type: {model_type}")
|
raise TypeError("Unknown model type: {model_type}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -16,9 +16,10 @@ _GENERATION_MODELS = {
|
|||||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
||||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
|
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
|
||||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||||
"ChameleonForCausalLM":
|
#TODO(ywang96): remove this when huggingface fixes the model repo
|
||||||
("chameleon", "ChameleonForConditionalGeneration"
|
"ChameleonForCausalLM": ("chameleon", "ChameleonForConditionalGeneration"),
|
||||||
), #TODO(ywang96): fix model name when huggingface fixes it
|
"ChameleonForConditionalGeneration":
|
||||||
|
("chameleon", "ChameleonForConditionalGeneration"),
|
||||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||||
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
||||||
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
||||||
|
|||||||
@ -1,13 +1,17 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import (Any, Dict, Iterable, List, Literal, Optional, Tuple,
|
||||||
|
TypedDict)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -22,10 +26,114 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.transformers_utils.configs import ChameleonConfig
|
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||||
|
repeat_and_pad_image_tokens)
|
||||||
|
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||||
|
from vllm.transformers_utils.configs import (ChameleonConfig,
|
||||||
|
ChameleonVQVAEConfig)
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
|
from .interfaces import SupportsVision
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# These configs are not part of the model config but the preprocessor
|
||||||
|
# and processor files, so we hardcode them in the model file for now.
|
||||||
|
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
|
||||||
|
CHAMELEON_IMAGE_SEQ_LENGTH = 1024
|
||||||
|
CHAMELEON_IMAGE_TOKEN_ID = 8711
|
||||||
|
CHAMELEON_IMAGE_START_TOKEN_ID = 8197
|
||||||
|
CHAMELEON_IMAGE_END_TOKEN_ID = 8196
|
||||||
|
CHAMELEON_SEP_TOKEN_ID = 8710
|
||||||
|
|
||||||
|
|
||||||
|
class ChameleonImagePixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
data: torch.Tensor
|
||||||
|
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_chameleon_image_tokens(ctx: InputContext):
|
||||||
|
return CHAMELEON_IMAGE_SEQ_LENGTH
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_seq_data_for_chameleon(
|
||||||
|
seq_len: int,
|
||||||
|
*,
|
||||||
|
image_token_id: int,
|
||||||
|
image_feature_size_override: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if image_feature_size_override is None:
|
||||||
|
image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
|
||||||
|
else:
|
||||||
|
image_feature_size = image_feature_size_override
|
||||||
|
|
||||||
|
token_ids = [image_token_id] * image_feature_size
|
||||||
|
token_ids += [0] * (seq_len - image_feature_size)
|
||||||
|
return SequenceData(token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_image_for_chameleon(
|
||||||
|
image_width_override: Optional[int] = None,
|
||||||
|
image_height_override: Optional[int] = None,
|
||||||
|
):
|
||||||
|
width = CHAMELEON_CROP_SIZE_WIDTH
|
||||||
|
height = CHAMELEON_CROP_SIZE_HEIGHT
|
||||||
|
if image_width_override is not None:
|
||||||
|
width = image_width_override
|
||||||
|
if image_height_override is not None:
|
||||||
|
height = image_height_override
|
||||||
|
|
||||||
|
image = Image.new("RGB", (width, height), color=0)
|
||||||
|
return {"image": image}
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int):
|
||||||
|
|
||||||
|
seq_data = dummy_seq_data_for_chameleon(
|
||||||
|
seq_len,
|
||||||
|
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_data = dummy_image_for_chameleon()
|
||||||
|
return seq_data, mm_data
|
||||||
|
|
||||||
|
|
||||||
|
def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Processing input prompt to insert required tokens for image placeholder.
|
||||||
|
|
||||||
|
See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
|
||||||
|
""" # noqa
|
||||||
|
|
||||||
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
|
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||||
|
return llm_inputs
|
||||||
|
|
||||||
|
model_config = ctx.model_config
|
||||||
|
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||||
|
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||||
|
tokenizer,
|
||||||
|
llm_inputs.get("prompt"),
|
||||||
|
llm_inputs["prompt_token_ids"],
|
||||||
|
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
|
||||||
|
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
|
||||||
|
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
|
||||||
|
pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Appending sep token for chat mode to follow default processor
|
||||||
|
# behavior
|
||||||
|
new_prompt += tokenizer.sep_token
|
||||||
|
new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
|
||||||
|
|
||||||
|
# NOTE: Create a defensive copy of the original inputs
|
||||||
|
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||||
|
prompt=new_prompt,
|
||||||
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
|
|
||||||
class ChameleonLayerNorm(nn.LayerNorm):
|
class ChameleonLayerNorm(nn.LayerNorm):
|
||||||
|
|
||||||
@ -318,12 +426,333 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
|
||||||
|
class ChameleonVQVAEVectorQuantizer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: ChameleonVQVAEConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.num_embeddings = config.num_embeddings
|
||||||
|
self.embedding_dim = config.embed_dim
|
||||||
|
self.beta = getattr(config, "beta", 0.25)
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
|
||||||
|
self.re_embed = self.num_embeddings
|
||||||
|
|
||||||
|
def forward(self, hidden_state: torch.Tensor):
|
||||||
|
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
|
||||||
|
hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
|
||||||
|
|
||||||
|
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||||
|
distances = (
|
||||||
|
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
|
||||||
|
torch.sum(self.embedding.weight**2, dim=1) -
|
||||||
|
2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
|
||||||
|
self.embedding.weight.transpose(0, 1)))
|
||||||
|
|
||||||
|
min_encoding_indices = torch.argmin(distances, dim=1)
|
||||||
|
hidden_state_quant = self.embedding(min_encoding_indices).view(
|
||||||
|
hidden_state.shape)
|
||||||
|
|
||||||
|
# compute loss for embedding
|
||||||
|
loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
|
||||||
|
2) + self.beta * torch.mean(
|
||||||
|
(hidden_state_quant - hidden_state.detach())**2)
|
||||||
|
|
||||||
|
# preserve gradients
|
||||||
|
hidden_state_quant = hidden_state + (hidden_state_quant -
|
||||||
|
hidden_state).detach()
|
||||||
|
|
||||||
|
# reshape back to match original input shape
|
||||||
|
hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
|
||||||
|
2).contiguous()
|
||||||
|
|
||||||
|
return hidden_state_quant, loss, min_encoding_indices
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
|
||||||
|
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor):
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
hidden_states = F.pad(hidden_states,
|
||||||
|
pad=(0, 1, 0, 1),
|
||||||
|
mode="constant",
|
||||||
|
value=0)
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
|
||||||
|
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: ChameleonVQVAEConfig,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels=None,
|
||||||
|
conv_shortcut=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels if out_channels is None \
|
||||||
|
else out_channels
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
|
||||||
|
self.norm1 = torch.nn.GroupNorm(num_groups=32,
|
||||||
|
num_channels=in_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True)
|
||||||
|
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
self.norm2 = torch.nn.GroupNorm(num_groups=32,
|
||||||
|
num_channels=out_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True)
|
||||||
|
self.dropout = torch.nn.Dropout(config.dropout)
|
||||||
|
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
else:
|
||||||
|
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor):
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states *= torch.sigmoid(hidden_states)
|
||||||
|
hidden_states = self.conv1(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
hidden_states *= torch.sigmoid(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.conv2(hidden_states)
|
||||||
|
|
||||||
|
if self.in_channels != self.out_channels:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
residual = self.conv_shortcut(residual)
|
||||||
|
else:
|
||||||
|
residual = self.nin_shortcut(residual)
|
||||||
|
|
||||||
|
return residual + hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
|
||||||
|
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||||
|
num_channels=in_channels,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True)
|
||||||
|
self.q = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.k = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.v = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||||
|
in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor):
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
query_states = self.q(hidden_states)
|
||||||
|
key_states = self.k(hidden_states)
|
||||||
|
value_states = self.v(hidden_states)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
batch_size, channels, height, width = query_states.shape
|
||||||
|
query_states = query_states.reshape(batch_size, channels,
|
||||||
|
height * width).permute(0, 2, 1)
|
||||||
|
key_states = key_states.reshape(batch_size, channels, height * width)
|
||||||
|
attn_weights = torch.bmm(query_states, key_states)
|
||||||
|
attn_weights = attn_weights * (int(channels)**(-0.5))
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=2)
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
value_states = value_states.reshape(batch_size, channels,
|
||||||
|
height * width)
|
||||||
|
attn_weights = attn_weights.permute(0, 2, 1)
|
||||||
|
attn_output = torch.bmm(value_states,
|
||||||
|
attn_weights).reshape(batch_size, channels,
|
||||||
|
height, width)
|
||||||
|
|
||||||
|
attn_output = self.proj_out(attn_output)
|
||||||
|
return residual + attn_output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
|
||||||
|
class ChameleonVQVAEEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: ChameleonVQVAEConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_resolutions = len(config.channel_multiplier)
|
||||||
|
self.num_res_blocks = config.num_res_blocks
|
||||||
|
base_channels = config.base_channels
|
||||||
|
resolution = config.resolution
|
||||||
|
in_channels = config.in_channels
|
||||||
|
double_latent = config.double_latent
|
||||||
|
latent_channels = config.latent_channels
|
||||||
|
channel_multiplier = config.channel_multiplier
|
||||||
|
|
||||||
|
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||||
|
base_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
curr_res = resolution
|
||||||
|
in_channel_multiplier = (1, ) + tuple(channel_multiplier)
|
||||||
|
self.in_channel_multiplier = in_channel_multiplier
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = base_channels * in_channel_multiplier[i_level]
|
||||||
|
block_out = base_channels * channel_multiplier[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
block.append(
|
||||||
|
ChameleonVQVAEEncoderResnetBlock(
|
||||||
|
config=config,
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
))
|
||||||
|
block_in = block_out
|
||||||
|
if (config.attn_resolutions is not None
|
||||||
|
and curr_res in config.attn_resolutions
|
||||||
|
and config.attn_type == "vanilla"):
|
||||||
|
attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
|
||||||
|
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
|
||||||
|
curr_res = curr_res // 2
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
|
||||||
|
config=config,
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
)
|
||||||
|
self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
|
||||||
|
block_in) if config.attn_type == "vanilla" else nn.Identity()
|
||||||
|
self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
|
||||||
|
config=config,
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm_out = torch.nn.GroupNorm(num_groups=32,
|
||||||
|
num_channels=block_in,
|
||||||
|
eps=1e-6,
|
||||||
|
affine=True)
|
||||||
|
self.conv_out = torch.nn.Conv2d(
|
||||||
|
block_in,
|
||||||
|
2 * latent_channels if double_latent else latent_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor):
|
||||||
|
# downsampling
|
||||||
|
hidden_states = [self.conv_in(pixel_values)]
|
||||||
|
for i_level in range(self.num_resolutions):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
hidden_state = self.down[i_level].block[i_block](
|
||||||
|
hidden_states[-1], )
|
||||||
|
if len(self.down[i_level].attn) > 0:
|
||||||
|
hidden_state = self.down[i_level].attn[i_block](
|
||||||
|
hidden_state)
|
||||||
|
hidden_states.append(hidden_state)
|
||||||
|
if i_level != self.num_resolutions - 1:
|
||||||
|
hidden_states.append(self.down[i_level].downsample(
|
||||||
|
hidden_states[-1]))
|
||||||
|
|
||||||
|
# middle
|
||||||
|
last_hidden_state = hidden_states[-1]
|
||||||
|
last_hidden_state = self.mid.block_1(last_hidden_state)
|
||||||
|
last_hidden_state = self.mid.attn_1(last_hidden_state)
|
||||||
|
last_hidden_state = self.mid.block_2(last_hidden_state)
|
||||||
|
|
||||||
|
# end
|
||||||
|
last_hidden_state = self.norm_out(last_hidden_state)
|
||||||
|
last_hidden_state *= torch.sigmoid(last_hidden_state)
|
||||||
|
last_hidden_state = self.conv_out(last_hidden_state)
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
|
||||||
|
class ChameleonVQVAE(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: ChameleonVQVAEConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = ChameleonVQVAEEncoder(config)
|
||||||
|
self.quantize = ChameleonVQVAEVectorQuantizer(config)
|
||||||
|
self.quant_conv = torch.nn.Conv2d(config.latent_channels,
|
||||||
|
config.embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
|
||||||
|
config.latent_channels, 1)
|
||||||
|
self.eval() # Chameleon's VQ model is frozen
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, pixel_values: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
hidden_states = self.encoder(pixel_values)
|
||||||
|
hidden_states = self.quant_conv(hidden_states)
|
||||||
|
quant, emb_loss, indices = self.quantize(hidden_states)
|
||||||
|
return quant, emb_loss, indices
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
|
||||||
class ChameleonImageVocabularyMapping:
|
class ChameleonImageVocabularyMapping:
|
||||||
"""
|
"""
|
||||||
A class for mapping discrete image tokens from VQGAN to BPE tokens.
|
A class for mapping discrete image tokens from VQGAN to BPE tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab_map):
|
def __init__(self, vocab_map: Dict[str, int]):
|
||||||
self.vocab_map = vocab_map
|
self.vocab_map = vocab_map
|
||||||
self.image_token_id = vocab_map.get("<image>")
|
self.image_token_id = vocab_map.get("<image>")
|
||||||
|
|
||||||
@ -401,13 +830,23 @@ class ChameleonModel(nn.Module):
|
|||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.vqmodel = ChameleonVQVAE(config.vq_config)
|
||||||
# TODO: Support image input
|
|
||||||
# self.vqmodel = ChameleonVQModel(config.vq_config)
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.embed_tokens(input_ids)
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Tokenizes images into discrete tokens with VQGAN module. Converts
|
||||||
|
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
||||||
|
special tokens.
|
||||||
|
"""
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
_, _, image_toks = self.vqmodel.encode(pixel_values)
|
||||||
|
bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
|
||||||
|
bpe_toks = bpe_toks.view(batch_size, -1)
|
||||||
|
return bpe_toks
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
@ -434,16 +873,22 @@ class ChameleonModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class ChameleonForConditionalGeneration(nn.Module):
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||||
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
|
||||||
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
|
||||||
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
|
||||||
|
class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ChameleonConfig,
|
config: ChameleonConfig,
|
||||||
|
multimodal_config: MultiModalConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
self.model = ChameleonModel(config, cache_config, quant_config)
|
self.model = ChameleonModel(config, cache_config, quant_config)
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
@ -458,6 +903,36 @@ class ChameleonForConditionalGeneration(nn.Module):
|
|||||||
config.vocab_size, logit_scale)
|
config.vocab_size, logit_scale)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT,
|
||||||
|
CHAMELEON_CROP_SIZE_WIDTH)
|
||||||
|
actual_dims = tuple(data.shape[1:])
|
||||||
|
|
||||||
|
if actual_dims != expected_dims:
|
||||||
|
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||||
|
raise ValueError(
|
||||||
|
f"The expected shape of pixel values is {expected_expr}. "
|
||||||
|
f"You supplied {tuple(data.shape)}.")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
|
||||||
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
|
||||||
|
if pixel_values is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(pixel_values, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
return ChameleonImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
data=self._validate_pixel_values(pixel_values),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -468,10 +943,17 @@ class ChameleonForConditionalGeneration(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
# TODO (ywang96): Support image input
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
# image_tokens = self.process_image_input(**kwargs)
|
|
||||||
# image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
if image_input is not None:
|
||||||
# input_ids[special_image_mask] = image_tokens.flatten().to(input_ids.dtype) #noqa
|
assert self.model.vqmodel is not None
|
||||||
|
image_tokens = self.model.get_image_tokens(image_input["data"].to(
|
||||||
|
self.config.torch_dtype))
|
||||||
|
image_token_id = self.model.vocabulary_mapping.image_token_id
|
||||||
|
special_image_mask = input_ids == image_token_id
|
||||||
|
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||||
|
input_ids = input_ids.masked_scatter(special_image_mask,
|
||||||
|
image_tokens)
|
||||||
|
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
attn_metadata)
|
attn_metadata)
|
||||||
@ -511,43 +993,52 @@ class ChameleonForConditionalGeneration(nn.Module):
|
|||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skip loading vqgan
|
|
||||||
# TODO: add support for the vision model
|
|
||||||
if "vqmodel" in name:
|
|
||||||
continue
|
|
||||||
if ("rotary_emb.cos_cached" in name
|
if ("rotary_emb.cos_cached" in name
|
||||||
or "rotary_emb.sin_cached" in name):
|
or "rotary_emb.sin_cached" in name):
|
||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
use_default_weight_loading = False
|
||||||
if weight_name not in name:
|
if "vqmodel" in name:
|
||||||
continue
|
if self.model.vqmodel is not None:
|
||||||
name = name.replace(weight_name, param_name)
|
# We only do sharding for language model and
|
||||||
# Skip loading extra bias for GPTQ models.
|
# not vqvae for now.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
use_default_weight_loading = True
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
for (param_name, weight_name,
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
shard_id) in stacked_params_mapping:
|
||||||
continue
|
if weight_name not in name:
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
if name.endswith("kv_scale"):
|
|
||||||
remapped_kv_scale_name = name.replace(
|
|
||||||
".kv_scale", ".attn.kv_scale")
|
|
||||||
if remapped_kv_scale_name not in params_dict:
|
|
||||||
print_warning_once(
|
|
||||||
f"Found kv scale in the checkpoint (e.g. {name}), "
|
|
||||||
"but not found the expected name in the model "
|
|
||||||
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
|
|
||||||
"not loaded.")
|
|
||||||
continue
|
continue
|
||||||
else:
|
name = name.replace(weight_name, param_name)
|
||||||
name = remapped_kv_scale_name
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
if name.endswith("kv_scale"):
|
||||||
|
remapped_kv_scale_name = name.replace(
|
||||||
|
".kv_scale", ".attn.kv_scale")
|
||||||
|
if remapped_kv_scale_name not in params_dict:
|
||||||
|
print_warning_once(
|
||||||
|
"Found kv scale in the checkpoint (e.g. "
|
||||||
|
f"{name}), but not found the expected name in "
|
||||||
|
f"the model (e.g. {remapped_kv_scale_name}). "
|
||||||
|
"kv-scale is not loaded.")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
name = remapped_kv_scale_name
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
if use_default_weight_loading and name in params_dict:
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from vllm.transformers_utils.configs.chameleon import ChameleonConfig
|
from vllm.transformers_utils.configs.chameleon import (ChameleonConfig,
|
||||||
|
ChameleonVQVAEConfig)
|
||||||
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
||||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||||
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
||||||
@ -12,6 +13,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChameleonConfig",
|
"ChameleonConfig",
|
||||||
|
"ChameleonVQVAEConfig",
|
||||||
"ChatGLMConfig",
|
"ChatGLMConfig",
|
||||||
"DbrxConfig",
|
"DbrxConfig",
|
||||||
"MPTConfig",
|
"MPTConfig",
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
@ -5,9 +7,7 @@ from transformers import PretrainedConfig
|
|||||||
# transformers once the new release with Chameleon support
|
# transformers once the new release with Chameleon support
|
||||||
# is available.
|
# is available.
|
||||||
class ChameleonConfig(PretrainedConfig):
|
class ChameleonConfig(PretrainedConfig):
|
||||||
|
|
||||||
model_type = "chameleon"
|
model_type = "chameleon"
|
||||||
is_composition = True
|
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -31,7 +31,7 @@ class ChameleonConfig(PretrainedConfig):
|
|||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
attention_bias=False,
|
attention_bias=False,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
qk_layernorm=False,
|
model_parallel_size=1,
|
||||||
swin_norm=False,
|
swin_norm=False,
|
||||||
vq_config=None,
|
vq_config=None,
|
||||||
vocabulary_map=None,
|
vocabulary_map=None,
|
||||||
@ -46,10 +46,6 @@ class ChameleonConfig(PretrainedConfig):
|
|||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.mlp_bias = mlp_bias
|
self.mlp_bias = mlp_bias
|
||||||
|
|
||||||
# for backward compatibility
|
|
||||||
if num_key_value_heads is None:
|
|
||||||
num_key_value_heads = num_attention_heads
|
|
||||||
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
self.num_key_value_heads = num_key_value_heads
|
||||||
self.hidden_act = hidden_act
|
self.hidden_act = hidden_act
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
@ -60,10 +56,14 @@ class ChameleonConfig(PretrainedConfig):
|
|||||||
self._rope_scaling_validation()
|
self._rope_scaling_validation()
|
||||||
self.attention_bias = attention_bias
|
self.attention_bias = attention_bias
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
self.qk_layernorm = qk_layernorm
|
self.model_parallel_size = model_parallel_size
|
||||||
self.swin_norm = swin_norm
|
self.swin_norm = swin_norm
|
||||||
# vq config is currently ignored
|
|
||||||
# self.vq_config = ChameleonVQConfig(**vq_config)
|
if vq_config is None:
|
||||||
|
vq_config = {}
|
||||||
|
|
||||||
|
self.vq_config = ChameleonVQVAEConfig(**vq_config)
|
||||||
|
|
||||||
self.vocabulary_map = vocabulary_map
|
self.vocabulary_map = vocabulary_map
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -99,3 +99,40 @@ class ChameleonConfig(PretrainedConfig):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`rope_scaling`'s factor field must be a float > 1, "
|
"`rope_scaling`'s factor field must be a float > 1, "
|
||||||
f"got {rope_scaling_factor}")
|
f"got {rope_scaling_factor}")
|
||||||
|
|
||||||
|
|
||||||
|
class ChameleonVQVAEConfig(PretrainedConfig):
|
||||||
|
|
||||||
|
model_type = "chameleon_vqgan"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int = 256,
|
||||||
|
num_embeddings: int = 8192,
|
||||||
|
double_latent: bool = False,
|
||||||
|
latent_channels: int = 256,
|
||||||
|
resolution: int = 512,
|
||||||
|
in_channels: int = 3,
|
||||||
|
base_channels: int = 128,
|
||||||
|
channel_multiplier: List[int] = [1, 1, 2, 2, 4], #noqa
|
||||||
|
num_res_blocks: int = 2,
|
||||||
|
attn_resolutions: Optional[List[int]] = None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
attn_type: str = "vanilla",
|
||||||
|
initializer_range=0.02,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.double_latent = double_latent
|
||||||
|
self.latent_channels = latent_channels
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.base_channels = base_channels
|
||||||
|
self.channel_multiplier = channel_multiplier
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_resolutions = attn_resolutions
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attn_type = attn_type
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user