Cyrus Leung c46b932df2
[Chore] Deprecate SupportsMultiModal.merge_by_field_config (#30170)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-12-06 07:57:28 +00:00

1379 lines
47 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property
from typing import Annotated, Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from transformers import BatchFeature, PixtralVisionConfig, TensorType
from transformers.image_utils import ImageInput
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
)
from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding,
apply_rotary_pos_emb,
position_ids_in_meshgrid,
)
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalUUIDDict,
NestedTensors,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import init_vllm_registered_model, maybe_prefix
from .vision import (
VisionEncoderInfo,
VisionFeatureSelectStrategy,
resolve_visual_encoder_outputs,
)
try:
# Note: vLLM does not install xformers by default.
from xformers import ops as xops
if current_platform.is_cuda() and current_platform.has_device_capability(100):
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS = False
else:
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
PATCH_MERGE = "patch_merge"
class PixtralImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
The result of stacking `ImageEncoding.tokens` from each prompt.
"""
type: Literal["pixel_values"] = "pixel_values"
images: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
]
class PixtralProcessorAdapter:
"""
Provide a HF-compatible interface for
`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
@property
def image_processor(self) -> ImageEncoder:
image_encoder = self.tokenizer.instruct.mm_encoder
assert isinstance(image_encoder, ImageEncoder)
return image_encoder
@cached_property
def image_break_id(self) -> int:
return self.image_processor.special_ids.img_break
@cached_property
def image_token_id(self) -> int:
return self.image_processor.special_ids.img
@cached_property
def image_end_id(self) -> int:
return self.image_processor.special_ids.img_end
@cached_property
def image_size(self) -> int:
return self.image_processor.mm_config.max_image_size
@cached_property
def patch_size(self) -> int:
return self.image_processor.mm_config.image_patch_size
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if not images:
input_ids = self.tokenizer(text).input_ids
return {"input_ids": torch.tensor(input_ids)}
# Allow dummy text, which is used for profiling as well as token inputs
if any(len(t) > 0 for t in text):
raise ValueError(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
image_processed = torch.tensor(image_inputs.image)
image_tokens = torch.tensor(image_inputs.tokens)
images_processed.append(image_processed)
images_tokens.append(image_tokens)
return BatchFeature(
{
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
}
)
class PixtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
if not isinstance(tokenizer, MistralTokenizer):
raise ValueError("This model requires `--tokenizer-mode mistral`")
return tokenizer
def get_hf_processor(self) -> PixtralProcessorAdapter:
return PixtralProcessorAdapter(self.get_tokenizer())
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_vision_config(
self,
processor: PixtralProcessorAdapter | None = None,
):
if processor is None:
processor = self.get_hf_processor()
return PixtralVisionConfig(
image_size=processor.image_size,
patch_size=processor.patch_size,
)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: PixtralProcessorAdapter | None = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height))
)
return ncols * nrows
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor
max_image_size = image_processor.mm_config.max_image_size
return ImageSize(width=max_image_size, height=max_image_size)
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return {
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
}
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_images = dummy_mm_data.get("image", [])
tokenization_kwargs = {"truncation": False}
request = ChatCompletionRequest(
messages=[
UserMessage(
content=[
TextChunk(text=dummy_text),
*(ImageChunk(image=image) for image in dummy_images),
]
),
]
)
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
return ProcessorInputs(
prompt=dummy_tokens,
mm_data=dummy_mm_data,
tokenization_kwargs=tokenization_kwargs,
)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]):
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(images=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_break_id = processor.image_break_id
image_token_id = processor.image_token_id
image_end_id = processor.image_end_id
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_size.width, image_size.height))
)
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return [
PromptReplacement(
modality="image",
target="", # Never match the prompt (see below note)
replacement=get_replacement,
),
]
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
mm_uuids: MultiModalUUIDDict | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_info, True
@MULTIMODAL_REGISTRY.register_processor(
PixtralMultiModalProcessor,
info=PixtralProcessingInfo,
dummy_inputs=PixtralDummyInputsBuilder,
)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return None
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
vision_args = {
key: value
for key, value in self.config.vision_config.to_dict().items()
if key in dataclass_fields
}
self.vision_args = VisionEncoderArgs(**vision_args)
# init MistralForCausalLM
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if multimodal_config.get_limit_per_prompt("image"):
self.vision_encoder = VisionTransformer(self.vision_args)
self.pre_mm_projector_norm = (
RMSNorm(self.vision_args.hidden_size, eps=1e-5)
if self.vision_args.add_pre_mm_projector_layer_norm
else None
)
self.patch_merger = (
PatchMerger(
vision_encoder_dim=self.vision_args.hidden_size,
spatial_merge_size=self.vision_args.spatial_merge_size,
use_mlp_bias=False,
)
if self.vision_args.mm_projector_id == PATCH_MERGE
else None
)
self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size
)
else:
self.vision_encoder = None
self.pre_mm_projector_norm = None
self.patch_merger = None
self.vision_language_adapter = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> PixtralImagePixelInputs | None:
images = kwargs.pop("images", None)
if images is None:
return None
return PixtralImagePixelInputs(
type="pixel_values",
images=images,
)
def _process_image_input(
self,
image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]:
assert (
self.vision_encoder is not None and self.vision_language_adapter is not None
)
images = image_input["images"]
image_features = self.vision_encoder(images)
feature_sizes = [image_feature.shape[0] for image_feature in image_features]
image_features = torch.cat(image_features)
if self.pre_mm_projector_norm is not None:
image_features = self.pre_mm_projector_norm(image_features)
if self.patch_merger is not None:
patch_size = self.vision_args.patch_size
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
img_patch_dims = [
(img.shape[1] // patch_size, img.shape[2] // patch_size)
for img in images
]
feature_sizes = [
feature_size // spatial_merge_size_square
for feature_size in feature_sizes
]
image_features = self.patch_merger(
image_features, image_sizes=img_patch_dims
)
image_embeds = self.vision_language_adapter(image_features)
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
return self._process_image_input(image_input)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for pixtral."""
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_encoder")
def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter")
def is_patch_merger(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("patch_merger")
def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("pre_mm_projector_norm")
# Get references to parameters for direct loading
vision_encoder_dict = (
dict(self.vision_encoder.named_parameters())
if self.vision_encoder is not None
else {}
)
patch_merger_dict = (
dict(self.patch_merger.named_parameters())
if self.patch_merger is not None
else {}
)
pre_mm_projector_norm_dict = (
dict(self.pre_mm_projector_norm.named_parameters())
if self.pre_mm_projector_norm is not None
else {}
)
vision_lang_adapter_dict = (
dict(self.vision_language_adapter.named_parameters())
if self.vision_language_adapter is not None
else {}
)
def llm_weights_generator():
# Single pass over weights
for name, w in weights:
if is_vision_encoder_weights((name, w)):
if self.vision_encoder is None:
continue
# Load vision encoder weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = vision_encoder_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_patch_merger((name, w)):
if self.patch_merger is None:
continue
# Load vision patch merger weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = patch_merger_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_pre_mm_projector_norm((name, w)):
if self.pre_mm_projector_norm is None:
continue
# Load vision pre_mm_projector_norm weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = pre_mm_projector_norm_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)):
if self.vision_language_adapter is None:
continue
# Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = vision_lang_adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
else:
# LLM weights: yield them to be loaded
# by language_model.load_weights
yield (name, w)
# Now we call the language model load with the generator
self.language_model.load_weights(llm_weights_generator())
# Vision encoder
@dataclass
class VisionEncoderArgs:
hidden_size: int
num_channels: int
image_size: int
patch_size: int
intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
rope_theta: float # for rope-2D
image_token_id: int
adapter_bias: bool = True
spatial_merge_size: int = 1
add_pre_mm_projector_layer_norm: bool = False
mm_projector_id: str = ""
def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim = x.ndim
assert ndim > 1
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
freqs_cis.shape,
(x.shape[1], x.shape[-1]),
)
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def precompute_freqs_cis_2d(
dim: int,
height: int,
width: int,
theta: float,
) -> torch.Tensor:
"""
freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
to be indexed by (height, width) position tuples
"""
# (dim / 2) frequency bases
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
h = torch.arange(height, device=freqs.device)
w = torch.arange(width, device=freqs.device)
freqs_h = torch.outer(h, freqs[::2]).float()
freqs_w = torch.outer(w, freqs[1::2]).float()
freqs_2d = torch.cat(
[
freqs_h[:, None, :].repeat(1, width, 1),
freqs_w[None, :, :].repeat(height, 1, 1),
],
dim=-1,
)
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
def apply_rotary_emb_vit(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
assert freqs_cis.dtype == torch.complex64
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class FeedForward(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
assert args.intermediate_size is not None
self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Attention(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.args = args
assert not args.hidden_size % args.num_attention_heads
self.n_heads = args.num_attention_heads
self.head_dim = args.hidden_size // args.num_attention_heads
self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
batch, patches, _ = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
if USE_XFORMERS_OPS:
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
else:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
out = out.transpose(1, 2)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.wo(out)
class TransformerBlock(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = FeedForward(args)
self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(
self.attention_norm(x), mask=mask, freqs_cis=freqs_cis
)
h = x + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
class Transformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(args.num_hidden_layers):
self.layers.append(TransformerBlock(args))
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
freqs_cis: torch.Tensor | None,
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, mask=mask, freqs_cis=freqs_cis)
return x
def position_meshgrid(
patch_embeds_list: list[torch.Tensor],
) -> torch.Tensor:
positions = torch.cat(
[
torch.stack(
torch.meshgrid(
torch.arange(p.shape[-2]),
torch.arange(p.shape[-1]),
indexing="ij",
),
dim=-1,
).reshape(-1, 2)
for p in patch_embeds_list
]
)
return positions
class VisionTransformer(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.args = args
self.patch_conv = Conv2dLayer(
in_channels=args.num_channels,
out_channels=args.hidden_size,
kernel_size=args.patch_size,
stride=args.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
self.transformer = Transformer(args)
head_dim = self.args.hidden_size // self.args.num_attention_heads
assert head_dim % 2 == 0, "ROPE requires even head_dim"
self._freqs_cis: torch.Tensor | None = None
@property
def max_patches_per_side(self) -> int:
return self.args.image_size // self.args.patch_size
@property
def device(self) -> torch.types.Device:
return next(self.parameters()).device
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@property
def freqs_cis(self) -> torch.Tensor:
if self._freqs_cis is None:
self._freqs_cis = precompute_freqs_cis_2d(
dim=self.args.hidden_size // self.args.num_attention_heads,
height=self.max_patches_per_side,
width=self.max_patches_per_side,
theta=self.args.rope_theta,
)
if self._freqs_cis.device != self.device:
self._freqs_cis = self._freqs_cis.to(device=self.device)
return self._freqs_cis
def forward(
self,
images: list[torch.Tensor],
) -> torch.Tensor:
"""
Args:
images: list of N_img images of variable sizes,
each of shape (C, H, W)
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
]
patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list]
embed_sizes = [p.shape[1] for p in patch_embeds]
# flatten to a single sequence
patch_embeds = torch.cat(patch_embeds, dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
positions = position_meshgrid(patch_embeds_list).to(self.device)
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
# pass through Transformer with a block diagonal mask delimiting images
if USE_XFORMERS_OPS:
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
)
else:
from transformers.models.pixtral.modeling_pixtral import (
generate_block_attention_mask,
)
mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
)
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
# squeeze dim 0 and split into separate tensors for each image
return torch.split(out.squeeze(0), embed_sizes)
class VisionLanguageAdapter(nn.Module):
def __init__(self, args: VisionEncoderArgs, dim: int):
super().__init__()
assert isinstance(args, VisionEncoderArgs)
self.w_in = nn.Linear(
args.hidden_size,
dim,
bias=args.adapter_bias,
)
self.gelu = nn.GELU()
self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x)))
class PatchMerger(nn.Module):
"""
Learned merging of spatial_merge_size ** 2 patches
"""
def __init__(
self,
vision_encoder_dim: int,
spatial_merge_size: int,
use_mlp_bias: bool = False,
) -> None:
super().__init__()
mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)
self.spatial_merge_size = spatial_merge_size
self.mlp_input_dim = mlp_input_dim
self.merging_layer = nn.Linear(
mlp_input_dim,
vision_encoder_dim,
bias=use_mlp_bias,
)
def forward(
self, x: torch.Tensor, image_sizes: list[tuple[int, int]]
) -> torch.Tensor:
# image_sizes specified in tokens
assert sum([h * w for h, w in image_sizes]) == len(x)
# x is (N, vision_encoder_dim)
x = self.permute(x, image_sizes)
# x is (N / spatial_merge_size ** 2,
# vision_encoder_dim * spatial_merge_size ** 2)
x = self.merging_layer(x)
# x is (N / spatial_merge_size ** 2, vision_encoder_dim)
return x
def permute(
self,
x: torch.Tensor,
image_sizes: list[tuple[int, int]],
) -> torch.Tensor:
"""
Args:
x: (N, D) where N is flattened and concatenated patch tokens
for all images
image_sizes: list of tuple of (height, width) in tokens for
each image
Returns:
image_features: reorders patch tokens so each grid of
(spatial_merge_size, spatial_merge_size) is contiguous.
now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
"""
sub_grids = get_sub_grids(
x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size
) # list of [d x sub_grid_size x sub_grid_size x n_patches]
permuted_tensor: list[torch.Tensor] = []
for grid in sub_grids:
n_patches = grid.shape[-1]
permuted_tensor.append(
grid.view(-1, n_patches).t()
) # n_patches x d * sub_grid_size * sub_grid_size
return torch.cat(
permuted_tensor, dim=0
) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)
def get_sub_grids(
x: torch.Tensor,
image_sizes: list[tuple[int, int]],
spatial_merge_size: int,
) -> list[torch.Tensor]:
# image_sizes specified in tokens
tokens_per_image = [h * w for h, w in image_sizes]
d = x.shape[-1]
all_img_sub_grids: list[torch.Tensor] = []
sub_grid_size = spatial_merge_size
for image_index, image_tokens in enumerate(x.split(tokens_per_image)):
# Reshape image_tokens into a 2D grid
h, w = image_sizes[image_index]
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[
None, :, :, :
] # 1 x d x h x w
sub_grids = torch.nn.functional.unfold(
image_grid, kernel_size=sub_grid_size, stride=sub_grid_size
)
sub_grids = sub_grids.view(
1, d, sub_grid_size, sub_grid_size, -1
) # 1 x d x sub_grid_size x sub_grid_size x n_patches
all_img_sub_grids.append(sub_grids[0])
return all_img_sub_grids
#### HF Transformers version of Pixtral ####
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
# This model follows the Llava family, meaning image embeddings are placed
# instead of the `[IMG]` token placeholders.
# The model uses [`PixtralVisionModel`] for its vision encoder,
# and [`MistralForCausalLM`] for its language decoder.
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
ncols, nrows = self.get_patch_grid_size(
image_width=image_width,
image_height=image_height,
)
return ncols * nrows
def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
# spatial_merge_size is needed for Mistral3
spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1)
return self.vision_config.patch_size * spatial_merge_size
def get_patch_grid_length(self) -> int:
image_size, patch_size = self.get_image_size(), self.get_patch_size()
# Since interpolation is applied, the image size need not be divisible
# assert image_size % patch_size == 0
return image_size // patch_size
# Adapted from: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/pixtral/image_processing_pixtral.py#L99
def get_patch_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
max_width = max_height = self.get_image_size()
patch_width = patch_height = self.get_patch_size()
ratio = max(image_width / max_width, image_height / max_height)
if ratio > 1:
image_width = int(math.floor(image_width / ratio))
image_height = int(math.floor(image_height / ratio))
nrows, ncols = _get_pixtral_hf_num_image_tokens(
(image_height, image_width),
(patch_height, patch_width),
) # type: ignore
return ncols, nrows
class PixtralHFMLP(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
) -> None:
super().__init__()
assert config.intermediate_size is not None
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
output_sizes=[config.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_and_mul(gate_up)
x, _ = self.down_proj(x)
return x
class PixtralHFAttention(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
assert not config.hidden_size % config.num_attention_heads
self.total_num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.n_heads = divide(config.num_attention_heads, tp_size)
self.head_dim = config.hidden_size // config.num_attention_heads
self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
assert self.total_num_heads * self.head_dim == config.hidden_size
self.o_proj = RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
batch, patches, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
q, k, v = qkv_states.chunk(3, dim=-1)
# Transpose q and k to apply HF's Rotary Position Embedding
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(batch, patches, self.n_heads, self.head_dim)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
if USE_XFORMERS_OPS:
# Transpose q and k back for attention
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
out = xops.memory_efficient_attention(q, k, v, attn_bias=attention_mask)
else:
v = v.transpose(1, 2)
out = nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask
)
out = out.transpose(1, 2)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
attn_output, _ = self.o_proj(out)
return attn_output, None
class PixtralHFTransformerBlock(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
) -> None:
super().__init__()
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
self.attention = PixtralHFAttention(
config, quant_config=quant_config, prefix=f"{prefix}.attention"
)
self.feed_forward = PixtralHFMLP(
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
)
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
r, _ = self.attention.forward(
self.attention_norm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings,
)
h = hidden_states + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
class PixtralHFTransformer(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
prefix: str = "",
) -> None:
super().__init__()
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList(
[
PixtralHFTransformerBlock(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(num_hidden_layers)
]
)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
return_all_hidden_states: bool,
) -> torch.Tensor:
hidden_states_pool = [x]
for layer in self.layers:
x = layer(x, attention_mask, position_embeddings)
if return_all_hidden_states:
hidden_states_pool.append(x)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return x
class PixtralHFVisionModel(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.patch_conv = Conv2dLayer(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer",
)
num_hidden_layers = config.num_hidden_layers
if len(self.transformer.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.transformer.layers)} "
"layers."
)
if require_post_norm is True:
msg = "PixtralHFVisionModel does not have post-layernorm"
raise ValueError(msg)
self.dtype = next(self.parameters()).dtype
self.device = next(self.parameters()).device
self.patch_positional_embedding = PixtralRotaryEmbedding(config, self.device)
def forward(
self,
pixel_values: list[torch.Tensor],
*,
select_layers: list[int] | None = None,
feature_select_strategy: VisionFeatureSelectStrategy | None = None,
) -> tuple[torch.Tensor, ...]:
"""
Args:
pixel_values: Each image to be processed will be a separate tensor
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially
select_layers: Layer indices whose features should be
concatenated and used as the visual encoder output. If none
are provided, the last layer is used.
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values
]
patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list]
embed_sizes = [p.shape[1] for p in patch_embeds]
# flatten to a single sequence
patch_embeds = torch.cat(patch_embeds, dim=1)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
position_ids = position_ids_in_meshgrid(
patch_embeds_list,
max_width=self.config.image_size // self.config.patch_size,
).to(self.device)
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
if USE_XFORMERS_OPS:
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
)
else:
from transformers.models.pixtral.modeling_pixtral import (
generate_block_attention_mask,
)
attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
)
out = self.transformer(
patch_embeds,
attention_mask,
position_embedding,
return_all_hidden_states=select_layers is not None,
)
out = resolve_visual_encoder_outputs(
out,
None,
select_layers=select_layers,
max_possible_layers=self.config.num_hidden_layers,
feature_select_strategy=feature_select_strategy,
)
# squeeze dim 0 and split into separate tensors for each image
return torch.split(out.squeeze(0), embed_sizes)
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
layer_count = len(self.transformer.layers)
for name, loaded_weight in weights:
# omit layers when num_hidden_layers_override is set
if name.startswith("transformer.layers"):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params