mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:01:40 +08:00
1317 lines
47 KiB
Python
1317 lines
47 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import math
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from dataclasses import dataclass, fields
|
|
from functools import cached_property
|
|
from typing import Literal, Optional, TypedDict, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
|
UserMessage)
|
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
|
from PIL import Image
|
|
from transformers import PixtralVisionConfig, TensorType
|
|
from transformers.image_utils import ImageInput
|
|
from transformers.models.pixtral.image_processing_pixtral import (
|
|
_num_image_tokens as _get_pixtral_hf_num_image_tokens)
|
|
from transformers.models.pixtral.modeling_pixtral import (
|
|
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
|
from transformers.tokenization_utils_base import TextInput
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
|
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|
NestedTensors)
|
|
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
|
MultiModalDataItems)
|
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|
BaseProcessingInfo, MultiModalHashes,
|
|
PromptReplacement, PromptUpdate,
|
|
PromptUpdateDetails)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
|
cached_tokenizer_from_config)
|
|
|
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
|
merge_multimodal_embeddings)
|
|
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
|
|
|
try:
|
|
from xformers import ops as xops
|
|
USE_XFORMERS_OPS = True
|
|
except ImportError:
|
|
USE_XFORMERS_OPS = False
|
|
|
|
PATCH_MERGE = "patch_merge"
|
|
|
|
|
|
class PixtralImagePixelInputs(TypedDict):
|
|
type: Literal["pixel_values"]
|
|
|
|
images: Union[torch.Tensor, list[torch.Tensor]]
|
|
"""
|
|
Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
|
|
|
|
The result of stacking `ImageEncoding.tokens` from each prompt.
|
|
"""
|
|
|
|
|
|
class PixtralProcessorAdapter:
|
|
"""
|
|
Provide a HF-compatible interface for
|
|
`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
|
|
"""
|
|
|
|
def __init__(self, tokenizer: MistralTokenizer) -> None:
|
|
super().__init__()
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
@property
|
|
def image_processor(self) -> ImageEncoder:
|
|
image_encoder = self.tokenizer.instruct.mm_encoder
|
|
assert isinstance(image_encoder, ImageEncoder)
|
|
return image_encoder
|
|
|
|
@cached_property
|
|
def image_break_id(self) -> int:
|
|
return self.image_processor.special_ids.img_break
|
|
|
|
@cached_property
|
|
def image_token_id(self) -> int:
|
|
return self.image_processor.special_ids.img
|
|
|
|
@cached_property
|
|
def image_end_id(self) -> int:
|
|
return self.image_processor.special_ids.img_end
|
|
|
|
@cached_property
|
|
def image_size(self) -> int:
|
|
return self.image_processor.mm_config.max_image_size
|
|
|
|
@cached_property
|
|
def patch_size(self) -> int:
|
|
return self.image_processor.mm_config.image_patch_size
|
|
|
|
def __call__(
|
|
self,
|
|
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
|
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
**kwargs,
|
|
) -> Mapping[str, NestedTensors]:
|
|
if text is None:
|
|
text = []
|
|
if not isinstance(text, list):
|
|
text = [text]
|
|
if images is None:
|
|
images = []
|
|
if not isinstance(images, list):
|
|
images = [images]
|
|
|
|
if not images:
|
|
input_ids = self.tokenizer(text).input_ids
|
|
|
|
return {"input_ids": torch.tensor(input_ids)}
|
|
|
|
# Allow dummy text, which is used for profiling as well as token inputs
|
|
if any(len(t) > 0 for t in text):
|
|
raise ValueError(
|
|
"You've passed text inputs instead of token inputs. "
|
|
"Make sure to process your input via `mistral_common`'s "
|
|
"tokenizer or pass a chat completion request. "
|
|
"For more info, see: "
|
|
"https://github.com/vllm-project/vllm/issues/8411.")
|
|
|
|
images_processed = list[torch.Tensor]()
|
|
images_tokens = list[torch.Tensor]()
|
|
|
|
for image in images:
|
|
image_inputs = self.image_processor(ImageChunk(image=image))
|
|
image_processed = torch.tensor(image_inputs.image)
|
|
image_tokens = torch.tensor(image_inputs.tokens)
|
|
|
|
images_processed.append(image_processed)
|
|
images_tokens.append(image_tokens)
|
|
|
|
return {
|
|
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
|
"images": images_processed,
|
|
}
|
|
|
|
|
|
class PixtralProcessingInfo(BaseProcessingInfo):
|
|
|
|
def get_tokenizer(self) -> MistralTokenizer:
|
|
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
|
|
if not isinstance(tokenizer, MistralTokenizer):
|
|
raise ValueError("This model requires `--tokenizer-mode mistral`")
|
|
|
|
return tokenizer
|
|
|
|
def get_hf_processor(self) -> PixtralProcessorAdapter:
|
|
return PixtralProcessorAdapter(self.get_tokenizer())
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
return {"image": None}
|
|
|
|
def get_vision_config(
|
|
self,
|
|
processor: Optional[PixtralProcessorAdapter] = None,
|
|
):
|
|
if processor is None:
|
|
processor = self.get_hf_processor()
|
|
|
|
return PixtralVisionConfig(
|
|
image_size=processor.image_size,
|
|
patch_size=processor.patch_size,
|
|
)
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
processor: Optional[PixtralProcessorAdapter] = None,
|
|
) -> int:
|
|
if processor is None:
|
|
processor = self.get_hf_processor()
|
|
|
|
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
|
Image.new("RGB", (image_width, image_height)))
|
|
|
|
return ncols * nrows
|
|
|
|
def get_image_size_with_most_features(self) -> ImageSize:
|
|
image_processor = self.get_hf_processor().image_processor
|
|
max_image_size = image_processor.mm_config.max_image_size
|
|
|
|
return ImageSize(width=max_image_size, height=max_image_size)
|
|
|
|
|
|
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
return ""
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> MultiModalDataDict:
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
target_width, target_height = \
|
|
self.info.get_image_size_with_most_features()
|
|
|
|
return {
|
|
"image":
|
|
self._get_dummy_images(width=target_width,
|
|
height=target_height,
|
|
num_images=num_images)
|
|
}
|
|
|
|
def get_dummy_processor_inputs(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> ProcessorInputs:
|
|
tokenizer = self.info.get_tokenizer()
|
|
|
|
dummy_text = self.get_dummy_text(mm_counts)
|
|
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
|
dummy_images = dummy_mm_data.get("image", [])
|
|
|
|
request = ChatCompletionRequest(messages=[
|
|
UserMessage(content=[
|
|
TextChunk(text=dummy_text),
|
|
*(ImageChunk(image=image) for image in dummy_images),
|
|
]),
|
|
])
|
|
res = tokenizer.mistral.encode_chat_completion(request)
|
|
dummy_tokens = res.tokens
|
|
|
|
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
|
|
|
|
|
|
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
|
):
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: Mapping[str, NestedTensors],
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(images=MultiModalFieldConfig.batched("image"))
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargs,
|
|
) -> Sequence[PromptUpdate]:
|
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
|
|
image_break_id = processor.image_break_id
|
|
image_token_id = processor.image_token_id
|
|
image_end_id = processor.image_end_id
|
|
|
|
def get_replacement(item_idx: int):
|
|
images = mm_items.get_items("image", ImageProcessorItems)
|
|
image_size = images.get_image_size(item_idx)
|
|
|
|
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
|
Image.new("RGB", (image_size.width, image_size.height)))
|
|
|
|
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
|
tokens[-1] = image_end_id
|
|
|
|
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="image",
|
|
target="", # Never match the prompt (see below note)
|
|
replacement=get_replacement,
|
|
),
|
|
]
|
|
|
|
def _cached_apply_hf_processor(
|
|
self,
|
|
prompt: Union[str, list[int]],
|
|
mm_data_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
*,
|
|
return_mm_hashes: bool,
|
|
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
|
(
|
|
prompt_ids,
|
|
mm_kwargs,
|
|
mm_hashes,
|
|
_,
|
|
) = super()._cached_apply_hf_processor(
|
|
prompt=prompt,
|
|
mm_data_items=mm_data_items,
|
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
|
return_mm_hashes=return_mm_hashes,
|
|
)
|
|
|
|
# NOTE: The tokens are already inserted by the chat template
|
|
return prompt_ids, mm_kwargs, mm_hashes, True
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
|
|
info=PixtralProcessingInfo,
|
|
dummy_inputs=PixtralDummyInputsBuilder)
|
|
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|
SupportsPP):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
|
|
dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
|
|
vision_args = {
|
|
key: value
|
|
for key, value in self.config.vision_config.to_dict().items()
|
|
if key in dataclass_fields
|
|
}
|
|
|
|
self.vision_args = VisionEncoderArgs(**vision_args)
|
|
|
|
# init MistralForCausalLM
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
)
|
|
|
|
self.vision_encoder = VisionTransformer(self.vision_args)
|
|
|
|
if self.vision_args.add_pre_mm_projector_layer_norm:
|
|
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
|
|
eps=1e-5)
|
|
|
|
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
|
self.patch_merger = PatchMerger(
|
|
vision_encoder_dim=self.vision_args.hidden_size,
|
|
spatial_merge_size=self.vision_args.spatial_merge_size,
|
|
use_mlp_bias=False,
|
|
)
|
|
|
|
self.vision_language_adapter = VisionLanguageAdapter(
|
|
self.vision_args, dim=config.text_config.hidden_size)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors)
|
|
|
|
def _parse_and_validate_image_input(
|
|
self, **kwargs: object) -> Optional[PixtralImagePixelInputs]:
|
|
images = kwargs.pop("images", None)
|
|
if images is None:
|
|
return None
|
|
|
|
if not isinstance(images, (torch.Tensor, list)):
|
|
raise ValueError("Incorrect type of images. "
|
|
f"Got type: {type(images)}")
|
|
|
|
return PixtralImagePixelInputs(
|
|
type="pixel_values",
|
|
images=flatten_bn(images),
|
|
)
|
|
|
|
def _process_image_input(
|
|
self,
|
|
image_input: PixtralImagePixelInputs,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
images = image_input["images"]
|
|
image_features = self.vision_encoder(images)
|
|
feature_sizes = [
|
|
image_feature.shape[0] for image_feature in image_features
|
|
]
|
|
image_features = torch.cat(image_features)
|
|
if self.vision_args.add_pre_mm_projector_layer_norm:
|
|
image_features = self.pre_mm_projector_norm(image_features)
|
|
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
|
patch_size = self.vision_args.patch_size
|
|
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
|
|
img_patch_dims = [(img.shape[1] // patch_size,
|
|
img.shape[2] // patch_size) for img in images]
|
|
feature_sizes = [
|
|
feature_size // spatial_merge_size_square
|
|
for feature_size in feature_sizes
|
|
]
|
|
image_features = self.patch_merger(image_features,
|
|
image_sizes=img_patch_dims)
|
|
image_embeds = self.vision_language_adapter(image_features)
|
|
image_embeds = torch.split(image_embeds, feature_sizes)
|
|
return image_embeds
|
|
|
|
def get_language_model(self) -> torch.nn.Module:
|
|
return self.language_model
|
|
|
|
def get_multimodal_embeddings(self,
|
|
**kwargs: object) -> MultiModalEmbeddings:
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
if image_input is None:
|
|
return []
|
|
|
|
return self._process_image_input(image_input)
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
if multimodal_embeddings is not None:
|
|
inputs_embeds = merge_multimodal_embeddings(
|
|
input_ids,
|
|
inputs_embeds,
|
|
multimodal_embeddings,
|
|
self.vision_args.image_token_id,
|
|
)
|
|
return inputs_embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
**kwargs: object,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
"""Run forward pass for pixtral."""
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
|
# condition is for v0 compatibility.
|
|
elif inputs_embeds is None:
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
|
vision_embeddings)
|
|
input_ids = None
|
|
|
|
hidden_states = self.language_model.model(input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
return self.language_model.compute_logits(hidden_states,
|
|
sampling_metadata)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
|
|
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
|
|
return weight[0].startswith("vision_encoder")
|
|
|
|
def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
|
|
return weight[0].startswith("vision_language_adapter")
|
|
|
|
def is_patch_merger(weight: tuple[str, torch.Tensor]):
|
|
return weight[0].startswith("patch_merger")
|
|
|
|
def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
|
|
return weight[0].startswith("pre_mm_projector_norm")
|
|
|
|
# Get references to parameters for direct loading
|
|
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
|
|
patch_merger_dict = dict(self.patch_merger.named_parameters(
|
|
)) if self.vision_args.mm_projector_id == PATCH_MERGE else dict()
|
|
pre_mm_projector_norm_dict = dict(
|
|
self.pre_mm_projector_norm.named_parameters(
|
|
)) if self.vision_args.add_pre_mm_projector_layer_norm else dict()
|
|
vision_lang_adapter_dict = dict(
|
|
self.vision_language_adapter.named_parameters())
|
|
|
|
def llm_weights_generator():
|
|
# Single pass over weights
|
|
for name, w in weights:
|
|
if is_vision_encoder_weights((name, w)):
|
|
# Load vision encoder weights directly
|
|
trimmed_name = '.'.join(name.split(".")[1:])
|
|
param = vision_encoder_dict[trimmed_name]
|
|
with torch.no_grad():
|
|
default_weight_loader(param, w)
|
|
elif is_patch_merger((name, w)):
|
|
# Load vision patch merger weights directly
|
|
trimmed_name = '.'.join(name.split(".")[1:])
|
|
param = patch_merger_dict[trimmed_name]
|
|
with torch.no_grad():
|
|
default_weight_loader(param, w)
|
|
elif is_pre_mm_projector_norm((name, w)):
|
|
# Load vision pre_mm_projector_norm weights directly
|
|
trimmed_name = '.'.join(name.split(".")[1:])
|
|
param = pre_mm_projector_norm_dict[trimmed_name]
|
|
with torch.no_grad():
|
|
default_weight_loader(param, w)
|
|
elif is_vision_lang_adapter_weights((name, w)):
|
|
# Load vision-language adapter weights directly
|
|
trimmed_name = '.'.join(name.split(".")[1:])
|
|
param = vision_lang_adapter_dict[trimmed_name]
|
|
with torch.no_grad():
|
|
default_weight_loader(param, w)
|
|
else:
|
|
# LLM weights: yield them to be loaded
|
|
# by language_model.load_weights
|
|
yield (name, w)
|
|
|
|
# Now we call the language model load with the generator
|
|
self.language_model.load_weights(llm_weights_generator())
|
|
|
|
|
|
# Vision encoder
|
|
@dataclass
|
|
class VisionEncoderArgs:
|
|
hidden_size: int
|
|
num_channels: int
|
|
image_size: int
|
|
patch_size: int
|
|
intermediate_size: int
|
|
num_hidden_layers: int
|
|
num_attention_heads: int
|
|
rope_theta: float # for rope-2D
|
|
image_token_id: int
|
|
adapter_bias: bool = True
|
|
spatial_merge_size: int = 1
|
|
add_pre_mm_projector_layer_norm: bool = False
|
|
mm_projector_id: str = ""
|
|
|
|
|
|
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
|
|
x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
freqs_cis: complex - (seq_len, head_dim / 2)
|
|
x: complex - (bsz, seq_len, head_dim / 2)
|
|
"""
|
|
ndim = x.ndim
|
|
assert ndim > 1
|
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
|
|
freqs_cis.shape,
|
|
(x.shape[1], x.shape[-1]),
|
|
)
|
|
shape = [
|
|
d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
|
|
]
|
|
return freqs_cis.view(*shape)
|
|
|
|
|
|
def precompute_freqs_cis_2d(
|
|
dim: int,
|
|
height: int,
|
|
width: int,
|
|
theta: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
|
|
to be indexed by (height, width) position tuples
|
|
"""
|
|
# (dim / 2) frequency bases
|
|
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
|
|
|
|
h = torch.arange(height, device=freqs.device)
|
|
w = torch.arange(width, device=freqs.device)
|
|
|
|
freqs_h = torch.outer(h, freqs[::2]).float()
|
|
freqs_w = torch.outer(w, freqs[1::2]).float()
|
|
freqs_2d = torch.cat(
|
|
[
|
|
freqs_h[:, None, :].repeat(1, width, 1),
|
|
freqs_w[None, :, :].repeat(height, 1, 1),
|
|
],
|
|
dim=-1,
|
|
)
|
|
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
|
|
|
|
|
def apply_rotary_emb_vit(
|
|
xq: torch.Tensor,
|
|
xk: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
assert freqs_cis.dtype == torch.complex64
|
|
freqs_cis = _reshape_for_broadcast(freqs_cis, xq_)
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
assert args.intermediate_size is not None
|
|
self.w1 = nn.Linear(args.hidden_size,
|
|
args.intermediate_size,
|
|
bias=False)
|
|
self.w2 = nn.Linear(args.intermediate_size,
|
|
args.hidden_size,
|
|
bias=False)
|
|
self.w3 = nn.Linear(args.hidden_size,
|
|
args.intermediate_size,
|
|
bias=False)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
assert not args.hidden_size % args.num_attention_heads
|
|
self.n_heads = args.num_attention_heads
|
|
self.head_dim = args.hidden_size // args.num_attention_heads
|
|
|
|
self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
batch, patches, _ = x.shape
|
|
|
|
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
|
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
|
|
|
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
|
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
|
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
|
return self.wo(out)
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.attention = Attention(args)
|
|
self.feed_forward = FeedForward(args)
|
|
self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5)
|
|
self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
r = self.attention.forward(self.attention_norm(x),
|
|
mask=mask,
|
|
freqs_cis=freqs_cis)
|
|
h = x + r
|
|
r = self.feed_forward.forward(self.ffn_norm(h))
|
|
out = h + r
|
|
return out
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.layers = torch.nn.ModuleList()
|
|
for _ in range(args.num_hidden_layers):
|
|
self.layers.append(TransformerBlock(args))
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
freqs_cis: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
for layer in self.layers:
|
|
x = layer(x, mask=mask, freqs_cis=freqs_cis)
|
|
return x
|
|
|
|
|
|
def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor:
|
|
positions = torch.cat([
|
|
torch.stack(
|
|
torch.meshgrid(
|
|
torch.arange(p.shape[-2]),
|
|
torch.arange(p.shape[-1]),
|
|
indexing="ij",
|
|
),
|
|
dim=-1,
|
|
).reshape(-1, 2) for p in patch_embeds_list
|
|
])
|
|
return positions
|
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs):
|
|
super().__init__()
|
|
self.args = args
|
|
self.patch_conv = nn.Conv2d(
|
|
in_channels=args.num_channels,
|
|
out_channels=args.hidden_size,
|
|
kernel_size=args.patch_size,
|
|
stride=args.patch_size,
|
|
bias=False,
|
|
)
|
|
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
|
|
self.transformer = Transformer(args)
|
|
|
|
head_dim = self.args.hidden_size // self.args.num_attention_heads
|
|
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
|
self._freqs_cis: Optional[torch.Tensor] = None
|
|
|
|
@property
|
|
def max_patches_per_side(self) -> int:
|
|
return self.args.image_size // self.args.patch_size
|
|
|
|
@property
|
|
def device(self) -> torch.types.Device:
|
|
return next(self.parameters()).device
|
|
|
|
@property
|
|
def dtype(self) -> torch.dtype:
|
|
return next(self.parameters()).dtype
|
|
|
|
@property
|
|
def freqs_cis(self) -> torch.Tensor:
|
|
if self._freqs_cis is None:
|
|
self._freqs_cis = precompute_freqs_cis_2d(
|
|
dim=self.args.hidden_size // self.args.num_attention_heads,
|
|
height=self.max_patches_per_side,
|
|
width=self.max_patches_per_side,
|
|
theta=self.args.rope_theta,
|
|
)
|
|
|
|
if self._freqs_cis.device != self.device:
|
|
self._freqs_cis = self._freqs_cis.to(device=self.device)
|
|
|
|
return self._freqs_cis
|
|
|
|
def forward(
|
|
self,
|
|
images: list[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
images: list of N_img images of variable sizes,
|
|
each of shape (C, H, W)
|
|
Returns:
|
|
image_features: tensor of token features for
|
|
all tokens of all images of shape (N_toks, D)
|
|
"""
|
|
# pass images through initial convolution independently
|
|
patch_embeds_list = [
|
|
self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
|
|
]
|
|
|
|
patch_embeds = [
|
|
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
|
|
]
|
|
embed_sizes = [p.shape[1] for p in patch_embeds]
|
|
|
|
# flatten to a single sequence
|
|
patch_embeds = torch.cat(patch_embeds, dim=1)
|
|
patch_embeds = self.ln_pre(patch_embeds)
|
|
|
|
# positional embeddings
|
|
positions = position_meshgrid(patch_embeds_list).to(self.device)
|
|
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
|
|
|
|
# pass through Transformer with a block diagonal mask delimiting images
|
|
if USE_XFORMERS_OPS:
|
|
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
|
else:
|
|
raise ImportError("Xformers is required for Pixtral inference "
|
|
"with the Mistral format")
|
|
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
|
|
|
# squeeze dim 0 and split into separate tensors for each image
|
|
return torch.split(out.squeeze(0), embed_sizes)
|
|
|
|
|
|
class VisionLanguageAdapter(nn.Module):
|
|
|
|
def __init__(self, args: VisionEncoderArgs, dim: int):
|
|
super().__init__()
|
|
assert isinstance(args, VisionEncoderArgs)
|
|
self.w_in = nn.Linear(
|
|
args.hidden_size,
|
|
dim,
|
|
bias=args.adapter_bias,
|
|
)
|
|
self.gelu = nn.GELU()
|
|
self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.w_out(self.gelu(self.w_in(x)))
|
|
|
|
|
|
class PatchMerger(nn.Module):
|
|
"""
|
|
Learned merging of spatial_merge_size ** 2 patches
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vision_encoder_dim: int,
|
|
spatial_merge_size: int,
|
|
use_mlp_bias: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)
|
|
|
|
self.spatial_merge_size = spatial_merge_size
|
|
self.mlp_input_dim = mlp_input_dim
|
|
|
|
self.merging_layer = nn.Linear(
|
|
mlp_input_dim,
|
|
vision_encoder_dim,
|
|
bias=use_mlp_bias,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor,
|
|
image_sizes: list[tuple[int, int]]) -> torch.Tensor:
|
|
# image_sizes specified in tokens
|
|
assert sum([h * w for h, w in image_sizes]) == len(x)
|
|
|
|
# x is (N, vision_encoder_dim)
|
|
x = self.permute(x, image_sizes)
|
|
|
|
# x is (N / spatial_merge_size ** 2,
|
|
# vision_encoder_dim * spatial_merge_size ** 2)
|
|
x = self.merging_layer(x)
|
|
|
|
# x is (N / spatial_merge_size ** 2, vision_encoder_dim)
|
|
return x
|
|
|
|
def permute(
|
|
self,
|
|
x: torch.Tensor,
|
|
image_sizes: list[tuple[int, int]],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x: (N, D) where N is flattened and concatenated patch tokens
|
|
for all images
|
|
image_sizes: list of tuple of (height, width) in tokens for
|
|
each image
|
|
Returns:
|
|
image_features: reorders patch tokens so each grid of
|
|
(spatial_merge_size, spatial_merge_size) is contiguous.
|
|
now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
|
|
"""
|
|
|
|
sub_grids = get_sub_grids(
|
|
x=x,
|
|
image_sizes=image_sizes,
|
|
spatial_merge_size=self.spatial_merge_size
|
|
) # list of [d x sub_grid_size x sub_grid_size x n_patches]
|
|
permuted_tensor: list[torch.Tensor] = []
|
|
for grid in sub_grids:
|
|
n_patches = grid.shape[-1]
|
|
permuted_tensor.append(grid.view(-1, n_patches).t(
|
|
)) # n_patches x d * sub_grid_size * sub_grid_size
|
|
return torch.cat(
|
|
permuted_tensor, dim=0
|
|
) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)
|
|
|
|
|
|
def get_sub_grids(
|
|
x: torch.Tensor,
|
|
image_sizes: list[tuple[int, int]],
|
|
spatial_merge_size: int,
|
|
) -> list[torch.Tensor]:
|
|
# image_sizes specified in tokens
|
|
tokens_per_image = [h * w for h, w in image_sizes]
|
|
d = x.shape[-1]
|
|
all_img_sub_grids: list[torch.Tensor] = []
|
|
sub_grid_size = spatial_merge_size
|
|
|
|
for image_index, image_tokens in enumerate(x.split(tokens_per_image)):
|
|
# Reshape image_tokens into a 2D grid
|
|
h, w = image_sizes[image_index]
|
|
image_grid = image_tokens.view(h, w, d).permute(
|
|
2, 0, 1)[None, :, :, :] # 1 x d x h x w
|
|
sub_grids = torch.nn.functional.unfold(image_grid,
|
|
kernel_size=sub_grid_size,
|
|
stride=sub_grid_size)
|
|
sub_grids = sub_grids.view(
|
|
1, d, sub_grid_size, sub_grid_size,
|
|
-1) # 1 x d x sub_grid_size x sub_grid_size x n_patches
|
|
|
|
all_img_sub_grids.append(sub_grids[0])
|
|
|
|
return all_img_sub_grids
|
|
|
|
|
|
#### HF Transformers version of Pixtral ####
|
|
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
|
|
# This model follows the Llava family, meaning image embeddings are placed
|
|
# instead of the `[IMG]` token placeholders.
|
|
# The model uses [`PixtralVisionModel`] for its vision encoder,
|
|
# and [`MistralForCausalLM`] for its language decoder.
|
|
|
|
|
|
class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> int:
|
|
ncols, nrows = self.get_patch_grid_size(
|
|
image_width=image_width,
|
|
image_height=image_height,
|
|
)
|
|
return ncols * nrows
|
|
|
|
def get_image_size(self) -> int:
|
|
return self.vision_config.image_size
|
|
|
|
def get_patch_size(self) -> int:
|
|
# spatial_merge_size is needed for Mistral3
|
|
spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1)
|
|
return self.vision_config.patch_size * spatial_merge_size
|
|
|
|
def get_patch_grid_length(self) -> int:
|
|
image_size, patch_size = self.get_image_size(), self.get_patch_size()
|
|
|
|
# Since interpolation is applied, the image size need not be divisible
|
|
# assert image_size % patch_size == 0
|
|
return image_size // patch_size
|
|
|
|
# Adapted from: https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/pixtral/image_processing_pixtral.py#L99
|
|
def get_patch_grid_size(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> tuple[int, int]:
|
|
max_width = max_height = self.get_image_size()
|
|
patch_width = patch_height = self.get_patch_size()
|
|
|
|
ratio = max(image_width / max_width, image_height / max_height)
|
|
|
|
if ratio > 1:
|
|
image_width = int(math.floor(image_width / ratio))
|
|
image_height = int(math.floor(image_height / ratio))
|
|
|
|
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
|
(image_height, image_width),
|
|
(patch_height, patch_width),
|
|
) # type: ignore
|
|
|
|
return ncols, nrows
|
|
|
|
|
|
class PixtralHFMLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
assert config.intermediate_size is not None
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
input_size=config.hidden_size,
|
|
output_sizes=[config.intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj")
|
|
self.down_proj = RowParallelLinear(input_size=config.intermediate_size,
|
|
output_size=config.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj")
|
|
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_and_mul(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class PixtralHFAttention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
assert not config.hidden_size % config.num_attention_heads
|
|
self.total_num_heads = config.num_attention_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.n_heads = divide(config.num_attention_heads, tp_size)
|
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size=config.hidden_size,
|
|
head_size=self.head_dim,
|
|
total_num_heads=self.total_num_heads,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
assert self.total_num_heads * self.head_dim == config.hidden_size
|
|
self.o_proj = RowParallelLinear(
|
|
input_size=config.hidden_size,
|
|
output_size=config.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: torch.Tensor,
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
batch, patches, _ = hidden_states.size()
|
|
|
|
qkv_states, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv_states.chunk(3, dim=-1)
|
|
|
|
# Transpose q and k to apply HF's Rotary Position Embedding
|
|
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
|
|
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
|
|
v = v.view(batch, patches, self.n_heads, self.head_dim)
|
|
cos, sin = position_embeddings
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
|
|
|
|
if USE_XFORMERS_OPS:
|
|
# Transpose q and k back for attention
|
|
q = q.transpose(1, 2).contiguous()
|
|
k = k.transpose(1, 2).contiguous()
|
|
|
|
out = xops.memory_efficient_attention(q,
|
|
k,
|
|
v,
|
|
attn_bias=attention_mask)
|
|
else:
|
|
v = v.transpose(1, 2)
|
|
out = nn.functional.scaled_dot_product_attention(
|
|
q, k, v, attn_mask=attention_mask)
|
|
out = out.transpose(1, 2)
|
|
|
|
out = out.view(batch, patches, self.n_heads * self.head_dim)
|
|
attn_output, _ = self.o_proj(out)
|
|
|
|
return attn_output, None
|
|
|
|
|
|
class PixtralHFTransformerBlock(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
|
self.attention = PixtralHFAttention(config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attention")
|
|
self.feed_forward = PixtralHFMLP(config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.feed_forward")
|
|
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
r, _ = self.attention.forward(self.attention_norm(hidden_states),
|
|
attention_mask=attention_mask,
|
|
position_embeddings=position_embeddings)
|
|
h = hidden_states + r
|
|
r = self.feed_forward.forward(self.ffn_norm(h))
|
|
out = h + r
|
|
return out
|
|
|
|
|
|
class PixtralHFTransformer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
num_hidden_layers_override: Optional[int] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
if num_hidden_layers_override is None:
|
|
num_hidden_layers = config.num_hidden_layers
|
|
else:
|
|
num_hidden_layers = num_hidden_layers_override
|
|
|
|
self.layers = nn.ModuleList([
|
|
PixtralHFTransformerBlock(config=config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.layers.{layer_idx}")
|
|
for layer_idx in range(num_hidden_layers)
|
|
])
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
position_embeddings: torch.Tensor,
|
|
return_all_hidden_states: bool,
|
|
) -> torch.Tensor:
|
|
hidden_states_pool = [x]
|
|
|
|
for layer in self.layers:
|
|
x = layer(x, attention_mask, position_embeddings)
|
|
if return_all_hidden_states:
|
|
hidden_states_pool.append(x)
|
|
# If we have multiple feature sample layers, we return all hidden
|
|
# states in order and grab the ones we need by index.
|
|
if return_all_hidden_states:
|
|
return hidden_states_pool
|
|
return x
|
|
|
|
|
|
class PixtralHFVisionModel(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PixtralVisionConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
*,
|
|
num_hidden_layers_override: Optional[int] = None,
|
|
require_post_norm: Optional[bool] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
|
|
self.patch_conv = nn.Conv2d(
|
|
in_channels=config.num_channels,
|
|
out_channels=config.hidden_size,
|
|
kernel_size=config.patch_size,
|
|
stride=config.patch_size,
|
|
bias=False,
|
|
)
|
|
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
|
self.transformer = PixtralHFTransformer(
|
|
config,
|
|
quant_config,
|
|
num_hidden_layers_override=num_hidden_layers_override,
|
|
prefix=f"{prefix}.transformer",
|
|
)
|
|
|
|
num_hidden_layers = config.num_hidden_layers
|
|
if len(self.transformer.layers) > config.num_hidden_layers:
|
|
raise ValueError(
|
|
f"The original encoder only has {num_hidden_layers} "
|
|
f"layers, but you requested {len(self.transformer.layers)} "
|
|
"layers.")
|
|
|
|
if require_post_norm is True:
|
|
msg = "PixtralHFVisionModel does not have post-layernorm"
|
|
raise ValueError(msg)
|
|
|
|
self.dtype = next(self.parameters()).dtype
|
|
self.device = next(self.parameters()).device
|
|
self.patch_positional_embedding = PixtralRotaryEmbedding(
|
|
config, self.device)
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: list[torch.Tensor],
|
|
feature_sample_layers: Optional[list[int]] = None,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
"""
|
|
Args:
|
|
pixel_values: Each image to be processed will be a separate tensor
|
|
in pixel_values. This means it will be a list of tensors
|
|
because multiple requests batched can have multiple images,
|
|
each with their own shape potentially
|
|
feature_sample_layers: Layer indices whose features should be
|
|
concatenated and used as the visual encoder output. If none
|
|
are provided, the last layer is used.
|
|
|
|
Returns:
|
|
image_features: tensor of token features for
|
|
all tokens of all images of shape (N_toks, D)
|
|
"""
|
|
# pass images through initial convolution independently
|
|
patch_embeds_list = [
|
|
self.patch_conv(img.unsqueeze(0).to(self.dtype))
|
|
for img in pixel_values
|
|
]
|
|
|
|
patch_embeds = [
|
|
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
|
|
]
|
|
embed_sizes = [p.shape[1] for p in patch_embeds]
|
|
|
|
# flatten to a single sequence
|
|
patch_embeds = torch.cat(patch_embeds, dim=1)
|
|
patch_embeds = self.ln_pre(patch_embeds)
|
|
|
|
# positional embeddings
|
|
position_ids = position_ids_in_meshgrid(
|
|
patch_embeds_list,
|
|
max_width=self.config.image_size // self.config.patch_size).to(
|
|
self.device)
|
|
position_embedding = self.patch_positional_embedding(
|
|
patch_embeds, position_ids)
|
|
|
|
if USE_XFORMERS_OPS:
|
|
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
|
else:
|
|
from transformers.models.pixtral.modeling_pixtral import (
|
|
generate_block_attention_mask)
|
|
attention_mask = generate_block_attention_mask(
|
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
|
patch_embeds)
|
|
|
|
return_all_hidden_states = feature_sample_layers is not None
|
|
out = self.transformer(
|
|
patch_embeds,
|
|
attention_mask,
|
|
position_embedding,
|
|
return_all_hidden_states=return_all_hidden_states)
|
|
|
|
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
|
|
self.config.num_hidden_layers)
|
|
|
|
# squeeze dim 0 and split into separate tensors for each image
|
|
return torch.split(out.squeeze(0), embed_sizes)
|
|
|
|
# (TODO) Add prefix argument for filtering out weights to be loaded
|
|
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".qkv_proj", ".q_proj", "q"),
|
|
(".qkv_proj", ".k_proj", "k"),
|
|
(".qkv_proj", ".v_proj", "v"),
|
|
(".gate_up_proj", ".gate_proj", 0),
|
|
(".gate_up_proj", ".up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
layer_count = len(self.transformer.layers)
|
|
|
|
for name, loaded_weight in weights:
|
|
# omit layers when num_hidden_layers_override is set
|
|
if name.startswith("transformer.layers"):
|
|
layer_idx = int(name.split(".")[2])
|
|
if layer_idx >= layer_count:
|
|
continue
|
|
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|