Add RADIO Vision Encoder Support to vLLM (#24595)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: root <root@cw-dfw-h100-001-305-026.cm.cluster>
This commit is contained in:
danielafrimi 2025-09-17 15:53:30 +03:00 committed by GitHub
parent e120533d7a
commit 252ada5559
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 826 additions and 56 deletions

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.models.radio import RadioModel
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import ImageTestAssets
# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
@torch.inference_mode()
def run_radio_test(
image_assets: ImageTestAssets,
model_id: str,
*,
dtype: str,
):
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
img_processor = CLIPImageProcessor.from_pretrained(model)
images = [asset.pil_image for asset in image_assets]
# Input resolution must be a multiple of `self.min_resolution_step`.
# Using `self.get_nearest_supported_resolution`, for assets 432x642 the
# nearest supported resolution is 432x640.
pixel_values = [
img_processor(
image,
return_tensors='pt').pixel_values.to(torch_dtype)[:, :, :, :640]
for image in images
]
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
hf_model = AutoModel.from_pretrained(
model_id,
config=config,
torch_dtype=torch_dtype,
trust_remote_code=True,
).to("cuda")
hf_model.eval()
hf_outputs_per_image = [
hf_model(pixel_value.to("cuda")).features
for pixel_value in pixel_values
]
radio_config = RadioConfig(model_name=config.args["model"],
reg_tokens=config.args["register_multiple"])
vllm_model = RadioModel(radio_config)
vllm_model.load_weights(hf_model.state_dict())
vllm_model = vllm_model.to("cuda", torch_dtype)
vllm_outputs_per_image = [
vllm_model(pixel_values=pixel_value.to("cuda"))
for pixel_value in pixel_values
]
del vllm_model, hf_model
cleanup_dist_env_and_memory()
cos_similar = nn.CosineSimilarity(dim=-1)
for vllm_output, hf_output in zip(vllm_outputs_per_image,
hf_outputs_per_image):
assert cos_similar(vllm_output, hf_output).mean() > 0.99
@pytest.mark.parametrize("model_id", [
"nvidia/C-RADIOv2-H",
])
@pytest.mark.parametrize("dtype", ["half"])
def test_radio(dist_init, image_assets, model_id, dtype: str) -> None:
run_radio_test(
image_assets,
model_id,
dtype=dtype,
)

View File

@ -18,8 +18,8 @@ import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import (AutoModel, BatchEncoding, BatchFeature,
PretrainedConfig, TensorType)
from transformers import (BatchEncoding, BatchFeature, PretrainedConfig,
TensorType)
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import ReLUSquaredActivation
@ -32,6 +32,7 @@ from vllm.model_executor.models.internvl import (calculate_internvl_targets,
get_internvl_target_ratios)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.radio import RadioModel
from vllm.model_executor.models.utils import (flatten_bn,
init_vllm_registered_model,
maybe_prefix,
@ -48,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@ -122,11 +124,6 @@ NanoNemotronVLVideoInputs = Union[NanoNemotronVLVideoPixelInputs,
NanoNemotronVLVideoEmbeddingInputs]
def input_conditioner(x, norm_mean, norm_std):
y = (x - norm_mean) / norm_std
return y
def dynamic_preprocess(image,
*,
image_size=512,
@ -305,8 +302,7 @@ class BaseNanoNemotronVLProcessor(ABC):
images, max_num_tiles)
image_inputs: dict[str, NestedTensors] = {
"pixel_values_flat":
input_conditioner(torch.cat(pixel_values_lst), self.norm_mean,
self.norm_std),
torch.cat(pixel_values_lst),
"image_num_patches":
torch.tensor([len(item) for item in pixel_values_lst]),
}
@ -428,8 +424,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
video_inputs: dict[str, NestedTensors] = {
"pixel_values_flat_video":
input_conditioner(torch.cat(pixel_values_lst_video),
self.norm_mean, self.norm_std),
torch.cat(pixel_values_lst_video),
"video_num_patches":
torch.tensor([len(item) for item in pixel_values_lst_video]),
}
@ -905,18 +900,9 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_model = AutoModel.from_config(config.vision_config,
trust_remote_code=True)
self.vision_model.model._initialize_weights = (
self.vision_model.model._init_weights)
# Move input normalization to processor to mirror original HF
# implementation where normalization is done in fp32
self.vision_model.radio_model.make_preprocessor_external()
self.vision_model = self.vision_model.to(
self.vision_model = self.get_vit_model_from_radio_config(config).to(
self.language_model.config.torch_dtype)
self.drop_vision_class_token = True
# Construct the vision projection.
vit_hidden_size = config.vit_hidden_size
vision_projection_hidden_size = config.projector_hidden_size
@ -972,7 +958,7 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
return x
def extract_feature(self, pixel_values):
vit_embeds = self.vision_model(pixel_values).features
vit_embeds = self.vision_model(pixel_values)
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
h = w = int(vit_embeds.shape[1]**0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
@ -1212,47 +1198,39 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
adapter_dict = dict(self.mlp1.named_parameters())
def is_vision_model_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_model")
def is_llm(name: str) -> bool:
return name.startswith("language_model")
def is_adapter_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("mlp1")
# Get references to parameters for direct loading
vision_model_dict = dict(self.vision_model.named_parameters())
vision_model_buffers = dict(self.vision_model.named_buffers())
adapter_dict = dict(self.mlp1.named_parameters())
def is_vision_weights(name: str) -> bool:
return name.startswith("vision_model.radio_model.")
def llm_weights_generator():
# Single pass over weights
for name, w in weights:
if is_vision_model_weights((name, w)):
# Load vision encoder weights directly
trimmed_name = ".".join(name.split(".")[1:])
if "input_conditioner" in trimmed_name:
continue
if trimmed_name in vision_model_buffers:
param = vision_model_buffers[trimmed_name]
else:
param = vision_model_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_adapter_weights((name, w)):
# Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
else:
# LLM weights: yield them to be loaded
# by language_model.load_weights
assert name.startswith("language_model")
trimmed_name = ".".join(name.split(".")[1:])
yield (trimmed_name, w)
# Separate weights by component
llm_weights = []
vision_weights = []
# Now we call the language model load with the generator
self.language_model.load_weights(llm_weights_generator())
for name, w in weights:
if is_llm(name):
# Strip 'language_model.' prefix for LLM weights
llm_weights.append((".".join(name.split(".")[1:]), w))
elif is_adapter_weights((name, w)):
# Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_weights(name):
# Convert: vision_model.radio_model.* → radio_model.*
hf_key = name[len(
"vision_model."):] # Remove "vision_model." prefix
vision_weights.append((hf_key, w))
self.language_model.load_weights(llm_weights)
self.vision_model.load_weights(vision_weights)
def print_architecture(self,
detailed: bool = True,
@ -1370,6 +1348,30 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
},
}
def get_vit_model_from_radio_config(self, hf_config):
hf_config_vision = hf_config.vision_config
model_name = hf_config_vision.args.get("model")
if model_name is None:
raise ValueError(f'Unsupported vit model type: {model_name}')
preferred_resolution = getattr(hf_config_vision,
"preferred_resolution", None)
image_size = preferred_resolution[0] if preferred_resolution else 224
patch_size = getattr(hf_config_vision, "patch_size", 16)
radio_config = RadioConfig(
model_name=model_name,
image_size=image_size,
patch_size=patch_size,
norm_mean=hf_config.norm_mean,
norm_std=hf_config.norm_std,
reg_tokens=(hf_config_vision.args.get("register_multiple")
if hasattr(hf_config_vision, "args")
and isinstance(hf_config_vision.args, dict) else None),
)
return RadioModel(config=radio_config)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.language_model.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)

View File

@ -0,0 +1,576 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import math
from collections.abc import Iterable
from itertools import repeat
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionEncoder
input_dim_t = Union[int, tuple[int, int]]
norm_t = Union[tuple[float, float, float], torch.Tensor]
def _ntuple(n):
def parse(x):
if isinstance(x, Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
class InputConditioner(nn.Module):
def __init__(
self,
input_scale: float,
norm_mean: norm_t,
norm_std: norm_t,
dtype: torch.dtype = None,
):
super().__init__()
self.dtype = dtype
self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
def forward(self, x: torch.Tensor):
y = (x - self.norm_mean) / self.norm_std
if self.dtype is not None:
y = y.to(self.dtype)
return y
def _to_tensor(v: norm_t):
return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
class ClsToken(nn.Module):
def __init__(
self,
ndim: int,
num_tokens: int = 1,
enabled: bool = True,
register_multiple: Optional[int] = None,
num_registers: Optional[int] = None,
):
super().__init__()
self.ndim = ndim
self.enabled = enabled
self.num_registers = 0
self.num_tokens = num_tokens
if enabled:
if num_registers:
self.num_registers = num_registers
elif register_multiple:
self.num_registers = register_multiple - (num_tokens %
register_multiple)
scale = ndim**-0.5
self.token = nn.Parameter(
torch.randn(num_tokens + self.num_registers, ndim) * scale)
else:
self.token = None
self.num_patches = self.num_tokens + self.num_registers
def forward(self, x: torch.Tensor):
if self.token is None:
return x
token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
x = torch.cat([
token,
x,
], dim=1)
return x
class ViTPatchGenerator(nn.Module):
def __init__(
self,
# config: PretrainedConfig,
patch_size: int,
embed_dim: int,
input_dims: input_dim_t,
abs_pos: bool = True,
normalize_patches: bool = False,
cls_token: bool = False,
max_input_dims: Optional[input_dim_t] = None,
pos_dropout: float = 0.0,
return_pos_enc: bool = False,
num_cls_tokens: int = 1,
register_multiple: Optional[int] = None,
num_registers: Optional[int] = None,
patch_bias: bool = False,
device=None,
dtype=None,
):
super().__init__()
if isinstance(input_dims, int):
input_dims = (input_dims, input_dims)
if max_input_dims is None:
max_input_dims = input_dims
if isinstance(max_input_dims, int):
max_input_dims = (max_input_dims, max_input_dims)
max_input_dims = tuple(
int(math.ceil(d / patch_size) * patch_size)
for d in max_input_dims)
self.cpe_mode = max_input_dims != input_dims
self.pos_dropout = pos_dropout
self.return_pos_enc = return_pos_enc
factory = dict(device=device, dtype=dtype)
self.patch_size = patch_size
self.abs_pos = abs_pos
self.embed_dim = embed_dim
self.num_rows = max_input_dims[0] // patch_size
self.num_cols = max_input_dims[1] // patch_size
self.input_dims = tuple(d // patch_size for d in input_dims)
self.num_patches = self.num_rows * self.num_cols
self.max_input_dims = max_input_dims
self.im_to_patches = Im2Patches(patch_size)
self.embedder = ViTPatchLinear(patch_size,
embed_dim,
bias=patch_bias,
**factory)
if abs_pos:
scale = embed_dim**-0.5
self.pos_embed = nn.Parameter(
torch.randn(1, self.num_patches, embed_dim, **factory) * scale)
self.cls_token = ClsToken(
embed_dim,
num_tokens=num_cls_tokens,
enabled=cls_token,
register_multiple=register_multiple,
num_registers=num_registers,
)
self.patch_normalizer = nn.LayerNorm(
embed_dim) if normalize_patches else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
patches = self.embed_patches(x)
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
patches = self.cls_token(patches)
patches = self.patch_normalizer(patches)
if self.return_pos_enc:
return patches, pos_enc
return patches
@property
def apply_cls_token(self):
return self.cls_token.enabled
@property
def num_cls_tokens(self):
return self.cls_token.num_tokens
@property
def num_cls_patches(self):
return self.cls_token.num_patches
@property
def num_registers(self):
return self.cls_token.num_registers
@property
def num_skip(self):
return self.num_cls_tokens + self.num_registers
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
if src_embed.shape != targ_embed.shape:
src_size = int(math.sqrt(src_embed.shape[1]))
assert src_size**2 == src_embed.shape[
1], 'Unable to interpolate non-square embedding'
src_embed = rearrange(src_embed,
'b (h w) c -> b c h w',
h=src_size,
w=src_size)
src_embed = F.interpolate(src_embed,
size=(self.num_rows, self.num_cols),
mode='bicubic',
align_corners=True,
antialias=False)
src_embed = rearrange(src_embed, 'b c h w -> b (h w) c')
targ_embed.data.copy_(src_embed)
def _load_projection(self, src_proj_weight: torch.Tensor,
targ_proj_weight: torch.Tensor):
if src_proj_weight.shape != targ_proj_weight.shape:
src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))
assert (src_patch_size**2) * 3 == src_proj_weight.shape[
1], 'Unable to interpolate non-square patch size'
src_proj_weight = rearrange(src_proj_weight,
'b (c h w) -> b c h w',
c=3,
h=src_patch_size,
w=src_patch_size)
src_proj_weight = F.interpolate(src_proj_weight,
size=(self.patch_size,
self.patch_size),
mode='bicubic',
align_corners=True,
antialias=False)
src_proj_weight = rearrange(src_proj_weight,
'b c h w -> b (c h w)')
targ_proj_weight.data.copy_(src_proj_weight)
def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
patches = self.im_to_patches(x)
patches = self.embedder(patches)
return patches
def apply_pos_enc(
self,
patches: torch.Tensor,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[tuple[int, int]] = None,
) -> torch.Tensor:
if not self.abs_pos:
return patches
pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
if self.training and self.pos_dropout > 0:
keeps = torch.rand(patches.shape[0],
1,
1,
dtype=pos_enc.dtype,
device=pos_enc.device) > self.pos_dropout
pos_enc_drop = torch.where(keeps, pos_enc, 0)
else:
pos_enc_drop = pos_enc
return patches + pos_enc_drop, pos_enc
def get_pos_enc(
self,
batch_size: int,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[tuple[int, int]] = None,
) -> torch.Tensor:
if input_size is None:
input_dims = self.input_dims
else:
input_dims = tuple(d // self.patch_size for d in input_size)
pos_embed = self._get_pos_embeddings(batch_size, input_dims)
if patch_idxs is None:
return pos_embed
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(
-1, -1, pos_embed.shape[-1])
pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1),
dim=1,
index=exp_patch_idxs)
return pos_embed
def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int,
int]):
if (self.num_rows, self.num_cols) == input_dims:
return self.pos_embed
pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols,
-1).permute(0, 3, 1, 2)
def window_select(pos_embed):
if input_dims[0] < pos_embed.shape[-2]:
pos_embed = pos_embed[..., :input_dims[0], :]
if input_dims[1] < pos_embed.shape[-1]:
pos_embed = pos_embed[..., :, :input_dims[1]]
return pos_embed
if self.cpe_mode:
if self.training:
min_scale = math.sqrt(0.1)
scale = torch.rand(batch_size, 1, 1, device=pos_embed.device
) * (1 - min_scale) + min_scale
aspect_min = math.log(3 / 4)
aspect_max = -aspect_min
aspect = torch.exp(
torch.rand(batch_size, 1, 1, device=pos_embed.device) *
(aspect_max - aspect_min) + aspect_min)
scale_x = scale * aspect
scale_y = scale * (1 / aspect)
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
pos_xy = torch.rand(
batch_size, 1, 1, 2,
device=pos_embed.device) * (1 - scale_xy)
lin_x = torch.linspace(
0, 1, steps=input_dims[1],
device=pos_embed.device)[None, None].expand(
batch_size, input_dims[0], -1)
lin_y = torch.linspace(
0, 1, steps=input_dims[0],
device=pos_embed.device)[None, :, None].expand(
batch_size, -1, input_dims[1])
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
grid_xy = lin_xy * scale_xy + pos_xy
# Convert to [-1, 1] range
grid_xy.mul_(2).sub_(1)
pos_embed = F.grid_sample(
pos_embed.float().expand(batch_size, -1, -1, -1),
grid=grid_xy,
mode='bilinear',
padding_mode='zeros',
align_corners=True,
).to(pos_embed.dtype)
else:
max_dim = max(input_dims)
pos_embed = F.interpolate(pos_embed.float(),
size=(max_dim, max_dim),
align_corners=True,
mode='bilinear').to(pos_embed.dtype)
pos_embed = window_select(pos_embed)
else:
pos_embed = window_select(pos_embed)
if pos_embed.shape[-2:] != input_dims:
pos_embed = F.interpolate(pos_embed.float(),
size=input_dims,
align_corners=True,
mode='bilinear').to(pos_embed.dtype)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
return pos_embed
class Im2Patches(nn.Module):
def __init__(self, patch_size: int):
super().__init__()
self.patch_size = patch_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.patch_size == 1:
patches = x.flatten(2)
patches = patches.permute(0, 2, 1)
return patches
py = x.shape[-2] // self.patch_size
px = x.shape[-1] // self.patch_size
patches = rearrange(
x,
'b c (py yy) (px xx) -> b (py px) (c yy xx)',
py=py,
yy=self.patch_size,
px=px,
xx=self.patch_size,
)
return patches
class ViTPatchLinear(nn.Linear):
def __init__(self,
patch_size: int,
embed_dim: int,
bias: bool = False,
**factory):
super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory)
self.patch_size = patch_size
class RadioInternVisionModel(nn.Module):
packed_modules_mapping = {
"qkv": ["qkv"],
}
def __init__(
self,
config: PretrainedConfig = None,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.img_size, self.grid_size, self.num_patches = self._init_img_size(
to_2tuple(config.patch_size), config.image_size)
max_img_size = int(
round(config.max_img_size / config.patch_size) * config.patch_size)
self.patch_generator = ViTPatchGenerator(
config.patch_size,
config.hidden_size,
input_dims=self.img_size,
max_input_dims=max_img_size,
cls_token=True,
register_multiple=config.reg_tokens)
self.encoder = InternVisionEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder",
)
def _init_img_size(self, patch_size, img_size: Union[int, tuple[int,
int]]):
if img_size is None:
return None, None, None
img_size = to_2tuple(img_size)
grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
num_patches = grid_size[0] * grid_size[1]
return img_size, grid_size, num_patches
def get_input_embeddings(self):
return self.embeddings
def forward(self, x: torch.Tensor) -> torch.FloatTensor:
assert self.patch_generator is not None
hidden_states = self.patch_generator(x)
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
return encoder_outputs
class RadioModel(nn.Module):
packed_modules_mapping = {
"qkv": ["qkv"],
}
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.input_conditioner = InputConditioner(
input_scale=1.0,
norm_mean=config.norm_mean,
norm_std=config.norm_std,
)
self.model = RadioInternVisionModel(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
prefix=prefix)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_embeds: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
x = self.input_conditioner(pixel_values)
y = self.model(x)
return self._extract_final(y)
def load_weights(self, weights) -> set[str]:
loaded_params: set[str] = set()
params_dict = dict(self.named_parameters())
if isinstance(weights, dict):
weights_list = list(weights.items())
else:
weights_list = list(weights)
for name, weight in weights_list:
if not name.startswith("radio_model."):
# Skip non-radio weights
continue
sub = name[len("radio_model."):] # drop "radio_model." prefix
# Skip buffers not used in vLLM
if sub in {"summary_idxs"}:
continue
vllm_key = None
if sub.startswith("model.patch_generator."):
vllm_key = f"model.patch_generator.{sub.split('.', 2)[-1]}"
elif sub.startswith("input_conditioner."):
vllm_key = f"input_conditioner.{sub.split('.', 1)[-1]}"
elif sub.startswith("model.blocks."):
# Encoder blocks: HF 'model.blocks.{i}.' ->
# vLLM 'model.encoder.layers.{i}.'
parts = sub.split(".")
if len(parts) >= 4:
layer_idx = parts[2]
suffix = ".".join(parts[3:])
# Skip layer-scale entries that vLLM doesn't use
if suffix in {"ls1", "ls2"} or suffix.startswith(
("ls1.", "ls2.")):
continue
vllm_key = f"model.encoder.layers.{layer_idx}.{suffix}"
if vllm_key and vllm_key in params_dict:
param = params_dict[vllm_key]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(vllm_key)
return loaded_params
def _extract_final(self, y: torch.Tensor):
# Remove CLS + REGISTERS tokens
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
all_feat = y[:, patch_gen.num_skip:]
return all_feat

View File

@ -26,6 +26,7 @@ from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.olmo3 import Olmo3Config
from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig,
@ -48,6 +49,7 @@ __all__ = [
"Nemotron_Nano_VL_Config",
"Olmo3Config",
"OvisConfig",
"RadioConfig",
"SpeculatorsConfig",
"UltravoxConfig",
"Step3VLConfig",

View File

@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Radio vision model configuration"""
from typing import Optional, Union
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
VIT_TIMM_DIM_BY_NAME: dict[str, tuple[int, int, int, int]] = {
"vit_small_patch16_224": (384, 12, 6, 1536),
"vit_base_patch16_224": (768, 12, 12, 3072),
"vit_large_patch16_224": (1024, 24, 16, 4096),
"vit_huge_patch16_224": (1280, 32, 16, 5120),
}
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
class RadioConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a Radio
vision model. It is used to instantiate a Radio model according to the
specified arguments, defining the model architecture.
Args:
model_name (`str`, *optional*, defaults to "vit_base_patch16_224"):
Name of the vision transformer model (e.g., "vit_base_patch16_224").
Used to determine architecture dimensions from
`VIT_TIMM_DIM_BY_NAME`.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
qkv_bias (`bool`, *optional*, defaults to True):
Whether to add a bias to the queries, keys and values.
qk_normalization (`bool`, *optional*, defaults to False):
Whether to apply normalization to queries and keys.
norm_type (`str`, *optional*, defaults to "layer_norm"):
The normalization type to use.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
initializer_factor (`float`, *optional*, defaults to 1.0):
A factor for initializing all weight matrices.
hidden_act (`str`, *optional*, defaults to "gelu"):
The non-linear activation function in the encoder.
max_img_size (`int`, *optional*, defaults to 2048):
Maximum image size for position embeddings.
norm_mean (`tuple` or `list`, *optional*,
defaults to (0.48145466, 0.4578275, 0.40821073)):
Mean values for image normalization (RGB channels).
norm_std (`tuple` or `list`, *optional*,
defaults to (0.26862954, 0.26130258, 0.27577711)):
Standard deviation values for image normalization (RGB channels).
reg_tokens (`int`, *optional*):
Number of register tokens to use.
"""
model_type = "radio"
def __init__(
self,
model_name: str,
image_size: int = 224,
patch_size: int = 16,
qkv_bias: bool = True,
qk_normalization: bool = False,
norm_type: str = "layer_norm",
layer_norm_eps: float = 1e-6,
initializer_factor: float = 1.0,
hidden_act: str = "gelu",
max_img_size: int = 2048,
norm_mean: Union[tuple[float, float, float], list] = OPENAI_CLIP_MEAN,
norm_std: Union[tuple[float, float, float], list] = OPENAI_CLIP_STD,
reg_tokens: Optional[int] = None,
**kwargs,
):
self.model_name = model_name
(
self.hidden_size,
self.num_hidden_layers,
self.num_attention_heads,
self.intermediate_size,
) = VIT_TIMM_DIM_BY_NAME[model_name]
self.image_size = image_size
self.patch_size = patch_size
self.qkv_bias = qkv_bias
self.qk_normalization = qk_normalization
self.norm_type = norm_type
self.layer_norm_eps = layer_norm_eps
self.initializer_factor = initializer_factor
self.hidden_act = hidden_act
self.max_img_size = max_img_size
self.norm_mean = list(norm_mean) if isinstance(norm_mean,
(tuple,
list)) else norm_mean
self.norm_std = list(norm_std) if isinstance(norm_std,
(tuple,
list)) else norm_std
self.reg_tokens = reg_tokens
super().__init__(**kwargs)