Harry Mellor 8fcaaf6a16
Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-12 09:51:31 -07:00

248 lines
8.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# A modified implementation of the AIMv2 Transformer
# inserted here also the image tokenizer used by Ovis2
from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.ovis import AIMv2Config
class AIMv2SwiGLUFFN(nn.Module):
def __init__(
self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.hidden_size
bias = config.use_bias
self.fc13 = MergedColumnParallelLinear(
in_features,
[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc13",
)
self.fc2 = RowParallelLinear(
input_size=hidden_features,
output_size=in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc13(x)
x = self.act_fn(x)
x, _ = self.fc2(x)
return x
class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.proj = nn.Conv2d(
config.num_channels,
config.hidden_size,
kernel_size=(config.patch_size, config.patch_size),
stride=(config.patch_size, config.patch_size),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm.forward_native(x)
return x
class AIMv2ViTPreprocessor(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
num_patches = (config.image_size // config.patch_size) ** 2
self.patchifier = AIMv2PatchEmbed(config)
self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
tokens = self.patchifier(x)
_, N, _ = tokens.shape
pos_embed = self.pos_embed.to(tokens.device)
tokens = tokens + pos_embed[:, :N]
return tokens
class AIMv2Attention(nn.Module):
def __init__(
self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.qkv = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
)
self.proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
bias=config.use_bias,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
x = self.attn(q, k, v)
x, _ = self.proj(x)
return x
class AIMv2Block(nn.Module):
def __init__(
self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str
):
super().__init__()
self.attn = AIMv2Attention(
config, quant_config=quant_config, prefix=f"{prefix}.attn"
)
self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = AIMv2SwiGLUFFN(
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
)
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm_1.forward_native(x))
x = x + self.mlp(self.norm_2.forward_native(x))
return x
class AIMv2Transformer(nn.Module):
def __init__(
self,
config: AIMv2Config,
quant_config: QuantizationConfig,
*,
require_post_norm: bool | None = None,
prefix: str = "",
):
super().__init__()
self.blocks = nn.ModuleList(
[
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
for i in range(config.num_hidden_layers)
]
)
if require_post_norm:
self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.post_trunk_norm = None
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
# they take the -1 as the ref embeddings, like a clip skip
for block in self.blocks:
tokens = block(tokens)
if self.post_trunk_norm is not None:
tokens = self.post_trunk_norm(tokens)
return tokens
class AIMv2Model(torch.nn.Module):
def __init__(
self,
config: AIMv2Config,
quant_config: QuantizationConfig,
*,
require_post_norm: bool | None = None,
prefix: str = "",
):
super().__init__()
self.preprocessor = AIMv2ViTPreprocessor(config)
self.trunk = AIMv2Transformer(
config,
quant_config=quant_config,
require_post_norm=require_post_norm,
prefix=f"{prefix}.trunk",
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
x = self.preprocessor(pixel_values)
x = self.trunk(x)
return x
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".fc13", ".fc1", 0),
(".fc13", ".fc3", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if (
name.startswith("trunk.post_trunk_norm")
and self.trunk.post_trunk_norm is None
):
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params