mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 18:04:29 +08:00
[New Model] BAGEL support (AR only) (#28439)
Signed-off-by: princepride <wangzhipeng628@gmail.com> Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e3a1cd1c59
commit
1adeb3b84c
@ -661,6 +661,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
|
||||
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A<sup>+</sup> | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ |
|
||||
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ |
|
||||
| `BagelForConditionalGeneration` | BAGEL | T + I<sup>+</sup> | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ |
|
||||
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
|
||||
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ |
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ |
|
||||
|
||||
@ -118,6 +118,32 @@ def run_bee(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_bagel(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "ByteDance-Seed/BAGEL-7B-MoT"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
(
|
||||
f"<|im_start|>user\n<|image_pad|>\n{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# BLIP-2
|
||||
def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1832,6 +1858,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_example_map = {
|
||||
"aria": run_aria,
|
||||
"aya_vision": run_aya_vision,
|
||||
"bagel": run_bagel,
|
||||
"bee": run_bee,
|
||||
"blip-2": run_blip2,
|
||||
"chameleon": run_chameleon,
|
||||
|
||||
@ -582,6 +582,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0.dev"
|
||||
),
|
||||
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"),
|
||||
"BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"),
|
||||
"BeeForConditionalGeneration": _HfExamplesInfo(
|
||||
"Open-Bee/Bee-8B-RL",
|
||||
trust_remote_code=True,
|
||||
|
||||
584
vllm/model_executor/models/bagel.py
Normal file
584
vllm/model_executor/models/bagel.py
Normal file
@ -0,0 +1,584 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
||||
"""Inference-only BAGEL model compatible with HuggingFace weights.
|
||||
|
||||
BAGEL is a unified multimodal model for image understanding and generation.
|
||||
For vLLM, we focus on the image understanding (vision-to-text) capabilities.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.processors.bagel import BagelProcessor
|
||||
from vllm.utils.tensor_schema import TensorSchema
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BagelImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image
|
||||
- w: Width of each image
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor # Shape: (bn, 3, h, w)
|
||||
|
||||
|
||||
BagelImageInputs: TypeAlias = BagelImagePixelInputs
|
||||
|
||||
|
||||
class BagelVisionMLP(nn.Module):
|
||||
"""MLP connector for vision features."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_features: int,
|
||||
act_layer: str = "gelu_pytorch_tanh",
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.act = get_act_fn(act_layer)
|
||||
self.fc2 = RowParallelLinear(
|
||||
hidden_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x, _ = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class PositionEmbedding(nn.Module):
|
||||
"""2D position embedding for vision tokens using sin-cos embeddings."""
|
||||
|
||||
def __init__(self, max_num_patch_per_side: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.max_num_patch_per_side = max_num_patch_per_side
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# Create learnable 2D position embeddings (frozen sin-cos)
|
||||
pos_embed = self._get_2d_sincos_pos_embed(hidden_size, max_num_patch_per_side)
|
||||
self.register_buffer(
|
||||
"pos_embed",
|
||||
torch.from_numpy(pos_embed).float(),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int):
|
||||
"""Generate 2D sin-cos position embeddings."""
|
||||
import numpy as np
|
||||
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = PositionEmbedding._get_2d_sincos_pos_embed_from_grid(
|
||||
embed_dim, grid
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
@staticmethod
|
||||
def _get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid):
|
||||
"""Generate 2D sin-cos position embeddings from grid."""
|
||||
import numpy as np
|
||||
|
||||
assert embed_dim % 2 == 0
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = PositionEmbedding._get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim // 2, grid[0]
|
||||
)
|
||||
emb_w = PositionEmbedding._get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim // 2, grid[1]
|
||||
)
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1)
|
||||
return emb
|
||||
|
||||
@staticmethod
|
||||
def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos):
|
||||
"""Generate 1D sin-cos position embeddings."""
|
||||
import numpy as np
|
||||
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega
|
||||
|
||||
pos = pos.reshape(-1)
|
||||
out = np.einsum("m,d->md", pos, omega)
|
||||
|
||||
emb_sin = np.sin(out)
|
||||
emb_cos = np.cos(out)
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
||||
return emb
|
||||
|
||||
def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
position_ids: Flattened position IDs, shape (N,) where each ID
|
||||
corresponds to a position in the flattened grid
|
||||
Returns:
|
||||
Position embeddings of shape (N, hidden_size)
|
||||
"""
|
||||
# Ensure position_ids are on the same device as pos_embed
|
||||
position_ids = position_ids.to(self.pos_embed.device)
|
||||
return self.pos_embed[position_ids]
|
||||
|
||||
|
||||
class BagelProcessingInfo(BaseProcessingInfo):
|
||||
"""Processing information for BAGEL model."""
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> BagelProcessor:
|
||||
from vllm.transformers_utils.processor import cached_get_image_processor
|
||||
|
||||
image_processor = cached_get_image_processor(
|
||||
self.ctx.model_config.model,
|
||||
revision=self.ctx.model_config.revision,
|
||||
trust_remote_code=self.ctx.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
return BagelProcessor(
|
||||
image_processor=image_processor,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
hf_config = self.get_hf_config()
|
||||
# Calculate max tokens per image
|
||||
# For BAGEL: (vit_max_num_patch_per_side) ** 2
|
||||
max_num_patches = hf_config.vit_max_num_patch_per_side**2
|
||||
return {"image": max_num_patches}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vit_config = hf_config.vit_config
|
||||
patch_size = vit_config.patch_size
|
||||
|
||||
# Calculate number of patches
|
||||
num_patches_h = image_height // patch_size
|
||||
num_patches_w = image_width // patch_size
|
||||
return num_patches_h * num_patches_w
|
||||
|
||||
|
||||
class BagelDummyInputsBuilder(BaseDummyInputsBuilder[BagelProcessingInfo]):
|
||||
"""Build dummy inputs for BAGEL model profiling."""
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
# Use a simple placeholder for each image
|
||||
return "<|image_pad|>" * num_images
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
hf_config = self.info.get_hf_config()
|
||||
vit_config = hf_config.vit_config
|
||||
|
||||
# Use the configured image size
|
||||
image_size = vit_config.image_size
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image": self._get_dummy_images(
|
||||
width=image_size,
|
||||
height=image_size,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class BagelMultiModalProcessor(BaseMultiModalProcessor[BagelProcessingInfo]):
|
||||
"""Multimodal processor for BAGEL model."""
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptReplacement]:
|
||||
"""Replace image placeholders with the correct number of tokens."""
|
||||
hf_config = self.info.get_hf_config()
|
||||
|
||||
# Get the tokenizer to look up the image token ID
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_token_id = tokenizer.get_vocab().get("<|image_pad|>")
|
||||
if image_token_id is None:
|
||||
raise ValueError(
|
||||
"Image token '<|image_pad|>' not found in tokenizer vocabulary"
|
||||
)
|
||||
|
||||
def get_replacement_bagel(item_idx: int):
|
||||
# For BAGEL, calculate number of tokens based on max patch size
|
||||
num_tokens = hf_config.vit_max_num_patch_per_side**2
|
||||
# Use the image token ID from tokenizer
|
||||
return [image_token_id] * num_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement_bagel,
|
||||
)
|
||||
]
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Any,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return {
|
||||
"pixel_values": MultiModalFieldConfig.batched("image"),
|
||||
}
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
BagelMultiModalProcessor,
|
||||
info=BagelProcessingInfo,
|
||||
dummy_inputs=BagelDummyInputsBuilder,
|
||||
)
|
||||
class BagelForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP
|
||||
):
|
||||
"""
|
||||
BAGEL: A unified multimodal model for image understanding and generation.
|
||||
|
||||
For vLLM, we focus on the image understanding (vision-to-text) capabilities.
|
||||
The image generation part is not supported in vLLM.
|
||||
"""
|
||||
|
||||
# Weight mapping from HF to vLLM
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"language_model.": "language_model.",
|
||||
"vit_model.": "vit_model.",
|
||||
"connector.": "connector.",
|
||||
"vit_pos_embed.": "vit_pos_embed.",
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
# Ensure we have a BagelConfig (check by name to handle trust_remote_code)
|
||||
# When trust_remote_code=True, the config comes from transformers_modules
|
||||
if type(config).__name__ != "BagelConfig":
|
||||
raise ValueError(
|
||||
f"Expected BagelConfig, got {type(config).__name__}. "
|
||||
"Make sure the model config is properly loaded."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# Initialize language model (Qwen2)
|
||||
# Pass the llm_config from BagelConfig to initialize Qwen2 properly
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.llm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
architectures=["Qwen2ForCausalLM"],
|
||||
)
|
||||
|
||||
# Initialize vision model (SigLIP) if visual understanding is enabled
|
||||
if config.visual_und:
|
||||
# Fix vit_config: checkpoint has 26 layers (0-25) but config says 27
|
||||
# Also disable head as it's not in checkpoint
|
||||
vit_config = config.vit_config
|
||||
if vit_config.num_hidden_layers == 27:
|
||||
logger.warning(
|
||||
"Overriding vit_config.num_hidden_layers from 27 to 26 "
|
||||
"to match the Bagel model checkpoint."
|
||||
)
|
||||
vit_config.num_hidden_layers = 26
|
||||
if not hasattr(vit_config, "vision_use_head"):
|
||||
logger.warning(
|
||||
"Setting vit_config.vision_use_head to False as it is not "
|
||||
"present in the Bagel model checkpoint."
|
||||
)
|
||||
vit_config.vision_use_head = False
|
||||
|
||||
self.vit_model = SiglipVisionModel(
|
||||
config=vit_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vit_model"),
|
||||
)
|
||||
|
||||
# Initialize connector (MLP)
|
||||
vit_hidden_size = config.vit_config.hidden_size
|
||||
llm_hidden_size = config.llm_config.hidden_size
|
||||
|
||||
self.connector = BagelVisionMLP(
|
||||
in_features=vit_hidden_size,
|
||||
hidden_features=llm_hidden_size,
|
||||
out_features=llm_hidden_size,
|
||||
act_layer=config.connector_act,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "connector"),
|
||||
)
|
||||
|
||||
# Position embedding for vision tokens
|
||||
self.vit_pos_embed = PositionEmbedding(
|
||||
max_num_patch_per_side=config.vit_max_num_patch_per_side,
|
||||
hidden_size=llm_hidden_size,
|
||||
)
|
||||
else:
|
||||
self.vit_model = None
|
||||
self.connector = None
|
||||
self.vit_pos_embed = None
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> BagelImageInputs | None:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
return BagelImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: BagelImageInputs
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Process image inputs through vision encoder and connector."""
|
||||
pixel_values = image_input["pixel_values"]
|
||||
|
||||
# Handle potential extra batch dimension
|
||||
# Expected shape: (batch_size * num_images, 3, H, W)
|
||||
# But might receive: (batch_size, num_images, 3, H, W)
|
||||
if pixel_values.ndim == 5:
|
||||
# Flatten batch and num_images dimensions
|
||||
batch_size, num_images, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(
|
||||
batch_size * num_images, channels, height, width
|
||||
)
|
||||
|
||||
# Get vision features from SigLIP
|
||||
# pixel_values shape: (batch_size * num_images, 3, H, W)
|
||||
vision_features = self.vit_model(pixel_values)
|
||||
|
||||
# Pass through connector
|
||||
vision_embeds = self.connector(vision_features)
|
||||
|
||||
# Add position embeddings
|
||||
batch_size, num_patches, hidden_size = vision_embeds.shape
|
||||
patch_size = self.config.vit_config.patch_size
|
||||
image_size = self.config.vit_config.image_size
|
||||
|
||||
# Calculate grid dimensions
|
||||
num_patches_per_side = image_size // patch_size
|
||||
|
||||
# Create flattened position IDs (0 to num_patches-1)
|
||||
# For BAGEL, we use extrapolate mode by default
|
||||
h_coords = torch.arange(num_patches_per_side, device=vision_embeds.device)
|
||||
w_coords = torch.arange(num_patches_per_side, device=vision_embeds.device)
|
||||
position_ids = (
|
||||
h_coords[:, None] * self.config.vit_max_num_patch_per_side + w_coords
|
||||
).flatten()
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1).flatten()
|
||||
|
||||
# Add position embeddings
|
||||
pos_embeds = self.vit_pos_embed(position_ids)
|
||||
pos_embeds = pos_embeds.reshape(batch_size, num_patches, hidden_size)
|
||||
# Ensure pos_embeds are on the same device as vision_embeds
|
||||
pos_embeds = pos_embeds.to(vision_embeds.device)
|
||||
vision_embeds = vision_embeds + pos_embeds
|
||||
|
||||
# Split by image
|
||||
return tuple(vision_embeds)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
"""Get multimodal embeddings from input."""
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_language_model(self) -> nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
"""Run forward pass for BAGEL.
|
||||
|
||||
Args:
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a batch.
|
||||
positions: Flattened (concatenated) position ids corresponding to a batch.
|
||||
intermediate_tensors: Intermediate tensors from prior forward pass.
|
||||
inputs_embeds: Optional tensor of input embeddings.
|
||||
"""
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Load weights from checkpoint."""
|
||||
skip_prefixes = []
|
||||
# Skip vit_pos_embed.pos_embed as it's handled by PositionEmbedding module
|
||||
skip_prefixes.append("vit_pos_embed.pos_embed")
|
||||
|
||||
# If visual understanding is disabled, skip vision-related weights
|
||||
if self.vit_model is None:
|
||||
skip_prefixes.extend(["vit_model.", "connector.", "vit_pos_embed"])
|
||||
|
||||
# Skip generation-related weights since we only support text2text and image2text
|
||||
# Filter out all image generation components:
|
||||
# - 'moe_gen': MoE generation weights
|
||||
# - 'latent_pos_embed': Latent position embeddings for VAE
|
||||
# - 'llm2vae', 'vae2llm': LLM-VAE projections
|
||||
# - 'time_embedder': Timestep embeddings for diffusion
|
||||
# - VAE encoder/decoder: Use specific prefixes to avoid matching vision encoder
|
||||
generation_keywords = [
|
||||
"moe_gen",
|
||||
"latent_pos_embed",
|
||||
"llm2vae",
|
||||
"vae2llm",
|
||||
"time_embedder",
|
||||
]
|
||||
vae_prefixes = [
|
||||
"decoder.",
|
||||
"encoder.",
|
||||
] # VAE encoder/decoder, not vision encoder
|
||||
filtered_weights = []
|
||||
for name, tensor in weights:
|
||||
# Skip generation-related keywords
|
||||
if any(skip in name for skip in generation_keywords):
|
||||
continue
|
||||
if any(name.startswith(prefix) for prefix in vae_prefixes):
|
||||
continue
|
||||
|
||||
if "patch_embedding.weight" in name and tensor.ndim == 2:
|
||||
out_channels = tensor.shape[0]
|
||||
in_features = tensor.shape[1]
|
||||
patch_size = self.config.vit_config.patch_size
|
||||
in_channels = self.config.vit_config.num_channels
|
||||
if in_features == in_channels * patch_size * patch_size:
|
||||
tensor = tensor.reshape(
|
||||
out_channels, patch_size, patch_size, in_channels
|
||||
)
|
||||
tensor = tensor.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
filtered_weights.append((name, tensor))
|
||||
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
return loader.load_weights(filtered_weights, mapper=self.hf_to_vllm_mapper)
|
||||
@ -122,6 +122,8 @@ class Qwen2Attention(nn.Module):
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: dict[str, Any] | None = None,
|
||||
qk_norm: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -144,6 +146,7 @@ class Qwen2Attention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.dual_chunk_attention_config = dual_chunk_attention_config
|
||||
self.qk_norm = qk_norm
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@ -162,6 +165,11 @@ class Qwen2Attention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
# QK Normalization support (used in BAGEL and some other models)
|
||||
if self.qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
max_position=max_position,
|
||||
@ -197,6 +205,23 @@ class Qwen2Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Apply QK normalization if enabled (before RoPE)
|
||||
if self.qk_norm:
|
||||
# Reshape to apply per-head normalization
|
||||
# q shape: (total_tokens, q_size) -> (total_tokens, num_heads, head_dim)
|
||||
total_tokens = q.shape[0]
|
||||
q = q.view(total_tokens, self.num_heads, self.head_dim)
|
||||
k = k.view(total_tokens, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# Apply normalization
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Reshape back
|
||||
q = q.view(total_tokens, self.q_size)
|
||||
k = k.view(total_tokens, self.kv_size)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
@ -227,6 +252,9 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
# Check if QK normalization is enabled (used in BAGEL and some other models)
|
||||
qk_norm = getattr(config, "qk_norm", False)
|
||||
|
||||
self.self_attn = Qwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@ -238,6 +266,8 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
qk_norm=qk_norm,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
@ -480,6 +510,8 @@ class Qwen2Model(nn.Module):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -272,6 +272,7 @@ _MULTIMODAL_MODELS = {
|
||||
"aya_vision",
|
||||
"AyaVisionForConditionalGeneration",
|
||||
),
|
||||
"BagelForConditionalGeneration": ("bagel", "BagelForConditionalGeneration"),
|
||||
"BeeForConditionalGeneration": ("bee", "BeeForConditionalGeneration"),
|
||||
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
||||
"ChameleonForConditionalGeneration": (
|
||||
|
||||
@ -66,6 +66,7 @@ class LazyConfigDict(dict):
|
||||
|
||||
_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
afmoe="AfmoeConfig",
|
||||
bagel="BagelConfig",
|
||||
chatglm="ChatGLMConfig",
|
||||
deepseek_vl_v2="DeepseekVLV2Config",
|
||||
deepseek_v32="DeepseekV3Config",
|
||||
|
||||
@ -16,6 +16,7 @@ import importlib
|
||||
|
||||
_CLASS_TO_MODULE: dict[str, str] = {
|
||||
"AfmoeConfig": "vllm.transformers_utils.configs.afmoe",
|
||||
"BagelConfig": "vllm.transformers_utils.configs.bagel",
|
||||
"ChatGLMConfig": "vllm.transformers_utils.configs.chatglm",
|
||||
"DeepseekVLV2Config": "vllm.transformers_utils.configs.deepseek_vl2",
|
||||
"DotsOCRConfig": "vllm.transformers_utils.configs.dotsocr",
|
||||
@ -54,6 +55,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
|
||||
|
||||
__all__ = [
|
||||
"AfmoeConfig",
|
||||
"BagelConfig",
|
||||
"ChatGLMConfig",
|
||||
"DeepseekVLV2Config",
|
||||
"DeepseekV3Config",
|
||||
|
||||
53
vllm/transformers_utils/configs/bagel.py
Normal file
53
vllm/transformers_utils/configs/bagel.py
Normal file
@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from transformers import PretrainedConfig, SiglipVisionConfig
|
||||
from transformers.models.qwen2 import Qwen2Config
|
||||
|
||||
|
||||
class BagelConfig(PretrainedConfig):
|
||||
"""Configuration class for BAGEL model."""
|
||||
|
||||
model_type = "bagel"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
visual_gen: bool = True,
|
||||
visual_und: bool = True,
|
||||
llm_config: dict | Qwen2Config | None = None,
|
||||
vit_config: dict | SiglipVisionConfig | None = None,
|
||||
vae_config: dict | None = None,
|
||||
latent_patch_size: int = 2,
|
||||
max_latent_size: int = 32,
|
||||
vit_max_num_patch_per_side: int = 70,
|
||||
connector_act: str = "gelu_pytorch_tanh",
|
||||
interpolate_pos: bool = False,
|
||||
timestep_shift: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.visual_gen = visual_gen
|
||||
self.visual_und = visual_und
|
||||
|
||||
# Convert dict configs to proper config objects
|
||||
if isinstance(llm_config, dict):
|
||||
self.llm_config = Qwen2Config(**llm_config)
|
||||
else:
|
||||
self.llm_config = llm_config or Qwen2Config()
|
||||
|
||||
if isinstance(vit_config, dict):
|
||||
self.vit_config = SiglipVisionConfig(**vit_config)
|
||||
else:
|
||||
self.vit_config = vit_config or SiglipVisionConfig()
|
||||
|
||||
self.vae_config = vae_config or {"z_channels": 16, "downsample": 8}
|
||||
self.latent_patch_size = latent_patch_size
|
||||
self.max_latent_size = max_latent_size
|
||||
self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
|
||||
self.connector_act = connector_act
|
||||
self.interpolate_pos = interpolate_pos
|
||||
self.timestep_shift = timestep_shift
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
"""Return the hidden size of the language model."""
|
||||
return self.llm_config.hidden_size
|
||||
@ -8,6 +8,7 @@ reasons:
|
||||
- There is a need to override the existing processor to support vLLM.
|
||||
"""
|
||||
|
||||
from vllm.transformers_utils.processors.bagel import BagelProcessor
|
||||
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
|
||||
from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
|
||||
from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor
|
||||
@ -15,6 +16,7 @@ from vllm.transformers_utils.processors.ovis import OvisProcessor
|
||||
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
|
||||
|
||||
__all__ = [
|
||||
"BagelProcessor",
|
||||
"DeepseekVLV2Processor",
|
||||
"HunYuanVLProcessor",
|
||||
"HunYuanVLImageProcessor",
|
||||
|
||||
73
vllm/transformers_utils/processors/bagel.py
Normal file
73
vllm/transformers_utils/processors/bagel.py
Normal file
@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
|
||||
"""BAGEL processor for image and text inputs."""
|
||||
|
||||
from transformers import AutoProcessor
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class BagelProcessor(ProcessorMixin):
|
||||
"""
|
||||
Constructs a BAGEL processor which wraps a
|
||||
SigLIP image processor and a Qwen2 tokenizer.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "SiglipImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput
|
||||
| PreTokenizedInput
|
||||
| list[TextInput]
|
||||
| list[PreTokenizedInput] = None,
|
||||
images: ImageInput = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s).
|
||||
"""
|
||||
if images is not None:
|
||||
# Process images with the image processor
|
||||
# Ensure return_tensors is set to "pt" for PyTorch tensors
|
||||
image_kwargs = {**kwargs}
|
||||
if "return_tensors" not in image_kwargs:
|
||||
image_kwargs["return_tensors"] = "pt"
|
||||
pixel_values = self.image_processor(images, **image_kwargs)
|
||||
else:
|
||||
pixel_values = None
|
||||
|
||||
text_inputs = self.tokenizer(text, **kwargs) if text is not None else None
|
||||
|
||||
if pixel_values is not None and text_inputs is not None:
|
||||
text_inputs["pixel_values"] = pixel_values["pixel_values"]
|
||||
return text_inputs
|
||||
elif pixel_values is not None:
|
||||
return pixel_values
|
||||
else:
|
||||
return text_inputs
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's batch_decode.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's decode.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
|
||||
AutoProcessor.register("BagelProcessor", BagelProcessor)
|
||||
Loading…
x
Reference in New Issue
Block a user