mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 16:25:01 +08:00
638 lines
22 KiB
Python
638 lines
22 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from abc import abstractmethod
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from typing import Annotated, Final, Literal, Protocol, TypeVar
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import (
|
|
BatchFeature,
|
|
Mistral3Config,
|
|
PixtralVisionConfig,
|
|
PretrainedConfig,
|
|
)
|
|
from transformers.models.pixtral import PixtralProcessor
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
from vllm.model_executor.layers.activation import get_act_fn
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
)
|
|
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
|
|
from vllm.multimodal.processing import (
|
|
BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
InputProcessingContext,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
PromptUpdateDetails,
|
|
)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
from .interfaces import (
|
|
MultiModalEmbeddings,
|
|
SupportsLoRA,
|
|
SupportsMultiModal,
|
|
SupportsPP,
|
|
)
|
|
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
|
from .utils import (
|
|
AutoWeightsLoader,
|
|
WeightsMapper,
|
|
init_vllm_registered_model,
|
|
maybe_prefix,
|
|
)
|
|
from .vision import get_vision_encoder_info
|
|
|
|
|
|
class Mistral3ImagePixelInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- bn: Batch size * number of images
|
|
- c: Number of channels (3)
|
|
- h: Height of each image
|
|
- w: Width of each image
|
|
"""
|
|
|
|
type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
|
|
|
|
# Note that `height` or `width` may be different per batch and image,
|
|
# in which case the data is passed as a list instead of a batched tensor.
|
|
pixel_values: Annotated[
|
|
torch.Tensor | list[torch.Tensor],
|
|
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
|
|
]
|
|
|
|
|
|
class Mistral3PatchMerger(nn.Module):
|
|
"""
|
|
Learned merging of spatial_merge_size ** 2 patches
|
|
"""
|
|
|
|
def __init__(
|
|
self, vision_hidden_size: int, spatial_merge_size: int, patch_size: int
|
|
):
|
|
super().__init__()
|
|
|
|
self.vision_hidden_size = vision_hidden_size
|
|
self.spatial_merge_size = spatial_merge_size
|
|
self.patch_size = patch_size
|
|
self.merging_layer = nn.Linear(
|
|
vision_hidden_size * self.spatial_merge_size**2,
|
|
vision_hidden_size,
|
|
bias=False,
|
|
)
|
|
|
|
def forward(
|
|
self, image_features: torch.Tensor, image_sizes: torch.Tensor
|
|
) -> torch.Tensor:
|
|
image_sizes = [
|
|
(image_size[0] // self.patch_size, image_size[1] // self.patch_size)
|
|
for image_size in image_sizes
|
|
]
|
|
|
|
tokens_per_image = [h * w for h, w in image_sizes]
|
|
d = image_features.shape[-1]
|
|
|
|
permuted_tensor = []
|
|
for image_index, image_tokens in enumerate(
|
|
image_features.split(tokens_per_image)
|
|
):
|
|
# Reshape image_tokens into a 2D grid
|
|
h, w = image_sizes[image_index]
|
|
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
|
|
grid = torch.nn.functional.unfold(
|
|
image_grid,
|
|
kernel_size=self.spatial_merge_size,
|
|
stride=self.spatial_merge_size,
|
|
)
|
|
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
|
|
permuted_tensor.append(grid)
|
|
|
|
image_features = torch.cat(permuted_tensor, dim=0)
|
|
image_features = self.merging_layer(image_features)
|
|
return image_features
|
|
|
|
|
|
class Mistral3MultiModalProjector(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vision_hidden_size: int,
|
|
text_hidden_size: int,
|
|
spatial_merge_size: int,
|
|
patch_size: int,
|
|
projector_hidden_act: str,
|
|
multimodal_projector_bias: bool,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
|
|
self.norm = RMSNorm(vision_hidden_size, eps=1e-5)
|
|
self.patch_merger = Mistral3PatchMerger(
|
|
vision_hidden_size=vision_hidden_size,
|
|
spatial_merge_size=spatial_merge_size,
|
|
patch_size=patch_size,
|
|
)
|
|
|
|
self.linear_1 = ColumnParallelLinear(
|
|
vision_hidden_size,
|
|
text_hidden_size,
|
|
bias=multimodal_projector_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.linear_1",
|
|
)
|
|
self.act = get_act_fn(projector_hidden_act)
|
|
self.linear_2 = RowParallelLinear(
|
|
text_hidden_size,
|
|
text_hidden_size,
|
|
bias=multimodal_projector_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.linear_2",
|
|
)
|
|
|
|
def forward(
|
|
self, image_features: torch.Tensor, image_sizes: torch.Tensor
|
|
) -> torch.Tensor:
|
|
image_features = self.norm(image_features)
|
|
image_features = self.patch_merger(image_features, image_sizes)
|
|
hidden_states, _ = self.linear_1(image_features)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states, _ = self.linear_2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class LlavaLikeConfig(Protocol):
|
|
vision_config: Final[PretrainedConfig]
|
|
image_token_index: Final[int]
|
|
vision_feature_select_strategy: Final[str]
|
|
vision_feature_layer: Final[int | list[int]]
|
|
|
|
|
|
class LlavaLikeProcessor(Protocol):
|
|
image_token: Final[str]
|
|
|
|
|
|
class BaseLlavaProcessingInfo(BaseProcessingInfo):
|
|
def get_hf_config(self) -> LlavaLikeConfig:
|
|
return self.ctx.get_hf_config(Mistral3Config)
|
|
|
|
def get_vision_encoder_info(self):
|
|
return get_vision_encoder_info(self.get_hf_config())
|
|
|
|
@abstractmethod
|
|
def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
|
|
raise NotImplementedError
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
return {"image": None}
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> int:
|
|
vision_encoder_info = self.get_vision_encoder_info()
|
|
return vision_encoder_info.get_num_image_tokens(
|
|
image_width=image_width,
|
|
image_height=image_height,
|
|
)
|
|
|
|
def get_image_size_with_most_features(self) -> ImageSize:
|
|
vision_encoder_info = self.get_vision_encoder_info()
|
|
width = height = vision_encoder_info.get_image_size()
|
|
return ImageSize(width=width, height=height)
|
|
|
|
|
|
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
|
|
|
|
|
|
class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
processor = self.info.get_hf_processor()
|
|
image_token = processor.image_token
|
|
|
|
return image_token * num_images
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
) -> MultiModalDataDict:
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
target_width, target_height = self.info.get_image_size_with_most_features()
|
|
|
|
image_overrides = mm_options.get("image") if mm_options else None
|
|
|
|
return {
|
|
"image": self._get_dummy_images(
|
|
width=target_width,
|
|
height=target_height,
|
|
num_images=num_images,
|
|
overrides=image_overrides,
|
|
)
|
|
}
|
|
|
|
|
|
class Mistral3ProcessingInfo(BaseLlavaProcessingInfo):
|
|
def get_hf_processor(self, **kwargs: object):
|
|
return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
|
|
|
|
|
|
class Mistral3MultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingInfo]):
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
tok_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
processed_outputs = super()._call_hf_processor(
|
|
prompt=prompt,
|
|
mm_data=mm_data,
|
|
mm_kwargs=mm_kwargs,
|
|
tok_kwargs=tok_kwargs,
|
|
)
|
|
|
|
pixel_values = processed_outputs.get("pixel_values")
|
|
if pixel_values is not None:
|
|
# Avoid padding since we need the output for each image to be
|
|
# independent of other images for the cache to work correctly
|
|
image_sizes = processed_outputs["image_sizes"]
|
|
assert len(pixel_values) == len(image_sizes)
|
|
|
|
processed_outputs["pixel_values"] = [
|
|
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
|
|
]
|
|
|
|
return processed_outputs
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(
|
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
|
)
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
hf_config = self.info.get_hf_config()
|
|
tokenizer = self.info.get_tokenizer()
|
|
vocab = tokenizer.get_vocab()
|
|
|
|
image_break_id = vocab[processor.image_break_token]
|
|
image_token_id = hf_config.image_token_index
|
|
image_end_id = vocab[processor.image_end_token]
|
|
|
|
assert isinstance(hf_config.vision_config, PixtralVisionConfig)
|
|
encoder_info = PixtralHFEncoderInfo(hf_config)
|
|
|
|
def get_replacement(item_idx: int):
|
|
images = mm_items.get_items("image", ImageProcessorItems)
|
|
image_size = images.get_image_size(item_idx)
|
|
|
|
ncols, nrows = encoder_info.get_patch_grid_size(
|
|
image_width=image_size.width,
|
|
image_height=image_size.height,
|
|
)
|
|
|
|
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
|
tokens[-1] = image_end_id
|
|
|
|
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="image",
|
|
target=[image_token_id],
|
|
replacement=get_replacement,
|
|
),
|
|
]
|
|
|
|
|
|
def _build_mistral3_info(
|
|
ctx: InputProcessingContext,
|
|
) -> BaseLlavaProcessingInfo:
|
|
hf_config = ctx.get_hf_config(Mistral3Config)
|
|
assert isinstance(hf_config.vision_config, PixtralVisionConfig)
|
|
return Mistral3ProcessingInfo(ctx)
|
|
|
|
|
|
def _build_mistral3_processor(
|
|
info: _I,
|
|
dummy_inputs: BaseDummyInputsBuilder[_I],
|
|
*,
|
|
cache: BaseMultiModalProcessorCache | None = None,
|
|
) -> BaseMultiModalProcessor:
|
|
assert isinstance(info, Mistral3ProcessingInfo)
|
|
return Mistral3MultiModalProcessor(
|
|
info,
|
|
dummy_inputs, # type: ignore
|
|
cache=cache,
|
|
)
|
|
|
|
|
|
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
|
"""Determine the number of hidden layers to initialize up to in the
|
|
visual encoder.
|
|
|
|
Args:
|
|
hf_config: Model config with vision feature layer(s).
|
|
"""
|
|
feature_layers = hf_config.vision_feature_layer
|
|
num_hidden_layers = hf_config.vision_config.num_hidden_layers
|
|
# If we have one feature layer, initialize up to that layer
|
|
if isinstance(feature_layers, int):
|
|
return _get_layer_index(feature_layers, num_hidden_layers)
|
|
# If we have multiple feature layers, initialize up to the deepest one
|
|
elif isinstance(feature_layers, (list, tuple)):
|
|
return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
|
|
raise TypeError(
|
|
f"vision_layer_feature type: {type(feature_layers)} is not supported"
|
|
)
|
|
|
|
|
|
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
|
"""Given a signed vision feature layer, get the number of hidden layers
|
|
needed to leverage it.
|
|
|
|
Args:
|
|
feature_layer_index: Index of a required layer in the visual encoder.
|
|
num_hidden_layers: The total number of hidden layers in the visual
|
|
encoder.
|
|
"""
|
|
if feature_layer_index < 0:
|
|
return num_hidden_layers + feature_layer_index + 1
|
|
return feature_layer_index
|
|
|
|
|
|
def init_vision_tower_for_llava(
|
|
hf_config: LlavaLikeConfig,
|
|
quant_config: QuantizationConfig | None,
|
|
*,
|
|
require_post_norm: bool | None = None,
|
|
prefix: str = "",
|
|
) -> PixtralHFVisionModel:
|
|
vision_config = hf_config.vision_config
|
|
|
|
# Initialize the vision tower only up to the deepest required feature layer
|
|
num_hidden_layers = _get_num_hidden_layers(hf_config)
|
|
|
|
assert isinstance(vision_config, PixtralVisionConfig)
|
|
|
|
return PixtralHFVisionModel(
|
|
vision_config,
|
|
quant_config=quant_config,
|
|
num_hidden_layers_override=num_hidden_layers,
|
|
require_post_norm=require_post_norm,
|
|
prefix=prefix,
|
|
)
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
_build_mistral3_processor,
|
|
info=_build_mistral3_info,
|
|
dummy_inputs=Mistral3DummyInputsBuilder,
|
|
)
|
|
class Mistral3ForConditionalGeneration(
|
|
nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
|
|
):
|
|
merge_by_field_config = True
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
}
|
|
|
|
hf_to_vllm_mapper = WeightsMapper(
|
|
orig_to_new_prefix={
|
|
# mapping for new names in checkpoint saved after transformers v4.52
|
|
"model.language_model.": "language_model.model.",
|
|
"model.vision_tower.": "vision_tower.",
|
|
"model.multi_modal_projector.": "multi_modal_projector.",
|
|
"lm_head.": "language_model.lm_head.",
|
|
}
|
|
)
|
|
|
|
@classmethod
|
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
|
if modality.startswith("image"):
|
|
return None
|
|
|
|
raise ValueError("Only image modality is supported")
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
|
|
# NOTE: These are special cases for Pixtral-12B in the HF-format
|
|
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
|
|
if (
|
|
config.text_config.architectures is None
|
|
and config.text_config.model_type == "mistral"
|
|
):
|
|
config.text_config.architectures = ["MistralForCausalLM"]
|
|
if (
|
|
config.projector_hidden_act is None
|
|
and config.vision_config.hidden_act == "gelu"
|
|
):
|
|
config.projector_hidden_act = "gelu"
|
|
|
|
# TODO: Optionally initializes this for supporting embeddings.
|
|
if multimodal_config.get_limit_per_prompt("image"):
|
|
self.vision_tower = init_vision_tower_for_llava(
|
|
config,
|
|
quant_config,
|
|
require_post_norm=False,
|
|
prefix=maybe_prefix(prefix, "vision_tower"),
|
|
)
|
|
self.multi_modal_projector = Mistral3MultiModalProjector(
|
|
vision_hidden_size=config.vision_config.hidden_size,
|
|
text_hidden_size=config.text_config.hidden_size,
|
|
projector_hidden_act=config.projector_hidden_act,
|
|
spatial_merge_size=config.spatial_merge_size,
|
|
patch_size=config.vision_config.patch_size,
|
|
multimodal_projector_bias=config.multimodal_projector_bias,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
|
)
|
|
else:
|
|
self.vision_tower = None
|
|
self.multi_modal_projector = None
|
|
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors
|
|
)
|
|
|
|
def _parse_and_validate_image_input(
|
|
self, **kwargs: object
|
|
) -> Mistral3ImagePixelInputs | None:
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
image_embeds = kwargs.pop("image_embeds", None)
|
|
|
|
if pixel_values is None and image_embeds is None:
|
|
return None
|
|
|
|
return Mistral3ImagePixelInputs(
|
|
type="pixel_values_pixtral",
|
|
pixel_values=pixel_values,
|
|
)
|
|
|
|
def _process_image_input(
|
|
self,
|
|
image_input: Mistral3ImagePixelInputs,
|
|
) -> torch.Tensor | tuple[torch.Tensor, ...]:
|
|
if image_input["type"] == "image_embeds":
|
|
return image_input["data"]
|
|
|
|
image_sizes = [
|
|
(img.shape[-2], img.shape[-1]) for img in image_input["pixel_values"]
|
|
]
|
|
|
|
image_features = self.vision_tower(image_input["pixel_values"])
|
|
|
|
if isinstance(image_features, torch.Tensor):
|
|
return self.multi_modal_projector(image_features, image_sizes)
|
|
|
|
feature_sizes = [
|
|
image_feature.shape[0] // self.config.spatial_merge_size**2
|
|
for image_feature in image_features
|
|
]
|
|
|
|
image_embeds = self.multi_modal_projector(
|
|
torch.cat(image_features), image_sizes
|
|
)
|
|
if len(feature_sizes) > 1:
|
|
image_embeds = torch.split(image_embeds, feature_sizes)
|
|
else:
|
|
image_embeds = (image_embeds,)
|
|
return image_embeds
|
|
|
|
def get_language_model(self) -> torch.nn.Module:
|
|
return self.language_model
|
|
|
|
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
if image_input is None:
|
|
return []
|
|
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
|
|
return vision_embeddings
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
"""Run forward pass for Mistral3.
|
|
|
|
One key thing to understand is the `input_ids` already accounts for the
|
|
positions of the to-be-inserted image embeddings.
|
|
|
|
Concretely, consider a text prompt:
|
|
`"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
|
|
|
|
Tokenizer outputs:
|
|
`[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
|
|
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
|
|
|
|
To reserve space in KV cache, we have to insert placeholder tokens
|
|
before they are inputted to the model, so the input processor prepends
|
|
additional image tokens (denoted as `32000`), resulting in:
|
|
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
|
|
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
|
|
29901]`.
|
|
|
|
We insert 575 tokens so that including the original image token in the
|
|
input, there are a total of 576 (24 * 24) image tokens, which
|
|
corresponds to the number of image tokens inputted to the language
|
|
model, i.e. the number of image tokens outputted by the visual encoder.
|
|
|
|
This way, the `positions` and `attn_metadata` are consistent
|
|
with the `input_ids`.
|
|
|
|
Args:
|
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
|
batch.
|
|
positions: Position indices for the input tokens.
|
|
intermediate_tensors: Intermediate tensors from prior forward pass.
|
|
inputs_embeds: Optional tensor of input embeddings.
|
|
|
|
Info:
|
|
[`Mistral3ImagePixelInputs`][vllm.model_executor.models.mistral3.Mistral3ImagePixelInputs]
|
|
"""
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
hidden_states = self.language_model.model(
|
|
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
return self.language_model.compute_logits(hidden_states)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
skip_prefixes = []
|
|
if self.vision_tower is None and self.multi_modal_projector is None:
|
|
skip_prefixes = ["vision_tower.", "multi_modal_projector."]
|
|
|
|
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
|
|
|
def get_mm_mapping(self) -> MultiModelKeys:
|
|
"""
|
|
Get the module prefix in multimodal models
|
|
"""
|
|
return MultiModelKeys.from_string_field(
|
|
language_model="language_model",
|
|
connector="multi_modal_projector",
|
|
tower_model="vision_tower",
|
|
)
|