mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 03:54:59 +08:00
[Bugfix] Clean up MiniMax-VL and fix processing (#17354)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
890f104cdf
commit
00ee37efa2
@ -979,6 +979,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `MiniMaxVL01ForConditionalGeneration`
|
||||||
|
* MiniMax-VL
|
||||||
|
* T + I<sup>E+</sup>
|
||||||
|
* `MiniMaxAI/MiniMax-VL-01`, etc.
|
||||||
|
*
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎
|
||||||
- * `Mistral3ForConditionalGeneration`
|
- * `Mistral3ForConditionalGeneration`
|
||||||
* Mistral3
|
* Mistral3
|
||||||
* T + I<sup>+</sup>
|
* T + I<sup>+</sup>
|
||||||
|
|||||||
@ -270,6 +270,7 @@ def _test_processing_correctness_mistral(
|
|||||||
"openbmb/MiniCPM-Llama3-V-2_5",
|
"openbmb/MiniCPM-Llama3-V-2_5",
|
||||||
"openbmb/MiniCPM-o-2_6",
|
"openbmb/MiniCPM-o-2_6",
|
||||||
"openbmb/MiniCPM-V-2_6",
|
"openbmb/MiniCPM-V-2_6",
|
||||||
|
"MiniMaxAI/MiniMax-VL-01",
|
||||||
"allenai/Molmo-7B-D-0924",
|
"allenai/Molmo-7B-D-0924",
|
||||||
"allenai/Molmo-7B-O-0924",
|
"allenai/Molmo-7B-O-0924",
|
||||||
"nvidia/NVLM-D-72B",
|
"nvidia/NVLM-D-72B",
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from ...utils import build_model_context
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
|
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
|
||||||
# yapf: enable
|
|
||||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||||
def test_processor_override(
|
def test_processor_override(
|
||||||
image_assets: _ImageAssets,
|
image_assets: _ImageAssets,
|
||||||
|
|||||||
@ -1,52 +1,32 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from collections.abc import Iterable, Mapping
|
||||||
|
from typing import Literal, Optional, Set, Tuple, TypedDict, Union, cast
|
||||||
|
|
||||||
from abc import abstractmethod
|
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
|
|
||||||
TypeVar, Union, cast)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BatchFeature, CLIPVisionConfig, PretrainedConfig
|
from transformers import BatchFeature
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.jsontree import json_map_leaves
|
from vllm.jsontree import json_map_leaves
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
|
||||||
ImageSize, MultiModalDataItems)
|
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|
||||||
BaseProcessingInfo, PromptReplacement,
|
|
||||||
PromptUpdate)
|
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config
|
from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config
|
||||||
|
|
||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
|
from .llava import (BaseLlavaMultiModalProcessor, LlavaDummyInputsBuilder,
|
||||||
|
init_vision_tower_for_llava)
|
||||||
|
from .llava_next import LlavaNextProcessingInfo
|
||||||
from .pixtral import PixtralHFVisionModel
|
from .pixtral import PixtralHFVisionModel
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
from .vision import get_vision_encoder_info
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# For dummy input only
|
|
||||||
@dataclass
|
|
||||||
class MaxImageTokenMeta:
|
|
||||||
width: int = 1024
|
|
||||||
height: int = 1024
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxVL01ImagePixelInputs(TypedDict):
|
class MiniMaxVL01ImagePixelInputs(TypedDict):
|
||||||
@ -69,66 +49,8 @@ class MiniMaxVL01ImageEmbeddingInputs(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs,
|
||||||
if not isinstance(grid_pinpoints, list):
|
MiniMaxVL01ImageEmbeddingInputs]
|
||||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
|
||||||
|
|
||||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple,
|
|
||||||
# otherwise it will cause wrong calculate
|
|
||||||
if not isinstance(image_size, (list, tuple)):
|
|
||||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
|
||||||
raise TypeError("image_size invalid type " +
|
|
||||||
f"{type(image_size)} with value {image_size}")
|
|
||||||
image_size = image_size.tolist()
|
|
||||||
|
|
||||||
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
|
||||||
height, width = best_resolution
|
|
||||||
num_patches = 0
|
|
||||||
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
|
|
||||||
for i in range(0, height, patch_size):
|
|
||||||
for j in range(0, width, patch_size):
|
|
||||||
num_patches += 1
|
|
||||||
# add the base patch
|
|
||||||
num_patches += 1
|
|
||||||
return num_patches
|
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|
||||||
if not isinstance(grid_pinpoints, list):
|
|
||||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
|
||||||
|
|
||||||
# ! VERY IMPORTANT if image_size is tensor,
|
|
||||||
# must convert to into tuple,
|
|
||||||
# otherwise it will cause wrong calculate
|
|
||||||
if not isinstance(image_size, (list, tuple)):
|
|
||||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
|
||||||
raise TypeError(
|
|
||||||
"image_size invalid type " +
|
|
||||||
f"{type(image_size)} not valid, " +
|
|
||||||
"should be either list, tuple, np.ndarray or tensor")
|
|
||||||
image_size = image_size.tolist()
|
|
||||||
|
|
||||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
|
||||||
return height // patch_size, width // patch_size
|
|
||||||
|
|
||||||
|
|
||||||
def unpad_image(tensor, original_size):
|
|
||||||
original_height, original_width = original_size
|
|
||||||
current_height, current_width = tensor.shape[1:]
|
|
||||||
|
|
||||||
original_aspect_ratio = original_width / original_height
|
|
||||||
current_aspect_ratio = current_width / current_height
|
|
||||||
|
|
||||||
if original_aspect_ratio > current_aspect_ratio:
|
|
||||||
new_height = int(original_height * current_width) // original_width
|
|
||||||
padding = (current_height - new_height) // 2
|
|
||||||
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
|
||||||
else:
|
|
||||||
new_width = int(original_width * current_height) // original_height
|
|
||||||
padding = (current_width - new_width) // 2
|
|
||||||
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
|
||||||
|
|
||||||
return unpadded_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxVL01MultiModalProjector(nn.Module):
|
class MiniMaxVL01MultiModalProjector(nn.Module):
|
||||||
@ -161,144 +83,29 @@ class MiniMaxVL01MultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxVL01LikeConfig(Protocol):
|
class MiniMaxVL01DummyInputsBuilder(LlavaDummyInputsBuilder):
|
||||||
vision_config: Final[PretrainedConfig]
|
pass
|
||||||
image_token_index: Final[int]
|
|
||||||
vision_feature_select_strategy: Final[str]
|
|
||||||
vision_feature_layer: Final[Union[int, list[int]]]
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxVL01LikeProcessor(Protocol):
|
class MiniMaxVL01ProcessingInfo(LlavaNextProcessingInfo):
|
||||||
image_token: Final[str]
|
|
||||||
|
|
||||||
|
|
||||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxVL01DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
|
||||||
|
|
||||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
||||||
num_images = mm_counts.get("image", 0)
|
|
||||||
processor = self.info.get_hf_processor()
|
|
||||||
image_token = processor.image_token
|
|
||||||
return image_token * num_images
|
|
||||||
|
|
||||||
def get_dummy_mm_data(
|
|
||||||
self,
|
|
||||||
seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int],
|
|
||||||
) -> MultiModalDataDict:
|
|
||||||
num_images = mm_counts.get("image", 0)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"image":
|
|
||||||
self._get_dummy_images(width=MaxImageTokenMeta.width,
|
|
||||||
height=MaxImageTokenMeta.height,
|
|
||||||
num_images=num_images)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxVL01ProcessingInfo(BaseProcessingInfo):
|
|
||||||
|
|
||||||
def get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config(MiniMaxVL01Config)
|
return self.ctx.get_hf_config(MiniMaxVL01Config)
|
||||||
|
|
||||||
|
def get_hf_processor(self, **kwargs: object):
|
||||||
|
hf_processor = self.ctx.get_hf_processor(**kwargs)
|
||||||
|
image_processor = hf_processor.image_processor
|
||||||
|
image_processor.anyres_preprocess = (
|
||||||
|
image_processor.anyres_for_vllm_preprocess)
|
||||||
|
|
||||||
|
return hf_processor
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"image": None}
|
return {"image": None}
|
||||||
|
|
||||||
def get_vision_encoder_info(self):
|
|
||||||
return get_vision_encoder_info(self.get_hf_config())
|
|
||||||
|
|
||||||
def _apply_feature_select_strategy(
|
|
||||||
self,
|
|
||||||
strategy: str,
|
|
||||||
encoder_num_image_tokens: int,
|
|
||||||
) -> int:
|
|
||||||
if strategy == "default":
|
|
||||||
return encoder_num_image_tokens - 1
|
|
||||||
if strategy == "full":
|
|
||||||
return encoder_num_image_tokens
|
|
||||||
|
|
||||||
msg = f"Unexpected feature select strategy: {strategy!r}"
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
def get_num_image_tokens(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
image_width: int,
|
|
||||||
image_height: int,
|
|
||||||
) -> int:
|
|
||||||
hf_config = self.get_hf_config()
|
|
||||||
vision_encoder_info = self.get_vision_encoder_info()
|
|
||||||
|
|
||||||
return self._apply_feature_select_strategy(
|
|
||||||
hf_config.vision_feature_select_strategy,
|
|
||||||
vision_encoder_info.get_num_image_tokens(
|
|
||||||
image_width=image_width,
|
|
||||||
image_height=image_height,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
vision_encoder_info = self.get_vision_encoder_info()
|
|
||||||
width = height = vision_encoder_info.get_image_size()
|
|
||||||
return ImageSize(width=width, height=height)
|
|
||||||
|
|
||||||
def get_max_image_tokens(self) -> int:
|
|
||||||
target_width, target_height = self.get_image_size_with_most_features()
|
|
||||||
|
|
||||||
return self.get_num_image_tokens(
|
|
||||||
image_width=target_width,
|
|
||||||
image_height=target_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMiniMaxVL01MultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|
||||||
|
|
||||||
# Copied from BaseMultiModalProcessor
|
|
||||||
@abstractmethod
|
|
||||||
def _get_mm_fields_config(
|
|
||||||
self,
|
|
||||||
hf_inputs: BatchFeature,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _get_prompt_updates(
|
|
||||||
self,
|
|
||||||
mm_items: MultiModalDataItems,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
|
||||||
) -> Sequence[PromptUpdate]:
|
|
||||||
hf_config = self.info.get_hf_config()
|
|
||||||
image_token_id = hf_config.image_token_index
|
|
||||||
|
|
||||||
def get_replacement(item_idx: int):
|
|
||||||
images = mm_items.get_items(
|
|
||||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
|
||||||
|
|
||||||
if isinstance(images, ImageEmbeddingItems):
|
|
||||||
num_image_tokens = images.get_feature_size(item_idx)
|
|
||||||
else:
|
|
||||||
image_size = images.get_image_size(item_idx)
|
|
||||||
num_image_tokens = self.info.get_num_image_tokens(
|
|
||||||
image_width=image_size.width,
|
|
||||||
image_height=image_size.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [image_token_id] * num_image_tokens
|
|
||||||
|
|
||||||
return [
|
|
||||||
PromptReplacement(
|
|
||||||
modality="image",
|
|
||||||
target=[image_token_id],
|
|
||||||
replacement=get_replacement,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class MiniMaxVL01MultiModalProcessor(
|
class MiniMaxVL01MultiModalProcessor(
|
||||||
BaseMiniMaxVL01MultiModalProcessor[MiniMaxVL01ProcessingInfo]):
|
BaseLlavaMultiModalProcessor[MiniMaxVL01ProcessingInfo]):
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
self,
|
self,
|
||||||
@ -314,10 +121,9 @@ class MiniMaxVL01MultiModalProcessor(
|
|||||||
|
|
||||||
pixel_values = processed_outputs.get("pixel_values")
|
pixel_values = processed_outputs.get("pixel_values")
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
|
# Avoid padding since we need the output for each image to be
|
||||||
|
# independent of other images for the cache to work correctly
|
||||||
image_sizes = processed_outputs["image_sizes"]
|
image_sizes = processed_outputs["image_sizes"]
|
||||||
min_len = min(len(pixel_values), len(image_sizes))
|
|
||||||
pixel_values = pixel_values[:min_len]
|
|
||||||
image_sizes = image_sizes[:min_len]
|
|
||||||
assert len(pixel_values) == len(image_sizes)
|
assert len(pixel_values) == len(image_sizes)
|
||||||
|
|
||||||
processed_outputs["pixel_values"] = [
|
processed_outputs["pixel_values"] = [
|
||||||
@ -337,65 +143,6 @@ class MiniMaxVL01MultiModalProcessor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_num_hidden_layers(hf_config: MiniMaxVL01LikeConfig) -> int:
|
|
||||||
"""Determine the number of hidden layers to initialize up to in the
|
|
||||||
visual encoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hf_config: Model config with vision feature layer(s).
|
|
||||||
"""
|
|
||||||
feature_layers = hf_config.vision_feature_layer
|
|
||||||
num_hidden_layers = hf_config.vision_config.num_hidden_layers
|
|
||||||
# If we have one feature layer, initialize up to that layer
|
|
||||||
if isinstance(feature_layers, int):
|
|
||||||
return _get_layer_index(feature_layers, num_hidden_layers)
|
|
||||||
# If we have multiple feature layers, initialize up to the deepest one
|
|
||||||
elif isinstance(feature_layers, (list, tuple)):
|
|
||||||
return max(
|
|
||||||
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
|
|
||||||
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
|
|
||||||
" is not supported")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
|
||||||
"""Given a signed vision feature layer, get the number of hidden layers
|
|
||||||
needed to leverage it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
feature_layer_index: Index of a required layer in the visual encoder.
|
|
||||||
num_hidden_layers: The total number of hidden layers in the visual
|
|
||||||
encoder.
|
|
||||||
"""
|
|
||||||
if feature_layer_index < 0:
|
|
||||||
return num_hidden_layers + feature_layer_index + 1
|
|
||||||
return feature_layer_index
|
|
||||||
|
|
||||||
|
|
||||||
def init_vision_tower_for_MiniMaxVL01(
|
|
||||||
hf_config: MiniMaxVL01LikeConfig,
|
|
||||||
quant_config: Optional[QuantizationConfig],
|
|
||||||
*,
|
|
||||||
require_post_norm: Optional[bool] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
|
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
|
|
||||||
# Initialize the vision tower only up to the deepest required feature layer
|
|
||||||
num_hidden_layers = _get_num_hidden_layers(hf_config)
|
|
||||||
|
|
||||||
if isinstance(vision_config, CLIPVisionConfig):
|
|
||||||
return CLIPVisionModel(
|
|
||||||
vision_config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
num_hidden_layers_override=num_hidden_layers,
|
|
||||||
require_post_norm=require_post_norm,
|
|
||||||
prefix=prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
|
||||||
raise NotImplementedError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
MiniMaxVL01MultiModalProcessor,
|
MiniMaxVL01MultiModalProcessor,
|
||||||
info=MiniMaxVL01ProcessingInfo,
|
info=MiniMaxVL01ProcessingInfo,
|
||||||
@ -419,7 +166,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.multimodal_config = multimodal_config
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
# TODO: Optionally initializes this for supporting embeddings.
|
# TODO: Optionally initializes this for supporting embeddings.
|
||||||
self.vision_tower = init_vision_tower_for_MiniMaxVL01(
|
self.vision_tower = init_vision_tower_for_llava(
|
||||||
config,
|
config,
|
||||||
quant_config,
|
quant_config,
|
||||||
require_post_norm=False,
|
require_post_norm=False,
|
||||||
@ -476,7 +223,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
def _image_pixels_to_features(
|
def _image_pixels_to_features(
|
||||||
self,
|
self,
|
||||||
vision_tower: Union[CLIPVisionModel],
|
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
|
||||||
|
PixtralHFVisionModel],
|
||||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
|
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
# NOTE: we skip the step to select the vision feature layer since
|
# NOTE: we skip the step to select the vision feature layer since
|
||||||
@ -496,7 +244,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
def _process_image_pixels(
|
def _process_image_pixels(
|
||||||
self,
|
self,
|
||||||
inputs: Union[MiniMaxVL01ImagePixelInputs],
|
inputs: MiniMaxVL01ImagePixelInputs,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
@ -506,7 +254,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self,
|
self,
|
||||||
image_input: MiniMaxVL01ImagePixelInputs,
|
image_input: MiniMaxVL01ImageInputs,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
return image_input["data"]
|
return image_input["data"]
|
||||||
@ -539,7 +287,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[MiniMaxVL01ImagePixelInputs]:
|
self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user