Matthew Bonanni 430dd4d9eb
[Attention] Remove imports from vllm/attention/__init__.py (#29342)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
2025-11-26 10:53:15 -07:00

1099 lines
40 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from itertools import islice
from typing import Annotated, Any, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
BatchFeature,
ChameleonConfig,
ChameleonProcessor,
ChameleonVQVAEConfig,
)
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
row_parallel_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from .utils import (
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
class ChameleonImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class ChameleonProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(ChameleonConfig)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(ChameleonProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": 1}
def get_num_image_tokens(self) -> int:
processor = self.get_hf_processor()
return processor.image_seq_length
class ChameleonDummyInputsBuilder(BaseDummyInputsBuilder[ChameleonProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
config = self.info.get_hf_config()
width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=width,
height=height,
num_images=num_images,
overrides=image_overrides,
)
}
class ChameleonMultiModalProcessor(BaseMultiModalProcessor[ChameleonProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor adds sep token for chat mode
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
sep_token_id = vocab[tokenizer.sep_token] # type: ignore
return prompt_tokens + [sep_token_id]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
image_start_id = vocab[processor.image_start_token]
image_token_id = vocab[processor.image_token]
image_end_id = vocab[processor.image_end_token]
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=PromptUpdateDetails.select_token_id(
[image_start_id] + image_tokens + [image_end_id],
embed_token_id=image_token_id,
),
)
]
class ChameleonLayerNorm(nn.LayerNorm):
def __init__(self, hidden_size, *args, **kwargs):
super().__init__(hidden_size, *args, **kwargs)
self.normalized_shape = (hidden_size[-1],)
set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader})
set_weight_attrs(self.bias, {"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states):
hidden_states = F.layer_norm(
hidden_states, self.normalized_shape, None, None, eps=1e-5
)
hidden_states = hidden_states * self.weight + self.bias
return hidden_states
# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP
class ChameleonMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa
class ChameleonAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_parameters: dict[str, Any],
max_position_embeddings: int = 4096,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
cache_config: CacheConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# reshape for layernorm
q = q.reshape(-1, self.num_heads, self.head_dim)
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
q = self.q_norm(q)
k = self.k_norm(k)
q = q.view(*q.shape[:-2], -1)
k = k.view(*k.shape[:-2], -1)
return q, k
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class ChameleonDecoderLayer(nn.Module):
def __init__(
self,
config: ChameleonConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
self.self_attn = ChameleonAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = ChameleonMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class ChameleonSwinDecoderLayer(nn.Module):
def __init__(
self,
config: ChameleonConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
self.self_attn = ChameleonAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
rope_parameters=config.rope_parameters,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=False,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = ChameleonMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states = self.input_layernorm(hidden_states)
hidden_states = hidden_states + residual
# Fully Connected
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, residual
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
class ChameleonVQVAEVectorQuantizer(nn.Module):
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__()
self.num_embeddings = config.num_embeddings
self.embedding_dim = config.embed_dim
self.beta = getattr(config, "beta", 0.25)
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.re_embed = self.num_embeddings
def forward(self, hidden_state: torch.Tensor):
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
distances = (
torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2
* torch.einsum(
"bd,dn->bn",
hidden_state_flattened,
self.embedding.weight.transpose(0, 1),
)
)
min_encoding_indices = torch.argmin(distances, dim=1)
hidden_state_quant = self.embedding(min_encoding_indices).view(
hidden_state.shape
)
# compute loss for embedding
loss = torch.mean(
(hidden_state_quant.detach() - hidden_state) ** 2
) + self.beta * torch.mean((hidden_state_quant - hidden_state.detach()) ** 2)
# preserve gradients
hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
# reshape back to match original input shape
hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
return hidden_state_quant, loss, min_encoding_indices
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = Conv2dLayer(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, hidden_states: torch.Tensor):
# no asymmetric padding in torch conv, must do it ourselves
hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
hidden_states = self.conv(hidden_states)
return hidden_states
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
class ChameleonVQVAEEncoderResnetBlock(nn.Module):
def __init__(
self,
config: ChameleonVQVAEConfig,
in_channels: int,
out_channels=None,
conv_shortcut=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.conv1 = Conv2dLayer(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = torch.nn.GroupNorm(
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
)
self.dropout = torch.nn.Dropout(config.dropout)
self.conv2 = Conv2dLayer(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = Conv2dLayer(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = Conv2dLayer(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, hidden_states: torch.Tensor):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states *= torch.sigmoid(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states *= torch.sigmoid(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)
return residual + hidden_states
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
class ChameleonVQVAEEncoderAttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.q = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = Conv2dLayer(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, hidden_states: torch.Tensor):
residual = hidden_states
hidden_states = self.norm(hidden_states)
query_states = self.q(hidden_states)
key_states = self.k(hidden_states)
value_states = self.v(hidden_states)
# compute attention
batch_size, channels, height, width = query_states.shape
query_states = query_states.reshape(
batch_size, channels, height * width
).permute(0, 2, 1)
key_states = key_states.reshape(batch_size, channels, height * width)
attn_weights = torch.bmm(query_states, key_states)
attn_weights = attn_weights * (int(channels) ** (-0.5))
attn_weights = F.softmax(attn_weights, dim=2)
# attend to values
value_states = value_states.reshape(batch_size, channels, height * width)
attn_weights = attn_weights.permute(0, 2, 1)
attn_output = torch.bmm(value_states, attn_weights).reshape(
batch_size, channels, height, width
)
attn_output = self.proj_out(attn_output)
return residual + attn_output
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
class ChameleonVQVAEEncoder(nn.Module):
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__()
self.num_resolutions = len(config.channel_multiplier)
self.num_res_blocks = config.num_res_blocks
base_channels = config.base_channels
resolution = config.resolution
in_channels = config.in_channels
double_latent = config.double_latent
latent_channels = config.latent_channels
channel_multiplier = config.channel_multiplier
self.conv_in = Conv2dLayer(
in_channels, base_channels, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_channel_multiplier = (1,) + tuple(channel_multiplier)
self.in_channel_multiplier = in_channel_multiplier
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = base_channels * in_channel_multiplier[i_level]
block_out = base_channels * channel_multiplier[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_out,
)
)
block_in = block_out
if (
config.attn_resolutions is not None
and curr_res in config.attn_resolutions
and config.attn_type == "vanilla"
):
attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_in,
)
self.mid.attn_1 = (
ChameleonVQVAEEncoderAttnBlock(block_in)
if config.attn_type == "vanilla"
else nn.Identity()
)
self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
config=config,
in_channels=block_in,
out_channels=block_in,
)
self.norm_out = torch.nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
self.conv_out = Conv2dLayer(
block_in,
2 * latent_channels if double_latent else latent_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, pixel_values: torch.Tensor):
pixel_values = pixel_values.to(self.conv_in.weight.dtype)
# downsampling
hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_state = self.down[i_level].block[i_block](hidden_states[-1])
if len(self.down[i_level].attn) > 0:
hidden_state = self.down[i_level].attn[i_block](hidden_state)
hidden_states.append(hidden_state)
if i_level != self.num_resolutions - 1:
hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
# middle
last_hidden_state = hidden_states[-1]
last_hidden_state = self.mid.block_1(last_hidden_state)
last_hidden_state = self.mid.attn_1(last_hidden_state)
last_hidden_state = self.mid.block_2(last_hidden_state)
# end
last_hidden_state = self.norm_out(last_hidden_state)
last_hidden_state *= torch.sigmoid(last_hidden_state)
last_hidden_state = self.conv_out(last_hidden_state)
return last_hidden_state
# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
class ChameleonVQVAE(nn.Module):
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__()
self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
def encode(
self, pixel_values: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states)
return quant, emb_loss, indices
# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
class ChameleonImageVocabularyMapping:
"""
A class for mapping discrete image tokens from VQGAN to BPE tokens.
"""
def __init__(self, vocab_map: dict[str, int]):
self.vocab_map = vocab_map
self.image_token_id = vocab_map.get("<image>")
@cached_property
def val2name(self):
return {v: k for k, v in self.vocab_map.items()}
@cached_property
def image_tokens(self):
return sorted(
[val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]
)
@cached_property
def bpe2img(self):
img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
def remap(old_name: str) -> str:
return "".join(
img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]
)
return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
@cached_property
def img2bpe(self):
return {v: k for k, v in self.bpe2img.items()}
@cached_property
def bpe2img_search_tensors(self):
return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
sorted(self.bpe2img.values())
)
@cached_property
def img2bpe_mapping_tensor(self):
mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
for k, v in self.img2bpe.items():
mapping[k] = v
return mapping
def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
device = img_batch.device
img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
return img_tokens.to(device)
class ChameleonModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
decoder_layer = (
ChameleonDecoderLayer
if not self.config.swin_norm
else ChameleonSwinDecoderLayer
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.vqmodel = ChameleonVQVAE(config.vq_config)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Tokenizes images into discrete tokens with VQGAN module. Converts
obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
special tokens.
"""
batch_size = pixel_values.shape[0]
_, _, image_toks = self.vqmodel.encode(pixel_values)
bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
bpe_toks = bpe_toks.view(batch_size, -1)
return bpe_toks
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(
ChameleonMultiModalProcessor,
info=ChameleonProcessingInfo,
dummy_inputs=ChameleonDummyInputsBuilder,
)
class ChameleonForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<image>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.model = ChameleonModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> ChameleonImagePixelInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
vq_config: ChameleonVQVAEConfig = self.config.vq_config
expected_h = expected_w = vq_config.resolution
return ChameleonImagePixelInputs(
type="pixel_values",
data=pixel_values,
resolve_bindings={"h": expected_h, "w": expected_w},
)
def get_language_model(self) -> torch.nn.Module:
return self.model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.dtype)
)
vision_embeddings = self.model.embed_input_ids(image_tokens)
return vision_embeddings
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
# Disallow image tokens which does not include special
# begin-image and end-image tokens
if logits is not None:
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, image_tokens] = torch.finfo(logits.dtype).min
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
use_default_weight_loading = False
if "vqmodel" in name:
if self.model.vqmodel is not None:
# We only do sharding for language model and
# not vqvae for now.
use_default_weight_loading = True
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale"
)
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if use_default_weight_loading and name in params_dict:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params