mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 11:11:19 +08:00
[Model] MiniCPM-V/O supports V1 (#15487)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8063dfc61a
commit
ac5bc615b0
@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* `openbmb/MiniCPM-o-2_6`, etc.
|
* `openbmb/MiniCPM-o-2_6`, etc.
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
*
|
* ✅︎
|
||||||
- * `MiniCPMV`
|
- * `MiniCPMV`
|
||||||
* MiniCPM-V
|
* MiniCPM-V
|
||||||
* T + I<sup>E+</sup> + V<sup>E+</sup>
|
* T + I<sup>E+</sup> + V<sup>E+</sup>
|
||||||
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
|
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
*
|
* ✅︎
|
||||||
- * `MllamaForConditionalGeneration`
|
- * `MllamaForConditionalGeneration`
|
||||||
* Llama 3.2
|
* Llama 3.2
|
||||||
* T + I<sup>+</sup>
|
* T + I<sup>+</sup>
|
||||||
|
|||||||
@ -23,8 +23,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
|
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
|
||||||
TypedDict, Union)
|
Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -42,8 +42,6 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
|||||||
MultiModalDataParser)
|
MultiModalDataParser)
|
||||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||||
from vllm.multimodal.profiling import ProcessorInputs
|
from vllm.multimodal.profiling import ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
from vllm.utils import flatten_2d_lists
|
|
||||||
|
|
||||||
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||||
MiniCPMVMultiModalDataParser,
|
MiniCPMVMultiModalDataParser,
|
||||||
@ -51,13 +49,14 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
|||||||
_minicpmv_field_config)
|
_minicpmv_field_config)
|
||||||
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
from .vision import scatter_patch_features
|
||||||
|
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMOAudioFeatureInputs(TypedDict):
|
class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||||
type: Literal["audio_features"]
|
type: Literal["audio_features"]
|
||||||
audio_features: torch.Tensor
|
audio_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
|
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
|
||||||
Slice here means chunk. Audio that is too long will be split into slices,
|
Slice here means chunk. Audio that is too long will be split into slices,
|
||||||
@ -65,37 +64,40 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
|
|||||||
Padding is used therefore `audio_features` is `torch.Tensor`.
|
Padding is used therefore `audio_features` is `torch.Tensor`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio_feature_lens: torch.Tensor
|
audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_audios * num_slices)`
|
Shape: `(batch_size * num_audios, num_slices)`
|
||||||
|
|
||||||
This should be feature length of each audio slice,
|
This should be feature length of each audio slice,
|
||||||
which equals to `audio_features.shape[-1]`
|
which equals to `audio_features.shape[-1]`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
audio_bounds: torch.Tensor
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_audios * num_slices, 2)`
|
A boolean mask indicating which audio embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
This should be in `(start, stop)` format.
|
Shape: `(batch_size * num_audios, num_embeds)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||||
type: Literal["audio_embeds"]
|
type: Literal["audio_embeds"]
|
||||||
audio_embeds: torch.Tensor
|
audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_images * num_slices, hidden_size)`
|
Shape: `(batch_size * num_audios, num_slices, hidden_size)`
|
||||||
|
|
||||||
`hidden_size` must match the hidden size of language model backbone.
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
instead of a batched tensor.
|
instead of a batched tensor.
|
||||||
Length of each slice may vary, so pass it as a list.
|
Length of each slice may vary, so pass it as a list.
|
||||||
"""
|
"""
|
||||||
audio_bounds: torch.Tensor
|
|
||||||
"""
|
|
||||||
Shape: `(batch_size * num_audios * num_slices, 2)`
|
|
||||||
|
|
||||||
This should be in `(start, stop)` format.
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which audio embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_audios, num_embeds)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -104,11 +106,16 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
|||||||
|
|
||||||
|
|
||||||
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||||
|
audio_features = hf_inputs.get("audio_features", torch.empty(0))
|
||||||
|
num_audios = len(audio_features)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
**_minicpmv_field_config(hf_inputs),
|
**_minicpmv_field_config(hf_inputs),
|
||||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||||
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
||||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||||
|
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
|
||||||
|
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -149,7 +156,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
|||||||
audio_pattern = "(<audio>./</audio>)"
|
audio_pattern = "(<audio>./</audio>)"
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"image": None, "video": None, "audio": None}
|
return {**super().get_supported_mm_limits(), "audio": None}
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(
|
def get_mm_max_tokens_per_item(
|
||||||
self,
|
self,
|
||||||
@ -157,11 +164,25 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
|||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> Mapping[str, int]:
|
) -> Mapping[str, int]:
|
||||||
return {
|
return {
|
||||||
"image": self.get_max_image_tokens(),
|
**super().get_mm_max_tokens_per_item(seq_len, mm_counts),
|
||||||
"audio": self.get_max_audio_tokens(),
|
"audio":
|
||||||
"video": self.get_max_video_tokens(seq_len),
|
self.get_max_audio_tokens(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_audio_placeholder(
|
||||||
|
self,
|
||||||
|
audio_lens: int,
|
||||||
|
chunk_input: bool = True,
|
||||||
|
chunk_length: int = 1,
|
||||||
|
) -> str:
|
||||||
|
hf_processor = self.get_hf_processor()
|
||||||
|
|
||||||
|
return hf_processor.get_audio_placeholder(
|
||||||
|
audio_lens,
|
||||||
|
chunk_input=chunk_input,
|
||||||
|
chunk_length=chunk_length,
|
||||||
|
)
|
||||||
|
|
||||||
def get_default_audio_pool_step(self) -> int:
|
def get_default_audio_pool_step(self) -> int:
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
@ -197,12 +218,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
|||||||
max_videos = mm_config.get_limit_per_prompt("video")
|
max_videos = mm_config.get_limit_per_prompt("video")
|
||||||
max_audios = mm_config.get_limit_per_prompt("audio")
|
max_audios = mm_config.get_limit_per_prompt("audio")
|
||||||
|
|
||||||
# count <image_idx></image_idx> tokens
|
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||||
# which are not in get_max_image_tokens
|
max_audio_tokens = self.get_max_audio_tokens() * max_audios
|
||||||
max_image_tokens = self.get_max_image_tokens(
|
|
||||||
) * max_images + 4 * max_images
|
|
||||||
max_audio_tokens = self.get_max_audio_tokens(
|
|
||||||
) * max_audios + 2 * max_audios
|
|
||||||
max_total_frames = self.get_max_video_frames(seq_len -
|
max_total_frames = self.get_max_video_frames(seq_len -
|
||||||
max_image_tokens -
|
max_image_tokens -
|
||||||
max_audio_tokens)
|
max_audio_tokens)
|
||||||
@ -224,20 +241,20 @@ class MiniCPMODummyInputsBuilder(
|
|||||||
|
|
||||||
processor_inputs = super().get_dummy_processor_inputs(
|
processor_inputs = super().get_dummy_processor_inputs(
|
||||||
seq_len, mm_counts)
|
seq_len, mm_counts)
|
||||||
mm_data = {
|
|
||||||
"image":
|
audio_prompt_texts = self.info.audio_pattern * num_audios
|
||||||
processor_inputs.mm_data["image"],
|
audio_mm_data = {
|
||||||
"video":
|
|
||||||
processor_inputs.mm_data["video"],
|
|
||||||
"audio":
|
"audio":
|
||||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||||
}
|
}
|
||||||
|
|
||||||
audio_prompt_texts = self.info.audio_pattern * num_audios
|
return ProcessorInputs(
|
||||||
|
prompt_text=processor_inputs.prompt_text + audio_prompt_texts,
|
||||||
return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \
|
mm_data={
|
||||||
audio_prompt_texts,
|
**processor_inputs.mm_data,
|
||||||
mm_data=mm_data)
|
**audio_mm_data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMOMultiModalProcessor(
|
class MiniCPMOMultiModalProcessor(
|
||||||
@ -247,22 +264,17 @@ class MiniCPMOMultiModalProcessor(
|
|||||||
return MiniCPMOMultiModalDataParser(
|
return MiniCPMOMultiModalDataParser(
|
||||||
target_sr=self.info.get_default_audio_sampling_rate())
|
target_sr=self.info.get_default_audio_sampling_rate())
|
||||||
|
|
||||||
def get_audio_prompt_texts(self,
|
def get_audio_prompt_texts(
|
||||||
audio_lens: int,
|
self,
|
||||||
chunk_input: bool = True,
|
audio_lens: int,
|
||||||
chunk_length: int = 1) -> str:
|
chunk_input: bool = True,
|
||||||
return self.info.get_hf_processor().get_audio_placeholder(
|
chunk_length: int = 1,
|
||||||
audio_lens, chunk_input, chunk_length)
|
) -> str:
|
||||||
|
return self.info.get_audio_placeholder(
|
||||||
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
|
audio_lens,
|
||||||
tokenizer = self.info.get_tokenizer()
|
chunk_input=chunk_input,
|
||||||
special_tokens = super().get_special_tokens()
|
chunk_length=chunk_length,
|
||||||
if hasattr(tokenizer, "audio_start_id"):
|
)
|
||||||
special_tokens["audio_start_id"] = torch.tensor(
|
|
||||||
tokenizer.audio_start_id)
|
|
||||||
special_tokens["audio_end_id"] = torch.tensor(
|
|
||||||
tokenizer.audio_end_id)
|
|
||||||
return special_tokens
|
|
||||||
|
|
||||||
def process_audios(
|
def process_audios(
|
||||||
self,
|
self,
|
||||||
@ -274,32 +286,65 @@ class MiniCPMOMultiModalProcessor(
|
|||||||
|
|
||||||
parsed_audios = (self._get_data_parser().parse_mm_data({
|
parsed_audios = (self._get_data_parser().parse_mm_data({
|
||||||
"audio": audios
|
"audio": audios
|
||||||
}).get_items("audio", AudioProcessorItems))
|
}).get_items("audio",
|
||||||
|
(MiniCPMOAudioEmbeddingItems, AudioProcessorItems)))
|
||||||
|
|
||||||
audio_inputs = self._base_call_hf_processor(
|
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
|
||||||
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
audio_inputs = {}
|
||||||
mm_data={"audios": [[audio] for audio in parsed_audios]},
|
|
||||||
mm_kwargs={
|
|
||||||
**mm_kwargs, "chunk_input": True
|
|
||||||
},
|
|
||||||
out_keys={"audio_features", "audio_feature_lens"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Avoid padding since we need the output for each audio to be
|
audio_lens = [
|
||||||
# independent of other audios for the cache to work correctly
|
self.info.get_audio_len_by_num_chunks(
|
||||||
unpadded_audio_features = [
|
sum(map(len,
|
||||||
feat[:, :feature_len] for feat, feature_len in zip(
|
parsed_audios.get(i)["audio_embeds"])))
|
||||||
audio_inputs["audio_features"],
|
for i in range(len(parsed_audios))
|
||||||
audio_inputs["audio_feature_lens"],
|
]
|
||||||
|
else:
|
||||||
|
audio_inputs = self._base_call_hf_processor(
|
||||||
|
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
||||||
|
mm_data={"audios": [[audio] for audio in parsed_audios]},
|
||||||
|
mm_kwargs={
|
||||||
|
**mm_kwargs,
|
||||||
|
"chunk_input": True,
|
||||||
|
},
|
||||||
|
out_keys={"audio_features", "audio_feature_lens"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Avoid padding since we need the output for each audio to be
|
||||||
|
# independent of other audios for the cache to work correctly
|
||||||
|
unpadded_audio_features = [
|
||||||
|
feat[:, :feature_len] for feat, feature_len in zip(
|
||||||
|
audio_inputs["audio_features"],
|
||||||
|
audio_inputs["audio_feature_lens"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
audio_inputs["audio_features"] = unpadded_audio_features
|
||||||
|
|
||||||
|
audio_lens = [
|
||||||
|
parsed_audios.get_audio_length(i)
|
||||||
|
for i in range(len(parsed_audios))
|
||||||
|
]
|
||||||
|
|
||||||
|
audio_repl_features = [
|
||||||
|
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
|
||||||
]
|
]
|
||||||
audio_inputs["audio_features"] = unpadded_audio_features
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
audio_repls_feature_tokens = [
|
||||||
|
tokenizer.encode(audio_repl, add_special_tokens=False)
|
||||||
|
for audio_repl in audio_repl_features
|
||||||
|
]
|
||||||
|
|
||||||
|
embed_is_patch = [
|
||||||
|
self.get_embed_is_patch(audio_repl_tokens)
|
||||||
|
for audio_repl_tokens in audio_repls_feature_tokens
|
||||||
|
]
|
||||||
|
audio_inputs["audio_embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
|
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||||
|
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
|
||||||
|
|
||||||
return audio_inputs
|
return audio_inputs
|
||||||
|
|
||||||
def get_placeholder_match_pattern(self) -> str:
|
|
||||||
return r"\(<(image|video|audio)>./</\1>\)"
|
|
||||||
|
|
||||||
def process_mm_inputs(
|
def process_mm_inputs(
|
||||||
self,
|
self,
|
||||||
mm_data: Mapping[str, object],
|
mm_data: Mapping[str, object],
|
||||||
@ -331,8 +376,7 @@ class MiniCPMOMultiModalProcessor(
|
|||||||
if isinstance(audios, MiniCPMOAudioEmbeddingItems):
|
if isinstance(audios, MiniCPMOAudioEmbeddingItems):
|
||||||
single_audio_embeds = audios.get(item_idx)["audio_embeds"]
|
single_audio_embeds = audios.get(item_idx)["audio_embeds"]
|
||||||
audio_len = self.info.get_audio_len_by_num_chunks(
|
audio_len = self.info.get_audio_len_by_num_chunks(
|
||||||
sum(chunk_embeds.shape[0]
|
sum(map(len, single_audio_embeds)))
|
||||||
for chunk_embeds in single_audio_embeds))
|
|
||||||
else:
|
else:
|
||||||
audio_len = audios.get_audio_length(item_idx)
|
audio_len = audios.get_audio_length(item_idx)
|
||||||
|
|
||||||
@ -514,6 +558,8 @@ class MiniCPMO(MiniCPMV2_6):
|
|||||||
self.apm = self.init_audio_module(vllm_config=vllm_config,
|
self.apm = self.init_audio_module(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "apm"))
|
prefix=maybe_prefix(prefix, "apm"))
|
||||||
|
|
||||||
|
self.audio_token_id = None
|
||||||
|
|
||||||
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
# Do not use parameters temporarily
|
# Do not use parameters temporarily
|
||||||
audio_config = self.config.audio_config
|
audio_config = self.config.audio_config
|
||||||
@ -563,18 +609,30 @@ class MiniCPMO(MiniCPMV2_6):
|
|||||||
|
|
||||||
return input_lengths_after_cnn, input_lengths_after_pooling
|
return input_lengths_after_cnn, input_lengths_after_pooling
|
||||||
|
|
||||||
# Copied from HF repo of MiniCPM-o-2_6,
|
def get_audio_hidden_states(
|
||||||
# designed for batched inputs and outputs
|
self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]:
|
||||||
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
|
chunk_length = self.config.audio_chunk_length
|
||||||
chunk_length: int) -> list[torch.Tensor]:
|
|
||||||
wavforms = data.get(
|
|
||||||
"audio_features",
|
|
||||||
[]) # (bs, 80, frames) or [], multi audios need filled in advance
|
|
||||||
audio_feature_lens_raw = [data.get("audio_feature_lens",
|
|
||||||
[])] # list, [[x1, x2], [y1], [z1]]
|
|
||||||
|
|
||||||
if len(wavforms) == 0:
|
# (bs, 80, frames) or [], multi audios need filled in advance
|
||||||
return []
|
wavforms_raw = data["audio_features"]
|
||||||
|
if isinstance(wavforms_raw, list):
|
||||||
|
B = len(wavforms_raw)
|
||||||
|
C = wavforms_raw[0].shape[-2]
|
||||||
|
L = max(item.shape[-1] for item in wavforms_raw)
|
||||||
|
device = wavforms_raw[0].device
|
||||||
|
dtype = wavforms_raw[0].dtype
|
||||||
|
|
||||||
|
wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
|
||||||
|
for i, wavforms_item in enumerate(wavforms_raw):
|
||||||
|
L_item = wavforms_item.shape[-1]
|
||||||
|
wavforms[i, ..., :L_item] = wavforms_item
|
||||||
|
else:
|
||||||
|
wavforms = wavforms_raw
|
||||||
|
|
||||||
|
# list, [[x1, x2], [y1], [z1]]
|
||||||
|
audio_feature_lens_raw = data["audio_feature_lens"]
|
||||||
|
if isinstance(audio_feature_lens_raw, torch.Tensor):
|
||||||
|
audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
|
||||||
|
|
||||||
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
||||||
batch_size, _, max_mel_seq_len = wavforms.shape
|
batch_size, _, max_mel_seq_len = wavforms.shape
|
||||||
@ -625,159 +683,104 @@ class MiniCPMO(MiniCPMV2_6):
|
|||||||
|
|
||||||
num_audio_tokens = feature_lens_after_pooling
|
num_audio_tokens = feature_lens_after_pooling
|
||||||
|
|
||||||
final_audio_embeds = []
|
final_audio_embeds = list[torch.Tensor]()
|
||||||
idx = 0
|
idx = 0
|
||||||
for i in range(len(audio_feature_lens_raw)):
|
for i in range(len(audio_feature_lens_raw)):
|
||||||
target_audio_embeds = []
|
target_audio_embeds_lst = list[torch.Tensor]()
|
||||||
for _ in range(len(audio_feature_lens_raw[i])):
|
for _ in range(len(audio_feature_lens_raw[i])):
|
||||||
target_audio_embeds.append(
|
target_audio_embeds_lst.append(
|
||||||
audio_embeds[idx, :num_audio_tokens[idx], :])
|
audio_embeds[idx, :num_audio_tokens[idx], :])
|
||||||
idx += 1
|
idx += 1
|
||||||
final_audio_embeds.append(target_audio_embeds)
|
|
||||||
|
final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
|
||||||
|
|
||||||
return final_audio_embeds
|
return final_audio_embeds
|
||||||
|
|
||||||
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
|
def _parse_and_validate_audio_input(
|
||||||
audio_inputs: MiniCPMOAudioInputs,
|
self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]:
|
||||||
chunk_length: int) -> torch.Tensor:
|
|
||||||
device, dtype = vlm_embedding.device, vlm_embedding.dtype
|
|
||||||
if audio_inputs["type"] == "audio_embeds":
|
|
||||||
audio_embeddings = [
|
|
||||||
item.to(device=device, dtype=dtype)
|
|
||||||
for item in audio_inputs["audio_embeds"]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
audio_embeddings = self.get_audio_hidden_states(
|
|
||||||
audio_inputs, chunk_length)[0]
|
|
||||||
if audio_embeddings is None or len(audio_embeddings) == 0:
|
|
||||||
return vlm_embedding
|
|
||||||
audio_bounds = audio_inputs["audio_bounds"]
|
|
||||||
if self.config.chunk_input:
|
|
||||||
audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device,
|
|
||||||
dtype=dtype)
|
|
||||||
audio_start_pos = 0
|
|
||||||
for bound in audio_bounds:
|
|
||||||
audio_len = bound[1] - bound[0]
|
|
||||||
vlm_embedding[bound[0]:bound[1]] = audio_embs[
|
|
||||||
audio_start_pos:audio_start_pos + audio_len, :]
|
|
||||||
audio_start_pos += audio_len
|
|
||||||
else:
|
|
||||||
for embs, bound in zip(audio_embeddings, audio_bounds):
|
|
||||||
audio_indices = torch.arange(bound[0],
|
|
||||||
bound[1],
|
|
||||||
dtype=torch.long).to(device)
|
|
||||||
|
|
||||||
if embs.shape[0] != len(audio_indices):
|
|
||||||
raise ValueError(
|
|
||||||
"Shape mismatch: Trying to assign embeddings "
|
|
||||||
f"of shape {embs.shape} "
|
|
||||||
f"to input indices of length {len(audio_indices)}")
|
|
||||||
vlm_embedding[audio_indices] = embs.to(dtype)
|
|
||||||
return vlm_embedding
|
|
||||||
|
|
||||||
def _get_audio_bounds(self, input_ids: torch.Tensor,
|
|
||||||
audio_start_id: torch.Tensor,
|
|
||||||
audio_end_id: torch.Tensor) -> torch.Tensor:
|
|
||||||
audio_start_tokens, = torch.where(input_ids == audio_start_id[0])
|
|
||||||
audio_start_tokens += 1
|
|
||||||
audio_end_tokens, = torch.where(input_ids == audio_end_id[0])
|
|
||||||
valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens))
|
|
||||||
return torch.hstack([
|
|
||||||
audio_start_tokens[:valid_audio_nums].unsqueeze(-1),
|
|
||||||
audio_end_tokens[:valid_audio_nums].unsqueeze(-1)
|
|
||||||
])
|
|
||||||
|
|
||||||
def _parse_and_validate_audio_inputs(
|
|
||||||
self, input_ids: torch.Tensor,
|
|
||||||
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
|
|
||||||
audio_features = kwargs.pop("audio_features", None)
|
audio_features = kwargs.pop("audio_features", None)
|
||||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||||
|
|
||||||
if audio_features is None and audio_embeds is None:
|
if audio_features is None and audio_embeds is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
audio_start_id = kwargs.pop("audio_start_id")
|
audio_token_id = kwargs.pop("audio_token_id")
|
||||||
if not isinstance(audio_start_id, torch.Tensor):
|
if audio_token_id is not None:
|
||||||
raise ValueError("Incorrect type of audio_start_id. "
|
assert isinstance(audio_token_id, torch.Tensor)
|
||||||
f"Got type: {type(audio_start_id)}")
|
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
|
||||||
|
|
||||||
audio_end_id = kwargs.pop("audio_end_id")
|
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
|
||||||
if not isinstance(audio_end_id, torch.Tensor):
|
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of audio_end_id. "
|
raise ValueError("Incorrect type of audio_embed_is_patch. "
|
||||||
f"Got type: {type(audio_end_id)}")
|
f"Got type: {type(audio_embed_is_patch)}")
|
||||||
|
|
||||||
|
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
|
||||||
|
|
||||||
if audio_embeds is not None:
|
if audio_embeds is not None:
|
||||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of audio_embeds. "
|
raise ValueError("Incorrect type of audio_embeds. "
|
||||||
f"Got type: {type(audio_embeds)}")
|
f"Got type: {type(audio_embeds)}")
|
||||||
|
|
||||||
|
audio_embeds_flat = flatten_bn(audio_embeds)
|
||||||
|
|
||||||
return MiniCPMOAudioEmbeddingInputs(
|
return MiniCPMOAudioEmbeddingInputs(
|
||||||
type="audio_embeds",
|
type="audio_embeds",
|
||||||
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
|
audio_embeds=audio_embeds_flat,
|
||||||
concat=True),
|
embed_is_patch=audio_embed_is_patch,
|
||||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
|
||||||
audio_end_id),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if audio_features is not None:
|
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
raise ValueError("Incorrect type of audio_features. "
|
||||||
raise ValueError("Incorrect type of audio_features. "
|
f"Got type: {type(audio_features)}")
|
||||||
f"Got type: {type(audio_features)}")
|
|
||||||
|
|
||||||
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
||||||
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of audio_feature_lens. "
|
raise ValueError("Incorrect type of audio_feature_lens. "
|
||||||
f"Got type: {type(audio_feature_lens)}")
|
f"Got type: {type(audio_feature_lens)}")
|
||||||
|
|
||||||
return MiniCPMOAudioFeatureInputs(
|
audio_features_flat = flatten_bn(audio_features)
|
||||||
type="audio_features",
|
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
|
||||||
audio_features=flatten_bn(audio_features, concat=True),
|
|
||||||
audio_feature_lens=flatten_bn(
|
|
||||||
flatten_2d_lists(audio_feature_lens), concat=True),
|
|
||||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
|
||||||
audio_end_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
return MiniCPMOAudioFeatureInputs(
|
||||||
|
type="audio_features",
|
||||||
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
|
audio_features=audio_features_flat,
|
||||||
**kwargs: object):
|
audio_feature_lens=audio_feature_lens_flat,
|
||||||
image_inputs = self._parse_and_validate_image_inputs(
|
embed_is_patch=audio_embed_is_patch,
|
||||||
input_ids, **kwargs)
|
|
||||||
if not any("audio" in key for key in kwargs):
|
|
||||||
return image_inputs, None
|
|
||||||
audio_inputs = self._parse_and_validate_audio_inputs(
|
|
||||||
input_ids, **kwargs)
|
|
||||||
return image_inputs, audio_inputs
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if intermediate_tensors is not None:
|
|
||||||
vlm_embeddings = None
|
|
||||||
else:
|
|
||||||
image_inputs, audio_inputs = \
|
|
||||||
self._parse_and_validate_inputs(input_ids, **kwargs)
|
|
||||||
vlm_embeddings = self.get_embedding_with_vision(
|
|
||||||
input_ids, image_inputs)
|
|
||||||
|
|
||||||
if audio_inputs is not None:
|
|
||||||
vlm_embeddings = self.get_embedding_with_audios(
|
|
||||||
vlm_embeddings, audio_inputs,
|
|
||||||
self.config.audio_chunk_length)
|
|
||||||
|
|
||||||
# always pass the input via `inputs_embeds`
|
|
||||||
# to make sure the computation graph is consistent
|
|
||||||
# for `torch.compile` integration
|
|
||||||
input_ids = None
|
|
||||||
|
|
||||||
output = self.llm.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=vlm_embeddings,
|
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
|
modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)
|
||||||
|
|
||||||
|
# Preserve the order of modalities if there are multiple of them
|
||||||
|
# from the order of kwargs.
|
||||||
|
for input_key in kwargs:
|
||||||
|
if input_key in ("audio_features",
|
||||||
|
"audio_embeds") and "audios" not in modalities:
|
||||||
|
modalities["audios"] = self._parse_and_validate_audio_input(
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return modalities
|
||||||
|
|
||||||
|
def _process_audio_input(
|
||||||
|
self,
|
||||||
|
audio_input: MiniCPMOAudioInputs,
|
||||||
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
|
if audio_input["type"] == "audio_embeds":
|
||||||
|
return audio_input["audio_embeds"]
|
||||||
|
|
||||||
|
return self.get_audio_hidden_states(audio_input)
|
||||||
|
|
||||||
|
def _process_multimodal_inputs(self, modalities: dict):
|
||||||
|
multimodal_embeddings = super()._process_multimodal_inputs(modalities)
|
||||||
|
|
||||||
|
for modality in modalities:
|
||||||
|
if modality == "audios":
|
||||||
|
audio_input = modalities["audios"]
|
||||||
|
audio_features = self._process_audio_input(audio_input)
|
||||||
|
multimodal_embeddings += tuple(
|
||||||
|
scatter_patch_features(
|
||||||
|
audio_features,
|
||||||
|
audio_input["embed_is_patch"],
|
||||||
|
))
|
||||||
|
|
||||||
|
return multimodal_embeddings
|
||||||
|
|||||||
@ -23,17 +23,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
|
||||||
import math
|
import math
|
||||||
import re
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import cached_property, partial
|
from functools import cached_property, partial
|
||||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
|
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
|
||||||
TypedDict, Union)
|
Union)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.types
|
import torch.types
|
||||||
from PIL import Image
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BatchFeature, PretrainedConfig
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
@ -50,9 +48,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|||||||
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
|
||||||
MultiModalInputs, NestedTensors,
|
|
||||||
PlaceholderRange)
|
|
||||||
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
|
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
|
||||||
ImageProcessorItems, ImageSize,
|
ImageProcessorItems, ImageSize,
|
||||||
ModalityData, ModalityDataItems,
|
ModalityData, ModalityDataItems,
|
||||||
@ -67,13 +63,11 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.utils import flatten_2d_lists
|
from vllm.utils import flatten_2d_lists
|
||||||
|
|
||||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||||
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsV0Only)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||||
|
merge_multimodal_embeddings)
|
||||||
CPU_DEVICE = torch.device("cpu")
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
RawImageType = Union[Image.Image, torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMVImagePixelInputs(TypedDict):
|
class MiniCPMVImagePixelInputs(TypedDict):
|
||||||
@ -86,13 +80,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
|||||||
instead of a batched tensor.
|
instead of a batched tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_bounds: torch.Tensor
|
|
||||||
"""
|
|
||||||
Shape: `(batch_size * num_images * num_slices, 2)`
|
|
||||||
|
|
||||||
This should be in `(start, stop)` format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tgt_sizes: torch.Tensor
|
tgt_sizes: torch.Tensor
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_images * num_slices, 2)`
|
Shape: `(batch_size * num_images * num_slices, 2)`
|
||||||
@ -100,23 +87,34 @@ class MiniCPMVImagePixelInputs(TypedDict):
|
|||||||
This should be in `(height, width)` format.
|
This should be in `(height, width)` format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_slices: torch.Tensor
|
||||||
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMVImageEmbeddingInputs(TypedDict):
|
class MiniCPMVImageEmbeddingInputs(TypedDict):
|
||||||
type: Literal["image_embeds"]
|
type: Literal["image_embeds"]
|
||||||
image_embeds: torch.Tensor
|
image_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_images * num_slices,
|
Shape: `(batch_size * num_images, num_slices, hidden_size)`
|
||||||
image_feature_size, hidden_size)`
|
|
||||||
|
|
||||||
`hidden_size` must match the hidden size of language model backbone.
|
`hidden_size` must match the hidden size of language model backbone.
|
||||||
instead of a batched tensor.
|
instead of a batched tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_bounds: torch.Tensor
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
Shape: `(batch_size * num_images * num_slices, 2)`
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
This should be in `(start, stop)` format.
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -233,15 +231,25 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
|
|||||||
|
|
||||||
|
|
||||||
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||||
|
pixel_values = hf_inputs.get("pixel_values", torch.empty(0))
|
||||||
|
num_images = len(pixel_values)
|
||||||
|
|
||||||
|
video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0))
|
||||||
|
num_videos = len(video_pixel_values)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||||
tgt_sizes=MultiModalFieldConfig.batched("image"),
|
tgt_sizes=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
video_pixel_values=MultiModalFieldConfig.batched("video"),
|
video_pixel_values=MultiModalFieldConfig.batched("video"),
|
||||||
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
||||||
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
|
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
|
||||||
video_embeds=MultiModalFieldConfig.batched("video"),
|
video_embeds=MultiModalFieldConfig.batched("video"),
|
||||||
|
video_embed_is_patch=MultiModalFieldConfig.batched("video"),
|
||||||
|
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||||
|
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -348,10 +356,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
|||||||
return get_version_by_config(self.get_hf_config())
|
return get_version_by_config(self.get_hf_config())
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
mm_limits = {"image": None}
|
||||||
if self.get_model_version() == (2, 6):
|
if self.get_model_version() == (2, 6):
|
||||||
return {"image": None, "video": None}
|
mm_limits["video"] = None
|
||||||
else:
|
|
||||||
return {"image": None}
|
return mm_limits
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(
|
def get_mm_max_tokens_per_item(
|
||||||
self,
|
self,
|
||||||
@ -361,70 +370,79 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
|||||||
mm_max_tokens = {"image": self.get_max_image_tokens()}
|
mm_max_tokens = {"image": self.get_max_image_tokens()}
|
||||||
if self.get_model_version() == (2, 6):
|
if self.get_model_version() == (2, 6):
|
||||||
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
|
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
|
||||||
|
|
||||||
return mm_max_tokens
|
return mm_max_tokens
|
||||||
|
|
||||||
|
def get_slice_image_placeholder(
|
||||||
|
self,
|
||||||
|
image_size: ImageSize,
|
||||||
|
# For MiniCPM V/O 2.6
|
||||||
|
image_idx: int = 0,
|
||||||
|
max_slice_nums: Optional[int] = None,
|
||||||
|
use_image_id: bool = True,
|
||||||
|
) -> str:
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
version = self.get_model_version()
|
||||||
|
|
||||||
|
if version == (2, 0) or version == (2, 5):
|
||||||
|
return image_processor.get_slice_image_placeholder(image_size)
|
||||||
|
|
||||||
|
return image_processor.get_slice_image_placeholder(
|
||||||
|
image_size,
|
||||||
|
image_idx=image_idx,
|
||||||
|
max_slice_nums=max_slice_nums,
|
||||||
|
use_image_id=use_image_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_image_tokens(
|
||||||
|
self,
|
||||||
|
image_size: ImageSize,
|
||||||
|
max_slice_nums: Optional[int] = None,
|
||||||
|
use_image_id: bool = True,
|
||||||
|
) -> int:
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
image_placeholders = self.get_slice_image_placeholder(
|
||||||
|
image_size,
|
||||||
|
max_slice_nums=max_slice_nums,
|
||||||
|
use_image_id=use_image_id,
|
||||||
|
)
|
||||||
|
image_token_ids = tokenizer.encode(image_placeholders,
|
||||||
|
add_special_tokens=False)
|
||||||
|
|
||||||
|
return len(image_token_ids)
|
||||||
|
|
||||||
|
def get_max_image_tokens(self) -> int:
|
||||||
|
image_size = self.get_image_size_with_most_features()
|
||||||
|
return self.get_num_image_tokens(image_size)
|
||||||
|
|
||||||
|
def get_image_max_slice_num(self) -> int:
|
||||||
|
return getattr(self.get_hf_config(), "max_slice_num", 9)
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
|
image_size = getattr(self.get_hf_config(), "image_size", 448)
|
||||||
|
max_slice_num = self.get_image_max_slice_num()
|
||||||
|
return ImageSize(width=image_size, height=image_size * max_slice_num)
|
||||||
|
|
||||||
def get_max_video_frame_tokens(self) -> int:
|
def get_max_video_frame_tokens(self) -> int:
|
||||||
frame_size = self.get_video_frame_size_with_most_features()
|
frame_size = self.get_video_frame_size_with_most_features()
|
||||||
return self.get_num_image_tokens(frame_size,
|
|
||||||
self.get_video_max_slice_num())
|
return self.get_num_image_tokens(
|
||||||
|
frame_size,
|
||||||
|
max_slice_nums=self.get_video_max_slice_num(),
|
||||||
|
use_image_id=False,
|
||||||
|
)
|
||||||
|
|
||||||
def get_max_video_tokens(self, seq_len: int) -> int:
|
def get_max_video_tokens(self, seq_len: int) -> int:
|
||||||
return self.get_max_video_frame_tokens(
|
return self.get_max_video_frame_tokens(
|
||||||
) * self.get_num_frames_with_most_features(seq_len)
|
) * self.get_num_frames_with_most_features(seq_len)
|
||||||
|
|
||||||
def get_slice_query_num(self) -> int:
|
|
||||||
hf_config = self.get_hf_config()
|
|
||||||
query_num = getattr(hf_config, "query_num", 64)
|
|
||||||
return query_num
|
|
||||||
|
|
||||||
def get_max_slice_num(self) -> int:
|
|
||||||
hf_config = self.get_hf_config()
|
|
||||||
max_slice_num = getattr(hf_config, "max_slice_num", 9)
|
|
||||||
return max_slice_num
|
|
||||||
|
|
||||||
def get_sliced_grid(self, image_size: ImageSize,
|
|
||||||
max_slice_num: int) -> Tuple[int, int]:
|
|
||||||
if self.get_model_version() == (2, 6):
|
|
||||||
slice_grid = self.get_image_processor().get_sliced_grid(
|
|
||||||
image_size, max_slice_num)
|
|
||||||
else:
|
|
||||||
slice_grid = self.get_image_processor().get_sliced_grid(image_size)
|
|
||||||
return slice_grid
|
|
||||||
|
|
||||||
def get_num_image_tokens(self, image_size: ImageSize,
|
|
||||||
max_slice_num: int) -> int:
|
|
||||||
slice_grid = self.get_sliced_grid(image_size, max_slice_num)
|
|
||||||
num_tokens = self.get_slice_query_num(
|
|
||||||
) + 2 # <image>(<unk> * query_num)</image>
|
|
||||||
if slice_grid is not None:
|
|
||||||
if self.get_model_version() == (2, 6):
|
|
||||||
num_additional_tokens = 0
|
|
||||||
else:
|
|
||||||
# <slice><image>(<unk> * query_num)</image></slice>
|
|
||||||
num_additional_tokens = 2
|
|
||||||
num_tokens += ((self.get_slice_query_num() + 2) \
|
|
||||||
* slice_grid[0] * slice_grid[1]) \
|
|
||||||
+ slice_grid[1] - 1 + num_additional_tokens
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
def get_image_slice_nums(self, image_size: torch.Tensor,
|
|
||||||
max_slice_nums: int) -> int:
|
|
||||||
grid = self.get_sliced_grid(image_size, max_slice_nums)
|
|
||||||
return 1 if grid is None else grid[0] * grid[1] + 1
|
|
||||||
|
|
||||||
def get_max_image_tokens(self) -> int:
|
|
||||||
image_size = self.get_image_size_with_most_features()
|
|
||||||
return self.get_num_image_tokens(image_size, self.get_max_slice_num())
|
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
# Result in the max possible feature size (h:w = 9:1)
|
|
||||||
return self.get_default_image_sizes(self.get_max_slice_num())
|
|
||||||
|
|
||||||
def get_video_max_slice_num(self) -> int:
|
def get_video_max_slice_num(self) -> int:
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def get_video_frame_size_with_most_features(self) -> ImageSize:
|
def get_video_frame_size_with_most_features(self) -> ImageSize:
|
||||||
return self.get_default_image_sizes(self.get_video_max_slice_num())
|
image_size = getattr(self.get_hf_config(), "image_size", 448)
|
||||||
|
max_slice_num = self.get_video_max_slice_num()
|
||||||
|
return ImageSize(width=image_size, height=image_size * max_slice_num)
|
||||||
|
|
||||||
def get_max_video_frames(self, max_tokens: int) -> int:
|
def get_max_video_frames(self, max_tokens: int) -> int:
|
||||||
num_frame_tokens = self.get_max_video_frame_tokens()
|
num_frame_tokens = self.get_max_video_frame_tokens()
|
||||||
@ -436,10 +454,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
|||||||
max_images = mm_config.get_limit_per_prompt("image")
|
max_images = mm_config.get_limit_per_prompt("image")
|
||||||
max_videos = mm_config.get_limit_per_prompt("video")
|
max_videos = mm_config.get_limit_per_prompt("video")
|
||||||
|
|
||||||
# count <image_idx></image_idx> tokens
|
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||||
# which are not in get_max_image_tokens
|
|
||||||
max_image_tokens = self.get_max_image_tokens(
|
|
||||||
) * max_images + 4 * max_images
|
|
||||||
max_total_frames = self.get_max_video_frames(seq_len -
|
max_total_frames = self.get_max_video_frames(seq_len -
|
||||||
max_image_tokens)
|
max_image_tokens)
|
||||||
|
|
||||||
@ -447,10 +462,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
|||||||
|
|
||||||
return num_frames
|
return num_frames
|
||||||
|
|
||||||
def get_default_image_sizes(self, num_slices: int) -> ImageSize:
|
|
||||||
image_size = getattr(self.get_hf_config(), "image_size", 448)
|
|
||||||
return ImageSize(width=image_size, height=image_size * num_slices)
|
|
||||||
|
|
||||||
|
|
||||||
_I = TypeVar("_I",
|
_I = TypeVar("_I",
|
||||||
bound=MiniCPMVProcessingInfo,
|
bound=MiniCPMVProcessingInfo,
|
||||||
@ -499,42 +510,30 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
def _get_data_parser(self) -> MultiModalDataParser:
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
return MiniCPMVMultiModalDataParser()
|
return MiniCPMVMultiModalDataParser()
|
||||||
|
|
||||||
def get_slice_image_placeholder(self, image_size: ImageSize,
|
|
||||||
**kwargs) -> str:
|
|
||||||
image_processor = self.info.get_image_processor()
|
|
||||||
version = self.info.get_model_version()
|
|
||||||
if version == (2, 0) or version == (2, 5):
|
|
||||||
return image_processor.get_slice_image_placeholder(image_size)
|
|
||||||
return image_processor.get_slice_image_placeholder(
|
|
||||||
image_size, **kwargs)
|
|
||||||
|
|
||||||
def get_image_prompt_texts(self,
|
def get_image_prompt_texts(self,
|
||||||
image_size: ImageSize,
|
image_size: ImageSize,
|
||||||
image_idx: int = 0) -> str:
|
image_idx: int = 0) -> str:
|
||||||
return self.get_slice_image_placeholder(image_size,
|
return self.info.get_slice_image_placeholder(
|
||||||
image_idx=image_idx)
|
image_size,
|
||||||
|
image_idx=image_idx,
|
||||||
|
)
|
||||||
|
|
||||||
def get_video_prompt_texts(self, image_size: ImageSize,
|
def get_video_prompt_texts(self, image_size: ImageSize,
|
||||||
num_frames: int) -> str:
|
num_frames: int) -> str:
|
||||||
return self.get_slice_image_placeholder(
|
return self.info.get_slice_image_placeholder(
|
||||||
image_size=image_size,
|
image_size=image_size,
|
||||||
image_idx=0,
|
image_idx=0,
|
||||||
max_slice_nums=self.info.get_video_max_slice_num(),
|
max_slice_nums=self.info.get_video_max_slice_num(),
|
||||||
use_image_id=False,
|
use_image_id=False,
|
||||||
) * num_frames
|
) * num_frames
|
||||||
|
|
||||||
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
|
def get_embed_is_patch(
|
||||||
|
self,
|
||||||
|
input_ids: list[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||||
special_tokens = {
|
return torch.tensor(input_ids) == unk_token_id
|
||||||
"im_start_id": tokenizer.im_start_id,
|
|
||||||
"im_end_id": tokenizer.im_end_id,
|
|
||||||
}
|
|
||||||
if hasattr(tokenizer, "slice_start_id"):
|
|
||||||
special_tokens["slice_start_id"] = tokenizer.slice_start_id
|
|
||||||
special_tokens["slice_end_id"] = tokenizer.slice_end_id
|
|
||||||
|
|
||||||
return {k: torch.tensor(v) for k, v in special_tokens.items()}
|
|
||||||
|
|
||||||
def process_images(
|
def process_images(
|
||||||
self,
|
self,
|
||||||
@ -546,14 +545,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
|
|
||||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
parsed_images = (self._get_data_parser().parse_mm_data({
|
||||||
"image": images
|
"image": images
|
||||||
}).get_items("image", ImageProcessorItems))
|
}).get_items("image",
|
||||||
|
(MiniCPMVImageEmbeddingItems, ImageProcessorItems)))
|
||||||
|
|
||||||
return self._base_call_hf_processor(
|
if isinstance(parsed_images, MiniCPMVImageEmbeddingItems):
|
||||||
prompts=[self.info.image_pattern] * len(parsed_images),
|
image_inputs = {}
|
||||||
mm_data={"images": [[image] for image in parsed_images]},
|
else:
|
||||||
mm_kwargs=mm_kwargs,
|
image_inputs = self._base_call_hf_processor(
|
||||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
prompts=[self.info.image_pattern] * len(parsed_images),
|
||||||
)
|
mm_data={"images": [[image] for image in parsed_images]},
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||||
|
)
|
||||||
|
|
||||||
|
image_sizes = [
|
||||||
|
parsed_images.get_image_size(i) for i in range(len(parsed_images))
|
||||||
|
]
|
||||||
|
image_repl_features = [
|
||||||
|
self.get_image_prompt_texts(size, idx)
|
||||||
|
for idx, size in enumerate(image_sizes)
|
||||||
|
]
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
image_repls_feature_tokens = [
|
||||||
|
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||||
|
for image_repl in image_repl_features
|
||||||
|
]
|
||||||
|
|
||||||
|
embed_is_patch = [
|
||||||
|
self.get_embed_is_patch(image_repl_tokens)
|
||||||
|
for image_repl_tokens in image_repls_feature_tokens
|
||||||
|
]
|
||||||
|
image_inputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
|
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||||
|
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
def process_videos(
|
def process_videos(
|
||||||
self,
|
self,
|
||||||
@ -565,25 +593,55 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
|
|
||||||
parsed_videos = (self._get_data_parser().parse_mm_data({
|
parsed_videos = (self._get_data_parser().parse_mm_data({
|
||||||
"video": videos
|
"video": videos
|
||||||
}).get_items("video", VideoProcessorItems))
|
}).get_items("video",
|
||||||
|
(MiniCPMVVideoEmbeddingItems, VideoProcessorItems)))
|
||||||
|
|
||||||
max_slice_num = self.info.get_video_max_slice_num()
|
if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems):
|
||||||
|
video_inputs = {}
|
||||||
|
else:
|
||||||
|
video_inputs = self._base_call_hf_processor(
|
||||||
|
prompts=[
|
||||||
|
self.info.image_pattern * len(video)
|
||||||
|
for video in parsed_videos
|
||||||
|
],
|
||||||
|
mm_data={"images": list(parsed_videos)},
|
||||||
|
mm_kwargs={
|
||||||
|
**mm_kwargs,
|
||||||
|
"max_slice_nums":
|
||||||
|
self.info.get_video_max_slice_num(),
|
||||||
|
},
|
||||||
|
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
||||||
|
)
|
||||||
|
|
||||||
video_inputs = self._base_call_hf_processor(
|
frame_sizes = [
|
||||||
prompts=[
|
parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
|
||||||
self.info.image_pattern * len(video) for video in parsed_videos
|
]
|
||||||
],
|
num_frames = [
|
||||||
mm_data={"images": list(parsed_videos)},
|
parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
|
||||||
mm_kwargs={
|
]
|
||||||
**mm_kwargs, "max_slice_nums": max_slice_num
|
video_repl_features = [
|
||||||
},
|
self.get_video_prompt_texts(size, nframes)
|
||||||
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
|
for size, nframes in zip(frame_sizes, num_frames)
|
||||||
)
|
]
|
||||||
|
|
||||||
return {f"video_{k}": v for k, v in video_inputs.items()}
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
video_repls_feature_tokens = [
|
||||||
|
tokenizer.encode(video_repl, add_special_tokens=False)
|
||||||
|
for video_repl in video_repl_features
|
||||||
|
]
|
||||||
|
|
||||||
def get_placeholder_match_pattern(self) -> str:
|
embed_is_patch = [
|
||||||
return r"\(<(image|video)>./</\1>\)"
|
self.get_embed_is_patch(video_repl_tokens)
|
||||||
|
for video_repl_tokens in video_repls_feature_tokens
|
||||||
|
]
|
||||||
|
video_inputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
|
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
|
||||||
|
|
||||||
|
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||||
|
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
|
||||||
|
|
||||||
|
return video_inputs
|
||||||
|
|
||||||
def process_mm_inputs(
|
def process_mm_inputs(
|
||||||
self,
|
self,
|
||||||
@ -602,7 +660,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
mm_kwargs: Mapping[str, object],
|
mm_kwargs: Mapping[str, object],
|
||||||
*,
|
*,
|
||||||
out_keys: set[str],
|
out_keys: set[str],
|
||||||
) -> Mapping[str, NestedTensors]:
|
) -> dict[str, NestedTensors]:
|
||||||
# This processor supports zipping prompt and mm_data together
|
# This processor supports zipping prompt and mm_data together
|
||||||
if self.info.get_model_version() == (2, 6):
|
if self.info.get_model_version() == (2, 6):
|
||||||
inputs = super()._call_hf_processor(
|
inputs = super()._call_hf_processor(
|
||||||
@ -635,14 +693,13 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
mm_data: Mapping[str, object],
|
mm_data: Mapping[str, object],
|
||||||
mm_kwargs: Mapping[str, object],
|
mm_kwargs: Mapping[str, object],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
# Do not support combination inputs of images and videos for now
|
|
||||||
# Try to handle interleaved multimodal data
|
|
||||||
tokenizer = self.info.get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
|
||||||
|
input_ids = torch.tensor([tokenizer.encode(prompt)])
|
||||||
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
|
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
|
||||||
|
|
||||||
return BatchFeature({
|
return BatchFeature({
|
||||||
"input_ids":
|
"input_ids": input_ids,
|
||||||
torch.tensor([tokenizer.encode(prompt)]),
|
|
||||||
**mm_inputs,
|
**mm_inputs,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -701,39 +758,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return _minicpmv_field_config(hf_inputs)
|
return _minicpmv_field_config(hf_inputs)
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, List[int]],
|
|
||||||
mm_data: MultiModalDataDict,
|
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
|
||||||
return_mm_hashes: bool = False,
|
|
||||||
) -> MultiModalInputs:
|
|
||||||
if isinstance(prompt, list):
|
|
||||||
prompt = self.info.get_tokenizer().decode(prompt)
|
|
||||||
matches = re.findall(self.get_placeholder_match_pattern(), prompt)
|
|
||||||
mm_orders = {
|
|
||||||
f"{modality}_orders":
|
|
||||||
torch.tensor(
|
|
||||||
[index for index, m in enumerate(matches) if m == modality])
|
|
||||||
for modality in self.info.get_supported_mm_limits()
|
|
||||||
}
|
|
||||||
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
|
|
||||||
return_mm_hashes)
|
|
||||||
# Exclude <image_id>x</image_id> from placeholders
|
|
||||||
if "image" in result["mm_placeholders"] and \
|
|
||||||
self.info.get_model_version() == (2, 6):
|
|
||||||
result["mm_placeholders"]["image"] = [
|
|
||||||
PlaceholderRange(offset=p["offset"] + 3 + idx // 10,
|
|
||||||
length=p["length"] - 3 - idx // 10)
|
|
||||||
for idx, p in enumerate(result["mm_placeholders"]["image"])
|
|
||||||
]
|
|
||||||
result["mm_kwargs"].update(**mm_orders)
|
|
||||||
result["mm_kwargs"].update(**self.get_special_tokens())
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
|
||||||
SupportsV0Only):
|
|
||||||
"""
|
"""
|
||||||
The abstract class of MiniCPMV can only be inherited, but cannot be
|
The abstract class of MiniCPMV can only be inherited, but cannot be
|
||||||
instantiated.
|
instantiated.
|
||||||
@ -767,6 +793,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
prefix=maybe_prefix(
|
prefix=maybe_prefix(
|
||||||
prefix, "resampler"))
|
prefix, "resampler"))
|
||||||
|
|
||||||
|
self.mm_token_ids = set[int]()
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.llm.make_empty_intermediate_tensors)
|
self.llm.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
@ -777,233 +804,191 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
|
|
||||||
return get_sampler()
|
return get_sampler()
|
||||||
|
|
||||||
def get_embedding_with_vision(
|
def _parse_and_validate_vision_input(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
modality: str,
|
||||||
image_inputs: Optional[MiniCPMVImageInputs],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
|
||||||
|
|
||||||
if image_inputs is None:
|
|
||||||
return vlm_embedding
|
|
||||||
|
|
||||||
if image_inputs["type"] == "image_embeds":
|
|
||||||
vision_hidden_states = image_inputs["image_embeds"].to(
|
|
||||||
device=vlm_embedding.device,
|
|
||||||
dtype=vlm_embedding.dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
|
||||||
|
|
||||||
# See NOTE in _parse_and_validate_inputs
|
|
||||||
image_bounds = image_inputs["image_bounds"]
|
|
||||||
if len(image_bounds) > 0:
|
|
||||||
image_indices = torch.stack([
|
|
||||||
torch.arange(start, end, dtype=torch.long)
|
|
||||||
for start, end in image_bounds.tolist()
|
|
||||||
]).to(vlm_embedding.device)
|
|
||||||
|
|
||||||
vlm_embedding.scatter_(
|
|
||||||
0,
|
|
||||||
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
|
||||||
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
|
|
||||||
)
|
|
||||||
|
|
||||||
return vlm_embedding
|
|
||||||
|
|
||||||
def _get_image_bounds(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
im_start_id: torch.Tensor,
|
|
||||||
im_end_id: torch.Tensor,
|
|
||||||
slice_start_id: Optional[torch.Tensor] = None,
|
|
||||||
slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
# All the images in the batch should share the same special image
|
|
||||||
# bound token ids.
|
|
||||||
start_cond = input_ids == im_start_id[0]
|
|
||||||
end_cond = input_ids == im_end_id[0]
|
|
||||||
if slice_start_id is not None:
|
|
||||||
start_cond |= (input_ids == slice_start_id[0])
|
|
||||||
end_cond |= (input_ids == slice_end_id[0])
|
|
||||||
|
|
||||||
image_start_tokens, = torch.where(start_cond)
|
|
||||||
image_start_tokens += 1
|
|
||||||
image_end_tokens, = torch.where(end_cond)
|
|
||||||
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
|
||||||
|
|
||||||
if valid_image_nums == 0:
|
|
||||||
return torch.zeros((0, 2), device=input_ids.device)
|
|
||||||
|
|
||||||
return torch.hstack([
|
|
||||||
image_start_tokens[:valid_image_nums].unsqueeze(-1),
|
|
||||||
image_end_tokens[:valid_image_nums].unsqueeze(-1),
|
|
||||||
])
|
|
||||||
|
|
||||||
def _parse_and_validate_image_inputs(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
) -> Optional[MiniCPMVImageInputs]:
|
) -> Optional[MiniCPMVImageInputs]:
|
||||||
image_keys = {"pixel_values", "tgt_sizes"}
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
pixel_data = {
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
"image": {
|
|
||||||
key: kwargs.pop(key, None)
|
|
||||||
for key in image_keys
|
|
||||||
},
|
|
||||||
"video": {
|
|
||||||
key: kwargs.pop("video_" + key, None)
|
|
||||||
for key in image_keys
|
|
||||||
}
|
|
||||||
}
|
|
||||||
embed_data = {
|
|
||||||
"image": kwargs.pop("image_embeds", None),
|
|
||||||
"video": kwargs.pop("video_embeds", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
all_pixel_data = [
|
if pixel_values is None and image_embeds is None:
|
||||||
v for vs in pixel_data.values() for v in vs.values()
|
|
||||||
if v is not None
|
|
||||||
]
|
|
||||||
all_embed_data = [v for v in embed_data.values() if v is not None]
|
|
||||||
if len(all_pixel_data) == 0 and len(all_embed_data) == 0:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
im_start_id = kwargs.pop("im_start_id")
|
image_token_id = kwargs.pop("image_token_id")
|
||||||
if not isinstance(im_start_id, torch.Tensor):
|
if image_token_id is not None:
|
||||||
raise ValueError("Incorrect type of im_start_id. "
|
assert isinstance(image_token_id, torch.Tensor)
|
||||||
f"Got type: {type(im_start_id)}")
|
self.mm_token_ids.add(image_token_id.flatten().unique().item())
|
||||||
|
|
||||||
im_end_id = kwargs.pop("im_end_id")
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||||
if not isinstance(im_end_id, torch.Tensor):
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of im_end_id. "
|
raise ValueError(
|
||||||
f"Got type: {type(im_end_id)}")
|
f"Incorrect type of embed_is_patch for {modality=}. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
slice_start_id = kwargs.pop("slice_start_id", None)
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
if slice_start_id is not None and not isinstance(
|
|
||||||
slice_start_id, torch.Tensor):
|
|
||||||
raise ValueError("Incorrect type of slice_start_id. "
|
|
||||||
f"Got type: {type(slice_start_id)}")
|
|
||||||
|
|
||||||
slice_end_id = kwargs.pop("slice_end_id", None)
|
if image_embeds is not None:
|
||||||
if slice_end_id is not None and not isinstance(slice_end_id,
|
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||||
torch.Tensor):
|
raise ValueError(
|
||||||
raise ValueError("Incorrect type of slice_end_id. "
|
f"Incorrect type of image_embeds for {modality=}. "
|
||||||
f"Got type: {type(slice_end_id)}")
|
f"Got type: {type(image_embeds)}")
|
||||||
|
|
||||||
if len(all_embed_data) > 0:
|
image_embeds_flat = flatten_bn(image_embeds)
|
||||||
if len(all_embed_data) > 1:
|
|
||||||
raise ValueError("Incorrect inputs for vision embeddings. "
|
|
||||||
"Image embeds and video embeds can not "
|
|
||||||
"exist simultaneously.")
|
|
||||||
|
|
||||||
vision_embeds, = all_embed_data
|
|
||||||
if not isinstance(vision_embeds, (torch.Tensor, list)):
|
|
||||||
raise ValueError(f"Incorrect type of vision_embeds. "
|
|
||||||
f"Got type: {type(vision_embeds)}")
|
|
||||||
|
|
||||||
return MiniCPMVImageEmbeddingInputs(
|
return MiniCPMVImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
image_embeds=flatten_bn(flatten_2d_lists(vision_embeds),
|
image_embeds=image_embeds_flat,
|
||||||
concat=True),
|
embed_is_patch=embed_is_patch,
|
||||||
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
|
||||||
im_end_id, slice_start_id,
|
|
||||||
slice_end_id),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]()
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
for modality in ("image", "video"):
|
raise ValueError(
|
||||||
modality_orders = kwargs.pop(f"{modality}_orders", None)
|
f"Incorrect type of pixel_values for {modality=}. "
|
||||||
if modality_orders is not None:
|
f"Got type: {type(pixel_values)}")
|
||||||
if not isinstance(modality_orders, (torch.Tensor, list)):
|
|
||||||
raise ValueError(f"Incorrect type of {modality}_orders. "
|
|
||||||
f"Got type: {type(modality_orders)}")
|
|
||||||
|
|
||||||
order_data[modality] = modality_orders
|
tgt_sizes = kwargs.pop("tgt_sizes")
|
||||||
|
if not isinstance(tgt_sizes, (torch.Tensor, list)):
|
||||||
|
raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. "
|
||||||
|
f"Got type: {type(tgt_sizes)}")
|
||||||
|
|
||||||
batch_sizes = {
|
num_slices = [[len(p) for p in ps] for ps in pixel_values]
|
||||||
modality: len(modality_orders)
|
num_slices_flat = flatten_bn(torch.tensor(num_slices))
|
||||||
for modality, modality_orders in order_data.items()
|
|
||||||
}
|
|
||||||
unique_batch_sizes = set(batch_sizes.values())
|
|
||||||
assert len(unique_batch_sizes) == 1, (
|
|
||||||
f"Found inconsistent batch sizes: {batch_sizes}")
|
|
||||||
batch_size, = unique_batch_sizes
|
|
||||||
|
|
||||||
pixel_values_flat = list[torch.Tensor]()
|
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
|
||||||
tgt_sizes_flat = list[torch.Tensor]()
|
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
|
||||||
for b in range(batch_size):
|
|
||||||
mm_orders_b = [(idx_b.item(), modality)
|
|
||||||
for modality, modality_orders in order_data.items()
|
|
||||||
for idx_b in modality_orders[b]]
|
|
||||||
|
|
||||||
for _, modality in sorted(mm_orders_b, key=lambda x: x[0]):
|
|
||||||
modality_pixel_data = pixel_data[modality]
|
|
||||||
|
|
||||||
modality_pixel_values = modality_pixel_data["pixel_values"]
|
|
||||||
if not isinstance(modality_pixel_values, (torch.Tensor, list)):
|
|
||||||
raise ValueError(
|
|
||||||
f"Incorrect type of pixel_values for {modality=}. "
|
|
||||||
f"Got type: {type(modality_pixel_values)}")
|
|
||||||
|
|
||||||
modality_tgt_sizes = modality_pixel_data["tgt_sizes"]
|
|
||||||
if not isinstance(modality_tgt_sizes, (torch.Tensor, list)):
|
|
||||||
raise ValueError(
|
|
||||||
f"Incorrect type of tgt_sizes for {modality=}. "
|
|
||||||
f"Got type: {type(modality_tgt_sizes)}")
|
|
||||||
|
|
||||||
pixel_values_flat += flatten_2d_lists(modality_pixel_values[b])
|
|
||||||
tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b])
|
|
||||||
|
|
||||||
# NOTE: Input IDs does not contain image tokens during memory profiling,
|
|
||||||
# so we allow it to be empty
|
|
||||||
if len(pixel_values_flat) != len(tgt_sizes_flat):
|
if len(pixel_values_flat) != len(tgt_sizes_flat):
|
||||||
raise ValueError("Inconsistent flattened lengths, found: "
|
raise ValueError("Inconsistent flattened lengths, found: "
|
||||||
f"{len(pixel_values_flat)} vs. "
|
f"{len(pixel_values_flat)} vs. "
|
||||||
f"{len(tgt_sizes_flat)}")
|
f"{len(tgt_sizes_flat)}")
|
||||||
|
|
||||||
if len(pixel_values_flat) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return MiniCPMVImagePixelInputs(
|
return MiniCPMVImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=pixel_values_flat,
|
pixel_values=pixel_values_flat,
|
||||||
tgt_sizes=torch.stack(tgt_sizes_flat),
|
tgt_sizes=tgt_sizes_flat,
|
||||||
image_bounds=self._get_image_bounds(input_ids, im_start_id,
|
embed_is_patch=embed_is_patch,
|
||||||
im_end_id, slice_start_id,
|
num_slices=num_slices_flat,
|
||||||
slice_end_id),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
**kwargs: object):
|
modalities = {}
|
||||||
return self._parse_and_validate_image_inputs(input_ids, **kwargs)
|
|
||||||
|
# Preserve the order of modalities if there are multiple of them
|
||||||
|
# from the order of kwargs.
|
||||||
|
for input_key in kwargs:
|
||||||
|
if input_key in ("pixel_values",
|
||||||
|
"image_embeds") and "images" not in modalities:
|
||||||
|
modalities["images"] = self._parse_and_validate_vision_input(
|
||||||
|
"images", **kwargs)
|
||||||
|
if input_key in ("video_pixel_values",
|
||||||
|
"video_embeds") and "videos" not in modalities:
|
||||||
|
|
||||||
|
def _image_key(video_key: str):
|
||||||
|
if video_key == "video_token_id":
|
||||||
|
return "image_token_id"
|
||||||
|
|
||||||
|
return video_key.removeprefix("video_")
|
||||||
|
|
||||||
|
modalities["videos"] = self._parse_and_validate_vision_input(
|
||||||
|
"videos", **{
|
||||||
|
_image_key(k): v
|
||||||
|
for k, v in kwargs.items()
|
||||||
|
})
|
||||||
|
|
||||||
|
return modalities
|
||||||
|
|
||||||
|
def _process_vision_input(
|
||||||
|
self,
|
||||||
|
image_input: MiniCPMVImageInputs,
|
||||||
|
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
return image_input["image_embeds"]
|
||||||
|
|
||||||
|
image_features_flat = self.get_vision_hidden_states(image_input)
|
||||||
|
|
||||||
|
# Reconstruct the batch dimension
|
||||||
|
return image_features_flat.split(image_input["num_slices"].tolist())
|
||||||
|
|
||||||
|
def _process_multimodal_inputs(self, modalities: dict):
|
||||||
|
# The result multimodal_embeddings is tuple of tensors, with each
|
||||||
|
# tensor correspoending to a multimodal data item (image or video).
|
||||||
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||||
|
|
||||||
|
# NOTE: It is important to iterate over the keys in this dictionary
|
||||||
|
# to preserve the order of the modalities.
|
||||||
|
for modality in modalities:
|
||||||
|
if modality == "images":
|
||||||
|
image_input = modalities["images"]
|
||||||
|
image_features = self._process_vision_input(image_input)
|
||||||
|
multimodal_embeddings += tuple(
|
||||||
|
scatter_patch_features(
|
||||||
|
image_features,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
))
|
||||||
|
if modality == "videos":
|
||||||
|
video_input = modalities["videos"]
|
||||||
|
video_features = self._process_vision_input(video_input)
|
||||||
|
multimodal_embeddings += tuple(
|
||||||
|
scatter_patch_features(
|
||||||
|
video_features,
|
||||||
|
video_input["embed_is_patch"],
|
||||||
|
))
|
||||||
|
|
||||||
|
return multimodal_embeddings
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
|
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||||
|
if not modalities:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._process_multimodal_inputs(modalities)
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.llm.get_input_embeddings(input_ids)
|
||||||
|
if multimodal_embeddings is not None:
|
||||||
|
assert len(self.mm_token_ids) > 0
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
select_patch_features(multimodal_embeddings),
|
||||||
|
list(self.mm_token_ids),
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
vlm_embeddings = None
|
inputs_embeds = None
|
||||||
else:
|
|
||||||
image_inputs = \
|
|
||||||
self._parse_and_validate_inputs(input_ids, **kwargs)
|
|
||||||
vlm_embeddings = self.get_embedding_with_vision(
|
|
||||||
input_ids, image_inputs)
|
|
||||||
|
|
||||||
# always pass the input via `inputs_embeds`
|
# NOTE: In v1, inputs_embeds is always generated at model runner from
|
||||||
# to make sure the computation graph is consistent
|
# `get_multimodal_embeddings` and `get_input_embeddings`, this
|
||||||
# for `torch.compile` integration
|
# condition is only for v0 compatibility.
|
||||||
input_ids = None
|
elif inputs_embeds is None:
|
||||||
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
|
|
||||||
output = self.llm.model(
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
|
vision_embeddings)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
hidden_states = self.llm.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=vlm_embeddings,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
return output
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
@ -1105,9 +1090,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.model.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
def init_resampler(self,
|
def init_resampler(self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
vision_dim: int,
|
vision_dim: int,
|
||||||
|
|||||||
@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict):
|
|||||||
Shape: `(batch_size * num_images, num_embeds)`
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
num_crops: torch.Tensor
|
||||||
"""Shape: `(batch_size, num_images)`"""
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||||
|
|
||||||
embed_is_patch = flatten_bn(embed_is_patch)
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
|
num_crops = flatten_bn(num_crops, concat=True)
|
||||||
|
|
||||||
return MolmoImageInputs(
|
return MolmoImageInputs(
|
||||||
images=images,
|
images=images,
|
||||||
@ -1510,31 +1511,24 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
feat_is_patch = image_input["feat_is_patch"]
|
feat_is_patch = image_input["feat_is_patch"]
|
||||||
num_crops = image_input["num_crops"]
|
num_crops = image_input["num_crops"]
|
||||||
|
|
||||||
if isinstance(images, list):
|
# Call the vision backbone on the whole batch at once
|
||||||
# Call the vision backbone on the whole batch at once
|
images_flat = flatten_bn(images, concat=True)
|
||||||
images_flat = flatten_bn(images, concat=True)
|
image_masks_flat = (None if image_masks is None else flatten_bn(
|
||||||
image_masks_flat = (None if image_masks is None else flatten_bn(
|
image_masks, concat=True))
|
||||||
image_masks, concat=True))
|
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
|
||||||
|
|
||||||
image_features_flat = self.vision_backbone(
|
image_features_flat = self.vision_backbone(
|
||||||
images=images_flat.unsqueeze(0),
|
images=images_flat.unsqueeze(0),
|
||||||
image_masks=(None if image_masks_flat is None else
|
image_masks=(None if image_masks_flat is None else
|
||||||
image_masks_flat.unsqueeze(0)),
|
image_masks_flat.unsqueeze(0)),
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
||||||
# Reconstruct the batch dimension
|
|
||||||
num_crops_per_image = [nc.sum().item() for nc in num_crops]
|
|
||||||
image_features = image_features_flat.split(num_crops_per_image)
|
|
||||||
else:
|
|
||||||
image_features = self.vision_backbone(
|
|
||||||
images=images,
|
|
||||||
image_masks=image_masks,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only the features corresponding to patch tokens are relevant
|
# Only the features corresponding to patch tokens are relevant
|
||||||
return [
|
return [
|
||||||
feats[f_is_patch]
|
feats[f_is_patch] for feats, f_is_patch in zip(
|
||||||
for feats, f_is_patch in zip(image_features, feat_is_patch)
|
image_features_flat.split(num_crops.tolist()),
|
||||||
|
feat_is_patch_flat.split(num_crops.tolist()),
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user