mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 16:35:01 +08:00
Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Natan Bagrov <nbagrov@nvidia.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Natan Bagrov <nbagrov@nvidia.com> Co-authored-by: Roger Wang <hey@rogerw.io>
556 lines
18 KiB
Python
556 lines
18 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
# and proprietary rights in and to this software, related documentation
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
# distribution of this software and related documentation without an express
|
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
|
|
import math
|
|
from collections.abc import Iterable
|
|
from itertools import repeat
|
|
from typing import TypeAlias
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.intern_vit import InternVisionEncoder
|
|
|
|
input_dim_t: TypeAlias = int | tuple[int, int]
|
|
norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor
|
|
|
|
|
|
def _ntuple(n):
|
|
def parse(x):
|
|
if isinstance(x, Iterable) and not isinstance(x, str):
|
|
return tuple(x)
|
|
return tuple(repeat(x, n))
|
|
|
|
return parse
|
|
|
|
|
|
to_1tuple = _ntuple(1)
|
|
to_2tuple = _ntuple(2)
|
|
to_3tuple = _ntuple(3)
|
|
to_4tuple = _ntuple(4)
|
|
to_ntuple = _ntuple
|
|
|
|
|
|
class ClsToken(nn.Module):
|
|
def __init__(
|
|
self,
|
|
ndim: int,
|
|
num_tokens: int = 1,
|
|
enabled: bool = True,
|
|
register_multiple: int | None = None,
|
|
num_registers: int | None = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.ndim = ndim
|
|
self.enabled = enabled
|
|
self.num_registers = 0
|
|
self.num_tokens = num_tokens
|
|
if enabled:
|
|
if num_registers:
|
|
self.num_registers = num_registers
|
|
elif register_multiple:
|
|
self.num_registers = register_multiple - (
|
|
num_tokens % register_multiple
|
|
)
|
|
|
|
scale = ndim**-0.5
|
|
self.token = nn.Parameter(
|
|
torch.randn(num_tokens + self.num_registers, ndim) * scale
|
|
)
|
|
|
|
else:
|
|
self.token = None
|
|
|
|
self.num_patches = self.num_tokens + self.num_registers
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if self.token is None:
|
|
return x
|
|
|
|
token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
|
|
x = torch.cat(
|
|
[
|
|
token,
|
|
x,
|
|
],
|
|
dim=1,
|
|
)
|
|
|
|
return x
|
|
|
|
|
|
class ViTPatchGenerator(nn.Module):
|
|
def __init__(
|
|
self,
|
|
# config: PretrainedConfig,
|
|
patch_size: int,
|
|
embed_dim: int,
|
|
input_dims: input_dim_t,
|
|
abs_pos: bool = True,
|
|
normalize_patches: bool = False,
|
|
cls_token: bool = False,
|
|
max_input_dims: input_dim_t | None = None,
|
|
pos_dropout: float = 0.0,
|
|
return_pos_enc: bool = False,
|
|
num_cls_tokens: int = 1,
|
|
register_multiple: int | None = None,
|
|
num_registers: int | None = None,
|
|
patch_bias: bool = False,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
super().__init__()
|
|
if isinstance(input_dims, int):
|
|
input_dims = (input_dims, input_dims)
|
|
|
|
if max_input_dims is None:
|
|
max_input_dims = input_dims
|
|
if isinstance(max_input_dims, int):
|
|
max_input_dims = (max_input_dims, max_input_dims)
|
|
|
|
max_input_dims = tuple(
|
|
int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims
|
|
)
|
|
|
|
self.cpe_mode = max_input_dims != input_dims
|
|
self.pos_dropout = pos_dropout
|
|
self.return_pos_enc = return_pos_enc
|
|
|
|
factory = dict(device=device, dtype=dtype)
|
|
|
|
self.patch_size = patch_size
|
|
self.abs_pos = abs_pos
|
|
self.embed_dim = embed_dim
|
|
|
|
self.num_rows = max_input_dims[0] // patch_size
|
|
self.num_cols = max_input_dims[1] // patch_size
|
|
self.input_dims = tuple(d // patch_size for d in input_dims)
|
|
self.num_patches = self.num_rows * self.num_cols
|
|
self.max_input_dims = max_input_dims
|
|
|
|
self.im_to_patches = Im2Patches(patch_size)
|
|
self.embedder = ViTPatchLinear(
|
|
patch_size, embed_dim, bias=patch_bias, **factory
|
|
)
|
|
|
|
if abs_pos:
|
|
scale = embed_dim**-0.5
|
|
self.pos_embed = nn.Parameter(
|
|
torch.randn(1, self.num_patches, embed_dim, **factory) * scale
|
|
)
|
|
|
|
self.cls_token = ClsToken(
|
|
embed_dim,
|
|
num_tokens=num_cls_tokens,
|
|
enabled=cls_token,
|
|
register_multiple=register_multiple,
|
|
num_registers=num_registers,
|
|
)
|
|
|
|
self.patch_normalizer = (
|
|
nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
patches = self.embed_patches(x)
|
|
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
|
|
patches = self.cls_token(patches)
|
|
patches = self.patch_normalizer(patches)
|
|
if self.return_pos_enc:
|
|
return patches, pos_enc
|
|
return patches
|
|
|
|
@property
|
|
def apply_cls_token(self):
|
|
return self.cls_token.enabled
|
|
|
|
@property
|
|
def num_cls_tokens(self):
|
|
return self.cls_token.num_tokens
|
|
|
|
@property
|
|
def num_cls_patches(self):
|
|
return self.cls_token.num_patches
|
|
|
|
@property
|
|
def num_registers(self):
|
|
return self.cls_token.num_registers
|
|
|
|
@property
|
|
def num_skip(self):
|
|
return self.num_cls_tokens + self.num_registers
|
|
|
|
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
|
|
if src_embed.shape != targ_embed.shape:
|
|
src_size = int(math.sqrt(src_embed.shape[1]))
|
|
|
|
assert src_size**2 == src_embed.shape[1], (
|
|
"Unable to interpolate non-square embedding"
|
|
)
|
|
|
|
src_embed = rearrange(
|
|
src_embed, "b (h w) c -> b c h w", h=src_size, w=src_size
|
|
)
|
|
src_embed = F.interpolate(
|
|
src_embed,
|
|
size=(self.num_rows, self.num_cols),
|
|
mode="bicubic",
|
|
align_corners=True,
|
|
antialias=False,
|
|
)
|
|
src_embed = rearrange(src_embed, "b c h w -> b (h w) c")
|
|
targ_embed.data.copy_(src_embed)
|
|
|
|
def _load_projection(
|
|
self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor
|
|
):
|
|
if src_proj_weight.shape != targ_proj_weight.shape:
|
|
src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))
|
|
|
|
assert (src_patch_size**2) * 3 == src_proj_weight.shape[1], (
|
|
"Unable to interpolate non-square patch size"
|
|
)
|
|
|
|
src_proj_weight = rearrange(
|
|
src_proj_weight,
|
|
"b (c h w) -> b c h w",
|
|
c=3,
|
|
h=src_patch_size,
|
|
w=src_patch_size,
|
|
)
|
|
src_proj_weight = F.interpolate(
|
|
src_proj_weight,
|
|
size=(self.patch_size, self.patch_size),
|
|
mode="bicubic",
|
|
align_corners=True,
|
|
antialias=False,
|
|
)
|
|
src_proj_weight = rearrange(src_proj_weight, "b c h w -> b (c h w)")
|
|
targ_proj_weight.data.copy_(src_proj_weight)
|
|
|
|
def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
|
|
patches = self.im_to_patches(x)
|
|
patches = self.embedder(patches)
|
|
return patches
|
|
|
|
def apply_pos_enc(
|
|
self,
|
|
patches: torch.Tensor,
|
|
patch_idxs: torch.Tensor | None = None,
|
|
input_size: tuple[int, int] | None = None,
|
|
) -> torch.Tensor:
|
|
if not self.abs_pos:
|
|
return patches
|
|
|
|
pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
|
|
|
|
if self.training and self.pos_dropout > 0:
|
|
keeps = (
|
|
torch.rand(
|
|
patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device
|
|
)
|
|
> self.pos_dropout
|
|
)
|
|
pos_enc_drop = torch.where(keeps, pos_enc, 0)
|
|
else:
|
|
pos_enc_drop = pos_enc
|
|
|
|
return patches + pos_enc_drop, pos_enc
|
|
|
|
def get_pos_enc(
|
|
self,
|
|
batch_size: int,
|
|
patch_idxs: torch.Tensor | None = None,
|
|
input_size: tuple[int, int] | None = None,
|
|
) -> torch.Tensor:
|
|
if input_size is None:
|
|
input_dims = self.input_dims
|
|
else:
|
|
input_dims = tuple(d // self.patch_size for d in input_size)
|
|
|
|
pos_embed = self._get_pos_embeddings(batch_size, input_dims)
|
|
|
|
if patch_idxs is None:
|
|
return pos_embed
|
|
|
|
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
|
|
|
|
pos_embed = torch.gather(
|
|
pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs
|
|
)
|
|
return pos_embed
|
|
|
|
def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, int]):
|
|
if (self.num_rows, self.num_cols) == input_dims:
|
|
return self.pos_embed
|
|
|
|
pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(
|
|
0, 3, 1, 2
|
|
)
|
|
|
|
def window_select(pos_embed):
|
|
if input_dims[0] < pos_embed.shape[-2]:
|
|
pos_embed = pos_embed[..., : input_dims[0], :]
|
|
if input_dims[1] < pos_embed.shape[-1]:
|
|
pos_embed = pos_embed[..., :, : input_dims[1]]
|
|
return pos_embed
|
|
|
|
if self.cpe_mode:
|
|
if self.training:
|
|
min_scale = math.sqrt(0.1)
|
|
scale = (
|
|
torch.rand(batch_size, 1, 1, device=pos_embed.device)
|
|
* (1 - min_scale)
|
|
+ min_scale
|
|
)
|
|
aspect_min = math.log(3 / 4)
|
|
aspect_max = -aspect_min
|
|
aspect = torch.exp(
|
|
torch.rand(batch_size, 1, 1, device=pos_embed.device)
|
|
* (aspect_max - aspect_min)
|
|
+ aspect_min
|
|
)
|
|
|
|
scale_x = scale * aspect
|
|
scale_y = scale * (1 / aspect)
|
|
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
|
|
|
|
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (
|
|
1 - scale_xy
|
|
)
|
|
|
|
lin_x = torch.linspace(
|
|
0, 1, steps=input_dims[1], device=pos_embed.device
|
|
)[None, None].expand(batch_size, input_dims[0], -1)
|
|
lin_y = torch.linspace(
|
|
0, 1, steps=input_dims[0], device=pos_embed.device
|
|
)[None, :, None].expand(batch_size, -1, input_dims[1])
|
|
|
|
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
|
|
|
|
grid_xy = lin_xy * scale_xy + pos_xy
|
|
|
|
# Convert to [-1, 1] range
|
|
grid_xy.mul_(2).sub_(1)
|
|
|
|
pos_embed = F.grid_sample(
|
|
pos_embed.float().expand(batch_size, -1, -1, -1),
|
|
grid=grid_xy,
|
|
mode="bilinear",
|
|
padding_mode="zeros",
|
|
align_corners=True,
|
|
).to(pos_embed.dtype)
|
|
else:
|
|
max_dim = max(input_dims)
|
|
pos_embed = F.interpolate(
|
|
pos_embed.float(),
|
|
size=(max_dim, max_dim),
|
|
align_corners=True,
|
|
mode="bilinear",
|
|
).to(pos_embed.dtype)
|
|
|
|
pos_embed = window_select(pos_embed)
|
|
else:
|
|
pos_embed = window_select(pos_embed)
|
|
|
|
if pos_embed.shape[-2:] != input_dims:
|
|
pos_embed = F.interpolate(
|
|
pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear"
|
|
).to(pos_embed.dtype)
|
|
|
|
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
|
|
|
|
return pos_embed
|
|
|
|
|
|
class Im2Patches(nn.Module):
|
|
def __init__(self, patch_size: int):
|
|
super().__init__()
|
|
self.patch_size = patch_size
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self.patch_size == 1:
|
|
patches = x.flatten(2)
|
|
patches = patches.permute(0, 2, 1)
|
|
return patches
|
|
|
|
py = x.shape[-2] // self.patch_size
|
|
px = x.shape[-1] // self.patch_size
|
|
patches = rearrange(
|
|
x,
|
|
"b c (py yy) (px xx) -> b (py px) (c yy xx)",
|
|
py=py,
|
|
yy=self.patch_size,
|
|
px=px,
|
|
xx=self.patch_size,
|
|
)
|
|
return patches
|
|
|
|
|
|
class ViTPatchLinear(nn.Linear):
|
|
def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):
|
|
super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory)
|
|
self.patch_size = patch_size
|
|
|
|
|
|
class RadioInternVisionModel(nn.Module):
|
|
packed_modules_mapping = {
|
|
"qkv": ["qkv"],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
*,
|
|
num_hidden_layers_override: int | None = None,
|
|
num_dummy_heads: int = 0,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.img_size, self.grid_size, self.num_patches = self._init_img_size(
|
|
to_2tuple(config.patch_size), config.image_size
|
|
)
|
|
max_img_size = int(
|
|
round(config.max_img_size / config.patch_size) * config.patch_size
|
|
)
|
|
self.patch_generator = ViTPatchGenerator(
|
|
config.patch_size,
|
|
config.hidden_size,
|
|
input_dims=self.img_size,
|
|
max_input_dims=max_img_size,
|
|
cls_token=True,
|
|
register_multiple=config.reg_tokens,
|
|
)
|
|
|
|
self.encoder = InternVisionEncoder(
|
|
config=config,
|
|
quant_config=quant_config,
|
|
num_hidden_layers_override=num_hidden_layers_override,
|
|
num_dummy_heads=num_dummy_heads,
|
|
prefix=f"{prefix}.encoder",
|
|
)
|
|
|
|
def _init_img_size(self, patch_size, img_size: int | tuple[int, int]):
|
|
if img_size is None:
|
|
return None, None, None
|
|
img_size = to_2tuple(img_size)
|
|
grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
|
|
num_patches = grid_size[0] * grid_size[1]
|
|
return img_size, grid_size, num_patches
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embeddings
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.FloatTensor:
|
|
assert self.patch_generator is not None
|
|
hidden_states = self.patch_generator(x)
|
|
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
|
return encoder_outputs
|
|
|
|
|
|
class RadioModel(nn.Module):
|
|
packed_modules_mapping = {
|
|
"qkv": ["qkv"],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
*,
|
|
num_hidden_layers_override: int | None = None,
|
|
num_dummy_heads: int = 0,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.model = RadioInternVisionModel(
|
|
config=config,
|
|
quant_config=quant_config,
|
|
num_hidden_layers_override=num_hidden_layers_override,
|
|
num_dummy_heads=num_dummy_heads,
|
|
prefix=prefix,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.Tensor | None = None,
|
|
pixel_embeds: torch.Tensor | None = None,
|
|
) -> torch.FloatTensor:
|
|
y = self.model(pixel_values)
|
|
return self._extract_final(y)
|
|
|
|
def load_weights(self, weights) -> set[str]:
|
|
loaded_params: set[str] = set()
|
|
params_dict = dict(self.named_parameters())
|
|
|
|
if isinstance(weights, dict):
|
|
weights_list = list(weights.items())
|
|
else:
|
|
weights_list = list(weights)
|
|
|
|
for name, weight in weights_list:
|
|
if not name.startswith("radio_model."):
|
|
# Skip non-radio weights
|
|
continue
|
|
|
|
sub = name[len("radio_model.") :] # drop "radio_model." prefix
|
|
|
|
# Skip buffers not used in vLLM
|
|
if sub in {"summary_idxs"}:
|
|
continue
|
|
if sub.startswith("input_conditioner."):
|
|
# we normalize in the input processor,
|
|
# based on norm and std values from the config
|
|
continue
|
|
|
|
vllm_key = None
|
|
if sub.startswith("model.patch_generator."):
|
|
vllm_key = f"model.patch_generator.{sub.split('.', 2)[-1]}"
|
|
elif sub.startswith("input_conditioner."):
|
|
vllm_key = f"input_conditioner.{sub.split('.', 1)[-1]}"
|
|
elif sub.startswith("model.blocks."):
|
|
# Encoder blocks: HF 'model.blocks.{i}.' ->
|
|
# vLLM 'model.encoder.layers.{i}.'
|
|
parts = sub.split(".")
|
|
if len(parts) >= 4:
|
|
layer_idx = parts[2]
|
|
suffix = ".".join(parts[3:])
|
|
# Skip layer-scale entries that vLLM doesn't use
|
|
if suffix in {"ls1", "ls2"} or suffix.startswith(("ls1.", "ls2.")):
|
|
continue
|
|
vllm_key = f"model.encoder.layers.{layer_idx}.{suffix}"
|
|
|
|
if vllm_key and vllm_key in params_dict:
|
|
param = params_dict[vllm_key]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, weight)
|
|
loaded_params.add(vllm_key)
|
|
|
|
return loaded_params
|
|
|
|
def _extract_final(self, y: torch.Tensor):
|
|
# Remove CLS + REGISTERS tokens
|
|
patch_gen = getattr(self.model, "patch_generator", None)
|
|
if patch_gen is not None:
|
|
all_feat = y[:, patch_gen.num_skip :]
|
|
|
|
return all_feat
|