mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 05:04:58 +08:00
1099 lines
40 KiB
Python
1099 lines
40 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from functools import cached_property
|
|
from itertools import islice
|
|
from typing import Annotated, Any, Literal
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import (
|
|
BatchFeature,
|
|
ChameleonConfig,
|
|
ChameleonProcessor,
|
|
ChameleonVQVAEConfig,
|
|
)
|
|
|
|
from vllm.attention.layer import Attention
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.conv import Conv2dLayer
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
default_weight_loader,
|
|
row_parallel_weight_loader,
|
|
)
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
)
|
|
from vllm.multimodal.parse import MultiModalDataItems
|
|
from vllm.multimodal.processing import (
|
|
BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
PromptUpdateDetails,
|
|
)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
from .interfaces import (
|
|
MultiModalEmbeddings,
|
|
SupportsMultiModal,
|
|
SupportsPP,
|
|
SupportsQuant,
|
|
)
|
|
from .utils import (
|
|
is_pp_missing_parameter,
|
|
make_empty_intermediate_tensors_factory,
|
|
make_layers,
|
|
maybe_prefix,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class ChameleonImagePixelInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- bn: Batch size * number of images
|
|
- c: Number of channels (3)
|
|
- h: Height of each image
|
|
- w: Width of each image
|
|
"""
|
|
|
|
type: Literal["pixel_values"]
|
|
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
|
|
|
|
|
class ChameleonProcessingInfo(BaseProcessingInfo):
|
|
def get_hf_config(self):
|
|
return self.ctx.get_hf_config(ChameleonConfig)
|
|
|
|
def get_hf_processor(self, **kwargs: object):
|
|
return self.ctx.get_hf_processor(ChameleonProcessor, **kwargs)
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
return {"image": 1}
|
|
|
|
def get_num_image_tokens(self) -> int:
|
|
processor = self.get_hf_processor()
|
|
return processor.image_seq_length
|
|
|
|
|
|
class ChameleonDummyInputsBuilder(BaseDummyInputsBuilder[ChameleonProcessingInfo]):
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
processor = self.info.get_hf_processor()
|
|
image_token = processor.image_token
|
|
|
|
return image_token * num_images
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
) -> MultiModalDataDict:
|
|
config = self.info.get_hf_config()
|
|
|
|
width = height = config.vq_config.resolution
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
image_overrides = mm_options.get("image") if mm_options else None
|
|
|
|
return {
|
|
"image": self._get_dummy_images(
|
|
width=width,
|
|
height=height,
|
|
num_images=num_images,
|
|
overrides=image_overrides,
|
|
)
|
|
}
|
|
|
|
|
|
class ChameleonMultiModalProcessor(BaseMultiModalProcessor[ChameleonProcessingInfo]):
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
tok_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
if not mm_data:
|
|
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
|
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
|
|
|
return super()._call_hf_processor(
|
|
prompt=prompt,
|
|
mm_data=mm_data,
|
|
mm_kwargs=mm_kwargs,
|
|
tok_kwargs=tok_kwargs,
|
|
)
|
|
|
|
def _apply_hf_processor_tokens_only(
|
|
self,
|
|
prompt_tokens: list[int],
|
|
) -> list[int]:
|
|
# HF processor adds sep token for chat mode
|
|
tokenizer = self.info.get_tokenizer()
|
|
vocab = tokenizer.get_vocab()
|
|
|
|
sep_token_id = vocab[tokenizer.sep_token] # type: ignore
|
|
|
|
return prompt_tokens + [sep_token_id]
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
tokenizer = self.info.get_tokenizer()
|
|
vocab = tokenizer.get_vocab()
|
|
|
|
image_start_id = vocab[processor.image_start_token]
|
|
image_token_id = vocab[processor.image_token]
|
|
image_end_id = vocab[processor.image_end_token]
|
|
|
|
num_image_tokens = self.info.get_num_image_tokens()
|
|
image_tokens = [image_token_id] * num_image_tokens
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="image",
|
|
target=[image_token_id],
|
|
replacement=PromptUpdateDetails.select_token_id(
|
|
[image_start_id] + image_tokens + [image_end_id],
|
|
embed_token_id=image_token_id,
|
|
),
|
|
)
|
|
]
|
|
|
|
|
|
class ChameleonLayerNorm(nn.LayerNorm):
|
|
def __init__(self, hidden_size, *args, **kwargs):
|
|
super().__init__(hidden_size, *args, **kwargs)
|
|
self.normalized_shape = (hidden_size[-1],)
|
|
|
|
set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader})
|
|
set_weight_attrs(self.bias, {"weight_loader": row_parallel_weight_loader})
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = F.layer_norm(
|
|
hidden_states, self.normalized_shape, None, None, eps=1e-5
|
|
)
|
|
hidden_states = hidden_states * self.weight + self.bias
|
|
return hidden_states
|
|
|
|
|
|
# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
|
|
class ChameleonMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
quant_config: QuantizationConfig | None = None,
|
|
bias: bool = False,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
input_size=hidden_size,
|
|
output_sizes=[intermediate_size] * 2,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj",
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
input_size=intermediate_size,
|
|
output_size=hidden_size,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj",
|
|
)
|
|
if hidden_act != "silu":
|
|
raise ValueError(
|
|
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
|
)
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
|
|
class ChameleonAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
rope_parameters: dict[str, Any],
|
|
max_position_embeddings: int = 4096,
|
|
quant_config: QuantizationConfig | None = None,
|
|
bias: bool = False,
|
|
cache_config: CacheConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = num_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.total_num_kv_heads = num_kv_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
self.head_dim = hidden_size // self.total_num_heads
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.scaling = self.head_dim**-0.5
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size=hidden_size,
|
|
head_size=self.head_dim,
|
|
total_num_heads=self.total_num_heads,
|
|
total_num_kv_heads=self.total_num_kv_heads,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
input_size=self.total_num_heads * self.head_dim,
|
|
output_size=hidden_size,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
|
|
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=max_position_embeddings,
|
|
rope_parameters=rope_parameters,
|
|
)
|
|
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
def _apply_qk_norm(
|
|
self, q: torch.Tensor, k: torch.Tensor
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# reshape for layernorm
|
|
q = q.reshape(-1, self.num_heads, self.head_dim)
|
|
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
|
|
q = self.q_norm(q)
|
|
k = self.k_norm(k)
|
|
q = q.view(*q.shape[:-2], -1)
|
|
k = k.view(*k.shape[:-2], -1)
|
|
return q, k
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q, k = self._apply_qk_norm(q, k)
|
|
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class ChameleonDecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ChameleonConfig,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
|
|
|
|
self.self_attn = ChameleonAttention(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_kv_heads=getattr(
|
|
config, "num_key_value_heads", config.num_attention_heads
|
|
),
|
|
rope_parameters=config.rope_parameters,
|
|
max_position_embeddings=max_position_embeddings,
|
|
quant_config=quant_config,
|
|
bias=False,
|
|
cache_config=cache_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
self.mlp = ChameleonMLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
bias=getattr(config, "mlp_bias", False),
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
)
|
|
|
|
# Fully Connected
|
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
return hidden_states, residual
|
|
|
|
|
|
class ChameleonSwinDecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ChameleonConfig,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
|
|
|
|
self.self_attn = ChameleonAttention(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_kv_heads=getattr(
|
|
config, "num_key_value_heads", config.num_attention_heads
|
|
),
|
|
rope_parameters=config.rope_parameters,
|
|
max_position_embeddings=max_position_embeddings,
|
|
quant_config=quant_config,
|
|
bias=False,
|
|
cache_config=cache_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
self.mlp = ChameleonMLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
bias=getattr(config, "mlp_bias", False),
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(
|
|
config.hidden_size, eps=config.rms_norm_eps
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
)
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states = hidden_states + residual
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states, residual
|
|
|
|
|
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
|
|
class ChameleonVQVAEVectorQuantizer(nn.Module):
|
|
def __init__(self, config: ChameleonVQVAEConfig):
|
|
super().__init__()
|
|
self.num_embeddings = config.num_embeddings
|
|
self.embedding_dim = config.embed_dim
|
|
self.beta = getattr(config, "beta", 0.25)
|
|
|
|
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
|
|
self.re_embed = self.num_embeddings
|
|
|
|
def forward(self, hidden_state: torch.Tensor):
|
|
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
|
|
hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
|
|
|
|
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
|
distances = (
|
|
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
|
|
+ torch.sum(self.embedding.weight**2, dim=1)
|
|
- 2
|
|
* torch.einsum(
|
|
"bd,dn->bn",
|
|
hidden_state_flattened,
|
|
self.embedding.weight.transpose(0, 1),
|
|
)
|
|
)
|
|
|
|
min_encoding_indices = torch.argmin(distances, dim=1)
|
|
hidden_state_quant = self.embedding(min_encoding_indices).view(
|
|
hidden_state.shape
|
|
)
|
|
|
|
# compute loss for embedding
|
|
loss = torch.mean(
|
|
(hidden_state_quant.detach() - hidden_state) ** 2
|
|
) + self.beta * torch.mean((hidden_state_quant - hidden_state.detach()) ** 2)
|
|
|
|
# preserve gradients
|
|
hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
|
|
|
|
# reshape back to match original input shape
|
|
hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
|
|
|
|
return hidden_state_quant, loss, min_encoding_indices
|
|
|
|
|
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
|
|
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
|
|
def __init__(self, in_channels: int):
|
|
super().__init__()
|
|
self.conv = Conv2dLayer(
|
|
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
# no asymmetric padding in torch conv, must do it ourselves
|
|
hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
|
|
hidden_states = self.conv(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
|
|
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: ChameleonVQVAEConfig,
|
|
in_channels: int,
|
|
out_channels=None,
|
|
conv_shortcut=False,
|
|
):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = in_channels if out_channels is None else out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
|
|
self.norm1 = torch.nn.GroupNorm(
|
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
|
)
|
|
self.conv1 = Conv2dLayer(
|
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
self.norm2 = torch.nn.GroupNorm(
|
|
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
|
)
|
|
self.dropout = torch.nn.Dropout(config.dropout)
|
|
self.conv2 = Conv2dLayer(
|
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
self.conv_shortcut = Conv2dLayer(
|
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
else:
|
|
self.nin_shortcut = Conv2dLayer(
|
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
residual = hidden_states
|
|
hidden_states = self.norm1(hidden_states)
|
|
hidden_states *= torch.sigmoid(hidden_states)
|
|
hidden_states = self.conv1(hidden_states)
|
|
|
|
hidden_states = self.norm2(hidden_states)
|
|
hidden_states *= torch.sigmoid(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.conv2(hidden_states)
|
|
|
|
if self.in_channels != self.out_channels:
|
|
if self.use_conv_shortcut:
|
|
residual = self.conv_shortcut(residual)
|
|
else:
|
|
residual = self.nin_shortcut(residual)
|
|
|
|
return residual + hidden_states
|
|
|
|
|
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
|
|
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
|
|
def __init__(self, in_channels: int):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
|
|
self.norm = torch.nn.GroupNorm(
|
|
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
|
)
|
|
self.q = Conv2dLayer(
|
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
)
|
|
self.k = Conv2dLayer(
|
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
)
|
|
self.v = Conv2dLayer(
|
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
)
|
|
self.proj_out = Conv2dLayer(
|
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
residual = hidden_states
|
|
hidden_states = self.norm(hidden_states)
|
|
query_states = self.q(hidden_states)
|
|
key_states = self.k(hidden_states)
|
|
value_states = self.v(hidden_states)
|
|
|
|
# compute attention
|
|
batch_size, channels, height, width = query_states.shape
|
|
query_states = query_states.reshape(
|
|
batch_size, channels, height * width
|
|
).permute(0, 2, 1)
|
|
key_states = key_states.reshape(batch_size, channels, height * width)
|
|
attn_weights = torch.bmm(query_states, key_states)
|
|
attn_weights = attn_weights * (int(channels) ** (-0.5))
|
|
attn_weights = F.softmax(attn_weights, dim=2)
|
|
|
|
# attend to values
|
|
value_states = value_states.reshape(batch_size, channels, height * width)
|
|
attn_weights = attn_weights.permute(0, 2, 1)
|
|
attn_output = torch.bmm(value_states, attn_weights).reshape(
|
|
batch_size, channels, height, width
|
|
)
|
|
|
|
attn_output = self.proj_out(attn_output)
|
|
return residual + attn_output
|
|
|
|
|
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
|
|
class ChameleonVQVAEEncoder(nn.Module):
|
|
def __init__(self, config: ChameleonVQVAEConfig):
|
|
super().__init__()
|
|
|
|
self.num_resolutions = len(config.channel_multiplier)
|
|
self.num_res_blocks = config.num_res_blocks
|
|
base_channels = config.base_channels
|
|
resolution = config.resolution
|
|
in_channels = config.in_channels
|
|
double_latent = config.double_latent
|
|
latent_channels = config.latent_channels
|
|
channel_multiplier = config.channel_multiplier
|
|
|
|
self.conv_in = Conv2dLayer(
|
|
in_channels, base_channels, kernel_size=3, stride=1, padding=1
|
|
)
|
|
|
|
curr_res = resolution
|
|
in_channel_multiplier = (1,) + tuple(channel_multiplier)
|
|
self.in_channel_multiplier = in_channel_multiplier
|
|
self.down = nn.ModuleList()
|
|
for i_level in range(self.num_resolutions):
|
|
block = nn.ModuleList()
|
|
attn = nn.ModuleList()
|
|
block_in = base_channels * in_channel_multiplier[i_level]
|
|
block_out = base_channels * channel_multiplier[i_level]
|
|
for i_block in range(self.num_res_blocks):
|
|
block.append(
|
|
ChameleonVQVAEEncoderResnetBlock(
|
|
config=config,
|
|
in_channels=block_in,
|
|
out_channels=block_out,
|
|
)
|
|
)
|
|
block_in = block_out
|
|
if (
|
|
config.attn_resolutions is not None
|
|
and curr_res in config.attn_resolutions
|
|
and config.attn_type == "vanilla"
|
|
):
|
|
attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
|
|
|
|
down = nn.Module()
|
|
down.block = block
|
|
down.attn = attn
|
|
if i_level != self.num_resolutions - 1:
|
|
down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
|
|
curr_res = curr_res // 2
|
|
self.down.append(down)
|
|
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
|
|
config=config,
|
|
in_channels=block_in,
|
|
out_channels=block_in,
|
|
)
|
|
self.mid.attn_1 = (
|
|
ChameleonVQVAEEncoderAttnBlock(block_in)
|
|
if config.attn_type == "vanilla"
|
|
else nn.Identity()
|
|
)
|
|
self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
|
|
config=config,
|
|
in_channels=block_in,
|
|
out_channels=block_in,
|
|
)
|
|
|
|
self.norm_out = torch.nn.GroupNorm(
|
|
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
|
)
|
|
self.conv_out = Conv2dLayer(
|
|
block_in,
|
|
2 * latent_channels if double_latent else latent_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
)
|
|
|
|
def forward(self, pixel_values: torch.Tensor):
|
|
pixel_values = pixel_values.to(self.conv_in.weight.dtype)
|
|
|
|
# downsampling
|
|
hidden_states = [self.conv_in(pixel_values)]
|
|
for i_level in range(self.num_resolutions):
|
|
for i_block in range(self.num_res_blocks):
|
|
hidden_state = self.down[i_level].block[i_block](hidden_states[-1])
|
|
if len(self.down[i_level].attn) > 0:
|
|
hidden_state = self.down[i_level].attn[i_block](hidden_state)
|
|
hidden_states.append(hidden_state)
|
|
if i_level != self.num_resolutions - 1:
|
|
hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
|
|
|
|
# middle
|
|
last_hidden_state = hidden_states[-1]
|
|
last_hidden_state = self.mid.block_1(last_hidden_state)
|
|
last_hidden_state = self.mid.attn_1(last_hidden_state)
|
|
last_hidden_state = self.mid.block_2(last_hidden_state)
|
|
|
|
# end
|
|
last_hidden_state = self.norm_out(last_hidden_state)
|
|
last_hidden_state *= torch.sigmoid(last_hidden_state)
|
|
last_hidden_state = self.conv_out(last_hidden_state)
|
|
return last_hidden_state
|
|
|
|
|
|
# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
|
|
class ChameleonVQVAE(nn.Module):
|
|
def __init__(self, config: ChameleonVQVAEConfig):
|
|
super().__init__()
|
|
self.encoder = ChameleonVQVAEEncoder(config)
|
|
self.quantize = ChameleonVQVAEVectorQuantizer(config)
|
|
self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1)
|
|
self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1)
|
|
self.eval() # Chameleon's VQ model is frozen
|
|
|
|
def encode(
|
|
self, pixel_values: torch.Tensor
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
hidden_states = self.encoder(pixel_values)
|
|
hidden_states = self.quant_conv(hidden_states)
|
|
quant, emb_loss, indices = self.quantize(hidden_states)
|
|
return quant, emb_loss, indices
|
|
|
|
|
|
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
|
|
class ChameleonImageVocabularyMapping:
|
|
"""
|
|
A class for mapping discrete image tokens from VQGAN to BPE tokens.
|
|
"""
|
|
|
|
def __init__(self, vocab_map: dict[str, int]):
|
|
self.vocab_map = vocab_map
|
|
self.image_token_id = vocab_map.get("<image>")
|
|
|
|
@cached_property
|
|
def val2name(self):
|
|
return {v: k for k, v in self.vocab_map.items()}
|
|
|
|
@cached_property
|
|
def image_tokens(self):
|
|
return sorted(
|
|
[val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]
|
|
)
|
|
|
|
@cached_property
|
|
def bpe2img(self):
|
|
img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
|
|
|
|
def remap(old_name: str) -> str:
|
|
return "".join(
|
|
img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]
|
|
)
|
|
|
|
return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
|
|
|
|
@cached_property
|
|
def img2bpe(self):
|
|
return {v: k for k, v in self.bpe2img.items()}
|
|
|
|
@cached_property
|
|
def bpe2img_search_tensors(self):
|
|
return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
|
|
sorted(self.bpe2img.values())
|
|
)
|
|
|
|
@cached_property
|
|
def img2bpe_mapping_tensor(self):
|
|
mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
|
|
for k, v in self.img2bpe.items():
|
|
mapping[k] = v
|
|
return mapping
|
|
|
|
def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
|
|
device = img_batch.device
|
|
img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
|
|
return img_tokens.to(device)
|
|
|
|
|
|
class ChameleonModel(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
self.vocab_size = config.vocab_size
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
self.vocab_size,
|
|
config.hidden_size,
|
|
)
|
|
self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
|
|
decoder_layer = (
|
|
ChameleonDecoderLayer
|
|
if not self.config.swin_norm
|
|
else ChameleonSwinDecoderLayer
|
|
)
|
|
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda prefix: decoder_layer(
|
|
config=config,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
),
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.vqmodel = ChameleonVQVAE(config.vq_config)
|
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
|
["hidden_states", "residual"], config.hidden_size
|
|
)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Tokenizes images into discrete tokens with VQGAN module. Converts
|
|
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
|
|
special tokens.
|
|
"""
|
|
batch_size = pixel_values.shape[0]
|
|
_, _, image_toks = self.vqmodel.encode(pixel_values)
|
|
bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
|
|
bpe_toks = bpe_toks.view(batch_size, -1)
|
|
return bpe_toks
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.embed_input_ids(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
residual,
|
|
)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors(
|
|
{"hidden_states": hidden_states, "residual": residual}
|
|
)
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
ChameleonMultiModalProcessor,
|
|
info=ChameleonProcessingInfo,
|
|
dummy_inputs=ChameleonDummyInputsBuilder,
|
|
)
|
|
class ChameleonForConditionalGeneration(
|
|
nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
|
|
):
|
|
merge_by_field_config = True
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
}
|
|
|
|
@classmethod
|
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
|
if modality.startswith("image"):
|
|
return "<image>"
|
|
|
|
raise ValueError("Only image modality is supported")
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
self.model = ChameleonModel(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
|
)
|
|
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
prefix=maybe_prefix(prefix, "lm_head"),
|
|
)
|
|
if config.tie_word_embeddings:
|
|
self.lm_head.weight = self.model.embed_tokens.weight
|
|
|
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors
|
|
)
|
|
|
|
def _parse_and_validate_image_input(
|
|
self, **kwargs: object
|
|
) -> ChameleonImagePixelInputs | None:
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
|
|
if pixel_values is None:
|
|
return None
|
|
|
|
vq_config: ChameleonVQVAEConfig = self.config.vq_config
|
|
expected_h = expected_w = vq_config.resolution
|
|
|
|
return ChameleonImagePixelInputs(
|
|
type="pixel_values",
|
|
data=pixel_values,
|
|
resolve_bindings={"h": expected_h, "w": expected_w},
|
|
)
|
|
|
|
def get_language_model(self) -> torch.nn.Module:
|
|
return self.model
|
|
|
|
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
if image_input is None:
|
|
return []
|
|
assert self.model.vqmodel is not None
|
|
image_tokens = self.model.get_image_tokens(
|
|
image_input["data"].to(self.config.dtype)
|
|
)
|
|
vision_embeddings = self.model.embed_input_ids(image_tokens)
|
|
return vision_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
hidden_states = self.model(
|
|
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
|
)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
|
|
# Disallow image tokens which does not include special
|
|
# begin-image and end-image tokens
|
|
if logits is not None:
|
|
image_tokens = self.model.vocabulary_mapping.image_tokens
|
|
logits[:, image_tokens] = torch.finfo(logits.dtype).min
|
|
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".qkv_proj", ".q_proj", "q"),
|
|
(".qkv_proj", ".k_proj", "k"),
|
|
(".qkv_proj", ".v_proj", "v"),
|
|
(".gate_up_proj", ".gate_proj", 0),
|
|
(".gate_up_proj", ".up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
|
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
|
# Models trained using ColossalAI may include these tensors in
|
|
# the checkpoint. Skip them.
|
|
continue
|
|
|
|
# With tie_word_embeddings, we can skip lm_head.weight
|
|
# The weight might appear unnecessarily in the files if the model is
|
|
# processed with quantization, LoRA, fine-tuning, etc.
|
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
|
continue
|
|
|
|
use_default_weight_loading = False
|
|
if "vqmodel" in name:
|
|
if self.model.vqmodel is not None:
|
|
# We only do sharding for language model and
|
|
# not vqvae for now.
|
|
use_default_weight_loading = True
|
|
else:
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
# Remapping the name of FP8 kv-scale.
|
|
if name.endswith("kv_scale"):
|
|
remapped_kv_scale_name = name.replace(
|
|
".kv_scale", ".attn.kv_scale"
|
|
)
|
|
if remapped_kv_scale_name not in params_dict:
|
|
logger.warning_once(
|
|
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
|
|
name,
|
|
remapped_kv_scale_name,
|
|
)
|
|
continue
|
|
else:
|
|
name = remapped_kv_scale_name
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
if use_default_weight_loading and name in params_dict:
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|