汪志鹏 1adeb3b84c
[New Model] BAGEL support (AR only) (#28439)
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
2025-12-15 14:58:23 +08:00

585 lines
20 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
"""Inference-only BAGEL model compatible with HuggingFace weights.
BAGEL is a unified multimodal model for image understanding and generation.
For vLLM, we focus on the image understanding (vision-to-text) capabilities.
"""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, TypeAlias
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.bagel import BagelProcessor
from vllm.utils.tensor_schema import TensorSchema
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
logger = init_logger(__name__)
class BagelImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
pixel_values: torch.Tensor # Shape: (bn, 3, h, w)
BagelImageInputs: TypeAlias = BagelImagePixelInputs
class BagelVisionMLP(nn.Module):
"""MLP connector for vision features."""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: int,
act_layer: str = "gelu_pytorch_tanh",
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.fc1 = ColumnParallelLinear(
in_features,
hidden_features,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.act = get_act_fn(act_layer)
self.fc2 = RowParallelLinear(
hidden_features,
out_features,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x)
x = self.act(x)
x, _ = self.fc2(x)
return x
class PositionEmbedding(nn.Module):
"""2D position embedding for vision tokens using sin-cos embeddings."""
def __init__(self, max_num_patch_per_side: int, hidden_size: int):
super().__init__()
self.max_num_patch_per_side = max_num_patch_per_side
self.hidden_size = hidden_size
# Create learnable 2D position embeddings (frozen sin-cos)
pos_embed = self._get_2d_sincos_pos_embed(hidden_size, max_num_patch_per_side)
self.register_buffer(
"pos_embed",
torch.from_numpy(pos_embed).float(),
persistent=False,
)
@staticmethod
def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int):
"""Generate 2D sin-cos position embeddings."""
import numpy as np
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = PositionEmbedding._get_2d_sincos_pos_embed_from_grid(
embed_dim, grid
)
return pos_embed
@staticmethod
def _get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid):
"""Generate 2D sin-cos position embeddings from grid."""
import numpy as np
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = PositionEmbedding._get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0]
)
emb_w = PositionEmbedding._get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1]
)
emb = np.concatenate([emb_h, emb_w], axis=1)
return emb
@staticmethod
def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos):
"""Generate 1D sin-cos position embeddings."""
import numpy as np
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega
pos = pos.reshape(-1)
out = np.einsum("m,d->md", pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
position_ids: Flattened position IDs, shape (N,) where each ID
corresponds to a position in the flattened grid
Returns:
Position embeddings of shape (N, hidden_size)
"""
# Ensure position_ids are on the same device as pos_embed
position_ids = position_ids.to(self.pos_embed.device)
return self.pos_embed[position_ids]
class BagelProcessingInfo(BaseProcessingInfo):
"""Processing information for BAGEL model."""
def get_hf_processor(self, **kwargs: object) -> BagelProcessor:
from vllm.transformers_utils.processor import cached_get_image_processor
image_processor = cached_get_image_processor(
self.ctx.model_config.model,
revision=self.ctx.model_config.revision,
trust_remote_code=self.ctx.model_config.trust_remote_code,
)
tokenizer = self.get_tokenizer()
return BagelProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
**kwargs,
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.get_hf_config()
# Calculate max tokens per image
# For BAGEL: (vit_max_num_patch_per_side) ** 2
max_num_patches = hf_config.vit_max_num_patch_per_side**2
return {"image": max_num_patches}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vit_config = hf_config.vit_config
patch_size = vit_config.patch_size
# Calculate number of patches
num_patches_h = image_height // patch_size
num_patches_w = image_width // patch_size
return num_patches_h * num_patches_w
class BagelDummyInputsBuilder(BaseDummyInputsBuilder[BagelProcessingInfo]):
"""Build dummy inputs for BAGEL model profiling."""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
# Use a simple placeholder for each image
return "<|image_pad|>" * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
hf_config = self.info.get_hf_config()
vit_config = hf_config.vit_config
# Use the configured image size
image_size = vit_config.image_size
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=image_size,
height=image_size,
num_images=num_images,
overrides=image_overrides,
),
}
class BagelMultiModalProcessor(BaseMultiModalProcessor[BagelProcessingInfo]):
"""Multimodal processor for BAGEL model."""
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptReplacement]:
"""Replace image placeholders with the correct number of tokens."""
hf_config = self.info.get_hf_config()
# Get the tokenizer to look up the image token ID
tokenizer = self.info.get_tokenizer()
image_token_id = tokenizer.get_vocab().get("<|image_pad|>")
if image_token_id is None:
raise ValueError(
"Image token '<|image_pad|>' not found in tokenizer vocabulary"
)
def get_replacement_bagel(item_idx: int):
# For BAGEL, calculate number of tokens based on max patch size
num_tokens = hf_config.vit_max_num_patch_per_side**2
# Use the image token ID from tokenizer
return [image_token_id] * num_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_bagel,
)
]
def _get_mm_fields_config(
self,
hf_inputs: Any,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return {
"pixel_values": MultiModalFieldConfig.batched("image"),
}
@MULTIMODAL_REGISTRY.register_processor(
BagelMultiModalProcessor,
info=BagelProcessingInfo,
dummy_inputs=BagelDummyInputsBuilder,
)
class BagelForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
):
"""
BAGEL: A unified multimodal model for image understanding and generation.
For vLLM, we focus on the image understanding (vision-to-text) capabilities.
The image generation part is not supported in vLLM.
"""
# Weight mapping from HF to vLLM
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.": "language_model.",
"vit_model.": "vit_model.",
"connector.": "connector.",
"vit_pos_embed.": "vit_pos_embed.",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
# Ensure we have a BagelConfig (check by name to handle trust_remote_code)
# When trust_remote_code=True, the config comes from transformers_modules
if type(config).__name__ != "BagelConfig":
raise ValueError(
f"Expected BagelConfig, got {type(config).__name__}. "
"Make sure the model config is properly loaded."
)
self.config = config
self.multimodal_config = multimodal_config
# Initialize language model (Qwen2)
# Pass the llm_config from BagelConfig to initialize Qwen2 properly
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.llm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
# Initialize vision model (SigLIP) if visual understanding is enabled
if config.visual_und:
# Fix vit_config: checkpoint has 26 layers (0-25) but config says 27
# Also disable head as it's not in checkpoint
vit_config = config.vit_config
if vit_config.num_hidden_layers == 27:
logger.warning(
"Overriding vit_config.num_hidden_layers from 27 to 26 "
"to match the Bagel model checkpoint."
)
vit_config.num_hidden_layers = 26
if not hasattr(vit_config, "vision_use_head"):
logger.warning(
"Setting vit_config.vision_use_head to False as it is not "
"present in the Bagel model checkpoint."
)
vit_config.vision_use_head = False
self.vit_model = SiglipVisionModel(
config=vit_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vit_model"),
)
# Initialize connector (MLP)
vit_hidden_size = config.vit_config.hidden_size
llm_hidden_size = config.llm_config.hidden_size
self.connector = BagelVisionMLP(
in_features=vit_hidden_size,
hidden_features=llm_hidden_size,
out_features=llm_hidden_size,
act_layer=config.connector_act,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "connector"),
)
# Position embedding for vision tokens
self.vit_pos_embed = PositionEmbedding(
max_num_patch_per_side=config.vit_max_num_patch_per_side,
hidden_size=llm_hidden_size,
)
else:
self.vit_model = None
self.connector = None
self.vit_pos_embed = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> BagelImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
return BagelImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
)
def _process_image_input(
self, image_input: BagelImageInputs
) -> tuple[torch.Tensor, ...]:
"""Process image inputs through vision encoder and connector."""
pixel_values = image_input["pixel_values"]
# Handle potential extra batch dimension
# Expected shape: (batch_size * num_images, 3, H, W)
# But might receive: (batch_size, num_images, 3, H, W)
if pixel_values.ndim == 5:
# Flatten batch and num_images dimensions
batch_size, num_images, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_images, channels, height, width
)
# Get vision features from SigLIP
# pixel_values shape: (batch_size * num_images, 3, H, W)
vision_features = self.vit_model(pixel_values)
# Pass through connector
vision_embeds = self.connector(vision_features)
# Add position embeddings
batch_size, num_patches, hidden_size = vision_embeds.shape
patch_size = self.config.vit_config.patch_size
image_size = self.config.vit_config.image_size
# Calculate grid dimensions
num_patches_per_side = image_size // patch_size
# Create flattened position IDs (0 to num_patches-1)
# For BAGEL, we use extrapolate mode by default
h_coords = torch.arange(num_patches_per_side, device=vision_embeds.device)
w_coords = torch.arange(num_patches_per_side, device=vision_embeds.device)
position_ids = (
h_coords[:, None] * self.config.vit_max_num_patch_per_side + w_coords
).flatten()
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1).flatten()
# Add position embeddings
pos_embeds = self.vit_pos_embed(position_ids)
pos_embeds = pos_embeds.reshape(batch_size, num_patches, hidden_size)
# Ensure pos_embeds are on the same device as vision_embeds
pos_embeds = pos_embeds.to(vision_embeds.device)
vision_embeds = vision_embeds + pos_embeds
# Split by image
return tuple(vision_embeds)
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
"""Get multimodal embeddings from input."""
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
return self._process_image_input(image_input)
def get_language_model(self) -> nn.Module:
return self.language_model
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for BAGEL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a batch.
positions: Flattened (concatenated) position ids corresponding to a batch.
intermediate_tensors: Intermediate tensors from prior forward pass.
inputs_embeds: Optional tensor of input embeddings.
"""
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights from checkpoint."""
skip_prefixes = []
# Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module
skip_prefixes.append("vit_pos_embed.pos_embed")
# If visual understanding is disabled, skip vision-related weights
if self.vit_model is None:
skip_prefixes.extend(["vit_model.", "connector.", "vit_pos_embed"])
# Skip generation-related weights since we only support text2text and image2text
# Filter out all image generation components:
# - 'moe_gen': MoE generation weights
# - 'latent_pos_embed': Latent position embeddings for VAE
# - 'llm2vae', 'vae2llm': LLM-VAE projections
# - 'time_embedder': Timestep embeddings for diffusion
# - VAE encoder/decoder: Use specific prefixes to avoid matching vision encoder
generation_keywords = [
"moe_gen",
"latent_pos_embed",
"llm2vae",
"vae2llm",
"time_embedder",
]
vae_prefixes = [
"decoder.",
"encoder.",
] # VAE encoder/decoder, not vision encoder
filtered_weights = []
for name, tensor in weights:
# Skip generation-related keywords
if any(skip in name for skip in generation_keywords):
continue
if any(name.startswith(prefix) for prefix in vae_prefixes):
continue
if "patch_embedding.weight" in name and tensor.ndim == 2:
out_channels = tensor.shape[0]
in_features = tensor.shape[1]
patch_size = self.config.vit_config.patch_size
in_channels = self.config.vit_config.num_channels
if in_features == in_channels * patch_size * patch_size:
tensor = tensor.reshape(
out_channels, patch_size, patch_size, in_channels
)
tensor = tensor.permute(0, 3, 1, 2).contiguous()
filtered_weights.append((name, tensor))
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper)