[Model] MiniCPM-V/O supports V1 (#15487)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-27 21:07:29 +08:00 committed by GitHub
parent 8063dfc61a
commit ac5bc615b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 573 additions and 594 deletions

View File

@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ
* `openbmb/MiniCPM-o-2_6`, etc. * `openbmb/MiniCPM-o-2_6`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
* * ✅︎
- * `MiniCPMV` - * `MiniCPMV`
* MiniCPM-V * MiniCPM-V
* T + I<sup>E+</sup> + V<sup>E+</sup> * T + I<sup>E+</sup> + V<sup>E+</sup>
* `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. * `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
* * ✅︎
- * `MllamaForConditionalGeneration` - * `MllamaForConditionalGeneration`
* Llama 3.2 * Llama 3.2
* T + I<sup>+</sup> * T + I<sup>+</sup>

View File

@ -23,8 +23,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" """Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple, from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
TypedDict, Union) Union)
import torch import torch
from torch import nn from torch import nn
@ -42,8 +42,6 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser, MiniCPMVMultiModalDataParser,
@ -51,13 +49,14 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
_minicpmv_field_config) _minicpmv_field_config)
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix) maybe_prefix)
from .vision import scatter_patch_features
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
class MiniCPMOAudioFeatureInputs(TypedDict): class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
audio_features: torch.Tensor audio_features: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_audios * num_slices, num_channels, length)` Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
Slice here means chunk. Audio that is too long will be split into slices, Slice here means chunk. Audio that is too long will be split into slices,
@ -65,37 +64,40 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
Padding is used therefore `audio_features` is `torch.Tensor`. Padding is used therefore `audio_features` is `torch.Tensor`.
""" """
audio_feature_lens: torch.Tensor audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_audios * num_slices)` Shape: `(batch_size * num_audios, num_slices)`
This should be feature length of each audio slice, This should be feature length of each audio slice,
which equals to `audio_features.shape[-1]` which equals to `audio_features.shape[-1]`
""" """
audio_bounds: torch.Tensor embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_audios * num_slices, 2)` A boolean mask indicating which audio embeddings correspond
to patch tokens.
This should be in `(start, stop)` format. Shape: `(batch_size * num_audios, num_embeds)`
""" """
class MiniCPMOAudioEmbeddingInputs(TypedDict): class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
audio_embeds: torch.Tensor audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images * num_slices, hidden_size)` Shape: `(batch_size * num_audios, num_slices, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor. instead of a batched tensor.
Length of each slice may vary, so pass it as a list. Length of each slice may vary, so pass it as a list.
""" """
audio_bounds: torch.Tensor
"""
Shape: `(batch_size * num_audios * num_slices, 2)`
This should be in `(start, stop)` format. embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
""" """
@ -104,11 +106,16 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features = hf_inputs.get("audio_features", torch.empty(0))
num_audios = len(audio_features)
return dict( return dict(
**_minicpmv_field_config(hf_inputs), **_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.batched("audio"), audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"), audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
) )
@ -149,7 +156,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
audio_pattern = "(<audio>./</audio>)" audio_pattern = "(<audio>./</audio>)"
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None} return {**super().get_supported_mm_limits(), "audio": None}
def get_mm_max_tokens_per_item( def get_mm_max_tokens_per_item(
self, self,
@ -157,11 +164,25 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Mapping[str, int]: ) -> Mapping[str, int]:
return { return {
"image": self.get_max_image_tokens(), **super().get_mm_max_tokens_per_item(seq_len, mm_counts),
"audio": self.get_max_audio_tokens(), "audio":
"video": self.get_max_video_tokens(seq_len), self.get_max_audio_tokens(),
} }
def get_audio_placeholder(
self,
audio_lens: int,
chunk_input: bool = True,
chunk_length: int = 1,
) -> str:
hf_processor = self.get_hf_processor()
return hf_processor.get_audio_placeholder(
audio_lens,
chunk_input=chunk_input,
chunk_length=chunk_length,
)
def get_default_audio_pool_step(self) -> int: def get_default_audio_pool_step(self) -> int:
return 2 return 2
@ -197,12 +218,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
max_videos = mm_config.get_limit_per_prompt("video") max_videos = mm_config.get_limit_per_prompt("video")
max_audios = mm_config.get_limit_per_prompt("audio") max_audios = mm_config.get_limit_per_prompt("audio")
# count <image_idx></image_idx> tokens max_image_tokens = self.get_max_image_tokens() * max_images
# which are not in get_max_image_tokens max_audio_tokens = self.get_max_audio_tokens() * max_audios
max_image_tokens = self.get_max_image_tokens(
) * max_images + 4 * max_images
max_audio_tokens = self.get_max_audio_tokens(
) * max_audios + 2 * max_audios
max_total_frames = self.get_max_video_frames(seq_len - max_total_frames = self.get_max_video_frames(seq_len -
max_image_tokens - max_image_tokens -
max_audio_tokens) max_audio_tokens)
@ -224,20 +241,20 @@ class MiniCPMODummyInputsBuilder(
processor_inputs = super().get_dummy_processor_inputs( processor_inputs = super().get_dummy_processor_inputs(
seq_len, mm_counts) seq_len, mm_counts)
mm_data = {
"image": audio_prompt_texts = self.info.audio_pattern * num_audios
processor_inputs.mm_data["image"], audio_mm_data = {
"video":
processor_inputs.mm_data["video"],
"audio": "audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios) self._get_dummy_audios(length=audio_len, num_audios=num_audios)
} }
audio_prompt_texts = self.info.audio_pattern * num_audios return ProcessorInputs(
prompt_text=processor_inputs.prompt_text + audio_prompt_texts,
return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \ mm_data={
audio_prompt_texts, **processor_inputs.mm_data,
mm_data=mm_data) **audio_mm_data,
},
)
class MiniCPMOMultiModalProcessor( class MiniCPMOMultiModalProcessor(
@ -247,22 +264,17 @@ class MiniCPMOMultiModalProcessor(
return MiniCPMOMultiModalDataParser( return MiniCPMOMultiModalDataParser(
target_sr=self.info.get_default_audio_sampling_rate()) target_sr=self.info.get_default_audio_sampling_rate())
def get_audio_prompt_texts(self, def get_audio_prompt_texts(
audio_lens: int, self,
chunk_input: bool = True, audio_lens: int,
chunk_length: int = 1) -> str: chunk_input: bool = True,
return self.info.get_hf_processor().get_audio_placeholder( chunk_length: int = 1,
audio_lens, chunk_input, chunk_length) ) -> str:
return self.info.get_audio_placeholder(
def get_special_tokens(self) -> Dict[str, torch.Tensor]: audio_lens,
tokenizer = self.info.get_tokenizer() chunk_input=chunk_input,
special_tokens = super().get_special_tokens() chunk_length=chunk_length,
if hasattr(tokenizer, "audio_start_id"): )
special_tokens["audio_start_id"] = torch.tensor(
tokenizer.audio_start_id)
special_tokens["audio_end_id"] = torch.tensor(
tokenizer.audio_end_id)
return special_tokens
def process_audios( def process_audios(
self, self,
@ -274,32 +286,65 @@ class MiniCPMOMultiModalProcessor(
parsed_audios = (self._get_data_parser().parse_mm_data({ parsed_audios = (self._get_data_parser().parse_mm_data({
"audio": audios "audio": audios
}).get_items("audio", AudioProcessorItems)) }).get_items("audio",
(MiniCPMOAudioEmbeddingItems, AudioProcessorItems)))
audio_inputs = self._base_call_hf_processor( if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
prompts=[self.info.audio_pattern] * len(parsed_audios), audio_inputs = {}
mm_data={"audios": [[audio] for audio in parsed_audios]},
mm_kwargs={
**mm_kwargs, "chunk_input": True
},
out_keys={"audio_features", "audio_feature_lens"},
)
# Avoid padding since we need the output for each audio to be audio_lens = [
# independent of other audios for the cache to work correctly self.info.get_audio_len_by_num_chunks(
unpadded_audio_features = [ sum(map(len,
feat[:, :feature_len] for feat, feature_len in zip( parsed_audios.get(i)["audio_embeds"])))
audio_inputs["audio_features"], for i in range(len(parsed_audios))
audio_inputs["audio_feature_lens"], ]
else:
audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
mm_data={"audios": [[audio] for audio in parsed_audios]},
mm_kwargs={
**mm_kwargs,
"chunk_input": True,
},
out_keys={"audio_features", "audio_feature_lens"},
) )
# Avoid padding since we need the output for each audio to be
# independent of other audios for the cache to work correctly
unpadded_audio_features = [
feat[:, :feature_len] for feat, feature_len in zip(
audio_inputs["audio_features"],
audio_inputs["audio_feature_lens"],
)
]
audio_inputs["audio_features"] = unpadded_audio_features
audio_lens = [
parsed_audios.get_audio_length(i)
for i in range(len(parsed_audios))
]
audio_repl_features = [
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
] ]
audio_inputs["audio_features"] = unpadded_audio_features
tokenizer = self.info.get_tokenizer()
audio_repls_feature_tokens = [
tokenizer.encode(audio_repl, add_special_tokens=False)
for audio_repl in audio_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(audio_repl_tokens)
for audio_repl_tokens in audio_repls_feature_tokens
]
audio_inputs["audio_embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
return audio_inputs return audio_inputs
def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video|audio)>./</\1>\)"
def process_mm_inputs( def process_mm_inputs(
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
@ -331,8 +376,7 @@ class MiniCPMOMultiModalProcessor(
if isinstance(audios, MiniCPMOAudioEmbeddingItems): if isinstance(audios, MiniCPMOAudioEmbeddingItems):
single_audio_embeds = audios.get(item_idx)["audio_embeds"] single_audio_embeds = audios.get(item_idx)["audio_embeds"]
audio_len = self.info.get_audio_len_by_num_chunks( audio_len = self.info.get_audio_len_by_num_chunks(
sum(chunk_embeds.shape[0] sum(map(len, single_audio_embeds)))
for chunk_embeds in single_audio_embeds))
else: else:
audio_len = audios.get_audio_length(item_idx) audio_len = audios.get_audio_length(item_idx)
@ -514,6 +558,8 @@ class MiniCPMO(MiniCPMV2_6):
self.apm = self.init_audio_module(vllm_config=vllm_config, self.apm = self.init_audio_module(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "apm")) prefix=maybe_prefix(prefix, "apm"))
self.audio_token_id = None
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Do not use parameters temporarily # Do not use parameters temporarily
audio_config = self.config.audio_config audio_config = self.config.audio_config
@ -563,18 +609,30 @@ class MiniCPMO(MiniCPMV2_6):
return input_lengths_after_cnn, input_lengths_after_pooling return input_lengths_after_cnn, input_lengths_after_pooling
# Copied from HF repo of MiniCPM-o-2_6, def get_audio_hidden_states(
# designed for batched inputs and outputs self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]:
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs, chunk_length = self.config.audio_chunk_length
chunk_length: int) -> list[torch.Tensor]:
wavforms = data.get(
"audio_features",
[]) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = [data.get("audio_feature_lens",
[])] # list, [[x1, x2], [y1], [z1]]
if len(wavforms) == 0: # (bs, 80, frames) or [], multi audios need filled in advance
return [] wavforms_raw = data["audio_features"]
if isinstance(wavforms_raw, list):
B = len(wavforms_raw)
C = wavforms_raw[0].shape[-2]
L = max(item.shape[-1] for item in wavforms_raw)
device = wavforms_raw[0].device
dtype = wavforms_raw[0].dtype
wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
for i, wavforms_item in enumerate(wavforms_raw):
L_item = wavforms_item.shape[-1]
wavforms[i, ..., :L_item] = wavforms_item
else:
wavforms = wavforms_raw
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = data["audio_feature_lens"]
if isinstance(audio_feature_lens_raw, torch.Tensor):
audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
audio_feature_lens = torch.hstack(audio_feature_lens_raw) audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape batch_size, _, max_mel_seq_len = wavforms.shape
@ -625,159 +683,104 @@ class MiniCPMO(MiniCPMV2_6):
num_audio_tokens = feature_lens_after_pooling num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = [] final_audio_embeds = list[torch.Tensor]()
idx = 0 idx = 0
for i in range(len(audio_feature_lens_raw)): for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = [] target_audio_embeds_lst = list[torch.Tensor]()
for _ in range(len(audio_feature_lens_raw[i])): for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append( target_audio_embeds_lst.append(
audio_embeds[idx, :num_audio_tokens[idx], :]) audio_embeds[idx, :num_audio_tokens[idx], :])
idx += 1 idx += 1
final_audio_embeds.append(target_audio_embeds)
final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
return final_audio_embeds return final_audio_embeds
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor, def _parse_and_validate_audio_input(
audio_inputs: MiniCPMOAudioInputs, self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]:
chunk_length: int) -> torch.Tensor:
device, dtype = vlm_embedding.device, vlm_embedding.dtype
if audio_inputs["type"] == "audio_embeds":
audio_embeddings = [
item.to(device=device, dtype=dtype)
for item in audio_inputs["audio_embeds"]
]
else:
audio_embeddings = self.get_audio_hidden_states(
audio_inputs, chunk_length)[0]
if audio_embeddings is None or len(audio_embeddings) == 0:
return vlm_embedding
audio_bounds = audio_inputs["audio_bounds"]
if self.config.chunk_input:
audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device,
dtype=dtype)
audio_start_pos = 0
for bound in audio_bounds:
audio_len = bound[1] - bound[0]
vlm_embedding[bound[0]:bound[1]] = audio_embs[
audio_start_pos:audio_start_pos + audio_len, :]
audio_start_pos += audio_len
else:
for embs, bound in zip(audio_embeddings, audio_bounds):
audio_indices = torch.arange(bound[0],
bound[1],
dtype=torch.long).to(device)
if embs.shape[0] != len(audio_indices):
raise ValueError(
"Shape mismatch: Trying to assign embeddings "
f"of shape {embs.shape} "
f"to input indices of length {len(audio_indices)}")
vlm_embedding[audio_indices] = embs.to(dtype)
return vlm_embedding
def _get_audio_bounds(self, input_ids: torch.Tensor,
audio_start_id: torch.Tensor,
audio_end_id: torch.Tensor) -> torch.Tensor:
audio_start_tokens, = torch.where(input_ids == audio_start_id[0])
audio_start_tokens += 1
audio_end_tokens, = torch.where(input_ids == audio_end_id[0])
valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens))
return torch.hstack([
audio_start_tokens[:valid_audio_nums].unsqueeze(-1),
audio_end_tokens[:valid_audio_nums].unsqueeze(-1)
])
def _parse_and_validate_audio_inputs(
self, input_ids: torch.Tensor,
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
audio_features = kwargs.pop("audio_features", None) audio_features = kwargs.pop("audio_features", None)
audio_embeds = kwargs.pop("audio_embeds", None) audio_embeds = kwargs.pop("audio_embeds", None)
if audio_features is None and audio_embeds is None: if audio_features is None and audio_embeds is None:
return None return None
audio_start_id = kwargs.pop("audio_start_id") audio_token_id = kwargs.pop("audio_token_id")
if not isinstance(audio_start_id, torch.Tensor): if audio_token_id is not None:
raise ValueError("Incorrect type of audio_start_id. " assert isinstance(audio_token_id, torch.Tensor)
f"Got type: {type(audio_start_id)}") self.mm_token_ids.add(audio_token_id.flatten().unique().item())
audio_end_id = kwargs.pop("audio_end_id") audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
if not isinstance(audio_end_id, torch.Tensor): if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_end_id. " raise ValueError("Incorrect type of audio_embed_is_patch. "
f"Got type: {type(audio_end_id)}") f"Got type: {type(audio_embed_is_patch)}")
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
if audio_embeds is not None: if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)): if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. " raise ValueError("Incorrect type of audio_embeds. "
f"Got type: {type(audio_embeds)}") f"Got type: {type(audio_embeds)}")
audio_embeds_flat = flatten_bn(audio_embeds)
return MiniCPMOAudioEmbeddingInputs( return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds", type="audio_embeds",
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds), audio_embeds=audio_embeds_flat,
concat=True), embed_is_patch=audio_embed_is_patch,
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
) )
if audio_features is not None: if not isinstance(audio_features, (torch.Tensor, list)):
if not isinstance(audio_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio_features. "
raise ValueError("Incorrect type of audio_features. " f"Got type: {type(audio_features)}")
f"Got type: {type(audio_features)}")
audio_feature_lens = kwargs.pop("audio_feature_lens") audio_feature_lens = kwargs.pop("audio_feature_lens")
if not isinstance(audio_feature_lens, (torch.Tensor, list)): if not isinstance(audio_feature_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_feature_lens. " raise ValueError("Incorrect type of audio_feature_lens. "
f"Got type: {type(audio_feature_lens)}") f"Got type: {type(audio_feature_lens)}")
return MiniCPMOAudioFeatureInputs( audio_features_flat = flatten_bn(audio_features)
type="audio_features", audio_feature_lens_flat = flatten_bn(audio_feature_lens)
audio_features=flatten_bn(audio_features, concat=True),
audio_feature_lens=flatten_bn(
flatten_2d_lists(audio_feature_lens), concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
)
raise AssertionError("This line should be unreachable.") return MiniCPMOAudioFeatureInputs(
type="audio_features",
def _parse_and_validate_inputs(self, input_ids: torch.Tensor, audio_features=audio_features_flat,
**kwargs: object): audio_feature_lens=audio_feature_lens_flat,
image_inputs = self._parse_and_validate_image_inputs( embed_is_patch=audio_embed_is_patch,
input_ids, **kwargs)
if not any("audio" in key for key in kwargs):
return image_inputs, None
audio_inputs = self._parse_and_validate_audio_inputs(
input_ids, **kwargs)
return image_inputs, audio_inputs
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any,
) -> torch.Tensor:
if intermediate_tensors is not None:
vlm_embeddings = None
else:
image_inputs, audio_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs)
if audio_inputs is not None:
vlm_embeddings = self.get_embedding_with_audios(
vlm_embeddings, audio_inputs,
self.config.audio_chunk_length)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
output = self.llm.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings,
) )
return output
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("audio_features",
"audio_embeds") and "audios" not in modalities:
modalities["audios"] = self._parse_and_validate_audio_input(
**kwargs)
return modalities
def _process_audio_input(
self,
audio_input: MiniCPMOAudioInputs,
) -> Union[torch.Tensor, list[torch.Tensor]]:
if audio_input["type"] == "audio_embeds":
return audio_input["audio_embeds"]
return self.get_audio_hidden_states(audio_input)
def _process_multimodal_inputs(self, modalities: dict):
multimodal_embeddings = super()._process_multimodal_inputs(modalities)
for modality in modalities:
if modality == "audios":
audio_input = modalities["audios"]
audio_features = self._process_audio_input(audio_input)
multimodal_embeddings += tuple(
scatter_patch_features(
audio_features,
audio_input["embed_is_patch"],
))
return multimodal_embeddings

View File

@ -23,17 +23,15 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math import math
import re
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
TypedDict, Union) Union)
import numpy as np import numpy as np
import torch import torch
import torch.types import torch.types
from PIL import Image
from torch import nn from torch import nn
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from typing_extensions import TypeVar from typing_extensions import TypeVar
@ -50,9 +48,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
MultiModalInputs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
ImageProcessorItems, ImageSize, ImageProcessorItems, ImageSize,
ModalityData, ModalityDataItems, ModalityData, ModalityDataItems,
@ -67,13 +63,11 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists from vllm.utils import flatten_2d_lists
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsV0Only) SupportsMultiModal, SupportsPP)
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
CPU_DEVICE = torch.device("cpu") from .vision import scatter_patch_features, select_patch_features
RawImageType = Union[Image.Image, torch.Tensor]
class MiniCPMVImagePixelInputs(TypedDict): class MiniCPMVImagePixelInputs(TypedDict):
@ -86,13 +80,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
instead of a batched tensor. instead of a batched tensor.
""" """
image_bounds: torch.Tensor
"""
Shape: `(batch_size * num_images * num_slices, 2)`
This should be in `(start, stop)` format.
"""
tgt_sizes: torch.Tensor tgt_sizes: torch.Tensor
""" """
Shape: `(batch_size * num_images * num_slices, 2)` Shape: `(batch_size * num_images * num_slices, 2)`
@ -100,23 +87,34 @@ class MiniCPMVImagePixelInputs(TypedDict):
This should be in `(height, width)` format. This should be in `(height, width)` format.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_slices: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class MiniCPMVImageEmbeddingInputs(TypedDict): class MiniCPMVImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
image_embeds: torch.Tensor image_embeds: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images * num_slices, Shape: `(batch_size * num_images, num_slices, hidden_size)`
image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor. instead of a batched tensor.
""" """
image_bounds: torch.Tensor embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images * num_slices, 2)` A boolean mask indicating which image embeddings correspond
to patch tokens.
This should be in `(start, stop)` format. Shape: `(batch_size * num_images, num_embeds)`
""" """
@ -233,15 +231,25 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
pixel_values = hf_inputs.get("pixel_values", torch.empty(0))
num_images = len(pixel_values)
video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0))
num_videos = len(video_pixel_values)
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.batched("image"), tgt_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"), video_pixel_values=MultiModalFieldConfig.batched("video"),
video_image_sizes=MultiModalFieldConfig.batched("video"), video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.batched("video"),
video_embed_is_patch=MultiModalFieldConfig.batched("video"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
) )
@ -348,10 +356,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return get_version_by_config(self.get_hf_config()) return get_version_by_config(self.get_hf_config())
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
mm_limits = {"image": None}
if self.get_model_version() == (2, 6): if self.get_model_version() == (2, 6):
return {"image": None, "video": None} mm_limits["video"] = None
else:
return {"image": None} return mm_limits
def get_mm_max_tokens_per_item( def get_mm_max_tokens_per_item(
self, self,
@ -361,70 +370,79 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
mm_max_tokens = {"image": self.get_max_image_tokens()} mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6): if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len) mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
return mm_max_tokens return mm_max_tokens
def get_slice_image_placeholder(
self,
image_size: ImageSize,
# For MiniCPM V/O 2.6
image_idx: int = 0,
max_slice_nums: Optional[int] = None,
use_image_id: bool = True,
) -> str:
image_processor = self.get_image_processor()
version = self.get_model_version()
if version == (2, 0) or version == (2, 5):
return image_processor.get_slice_image_placeholder(image_size)
return image_processor.get_slice_image_placeholder(
image_size,
image_idx=image_idx,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
)
def get_num_image_tokens(
self,
image_size: ImageSize,
max_slice_nums: Optional[int] = None,
use_image_id: bool = True,
) -> int:
tokenizer = self.get_tokenizer()
image_placeholders = self.get_slice_image_placeholder(
image_size,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
)
image_token_ids = tokenizer.encode(image_placeholders,
add_special_tokens=False)
return len(image_token_ids)
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size_with_most_features()
return self.get_num_image_tokens(image_size)
def get_image_max_slice_num(self) -> int:
return getattr(self.get_hf_config(), "max_slice_num", 9)
def get_image_size_with_most_features(self) -> ImageSize:
image_size = getattr(self.get_hf_config(), "image_size", 448)
max_slice_num = self.get_image_max_slice_num()
return ImageSize(width=image_size, height=image_size * max_slice_num)
def get_max_video_frame_tokens(self) -> int: def get_max_video_frame_tokens(self) -> int:
frame_size = self.get_video_frame_size_with_most_features() frame_size = self.get_video_frame_size_with_most_features()
return self.get_num_image_tokens(frame_size,
self.get_video_max_slice_num()) return self.get_num_image_tokens(
frame_size,
max_slice_nums=self.get_video_max_slice_num(),
use_image_id=False,
)
def get_max_video_tokens(self, seq_len: int) -> int: def get_max_video_tokens(self, seq_len: int) -> int:
return self.get_max_video_frame_tokens( return self.get_max_video_frame_tokens(
) * self.get_num_frames_with_most_features(seq_len) ) * self.get_num_frames_with_most_features(seq_len)
def get_slice_query_num(self) -> int:
hf_config = self.get_hf_config()
query_num = getattr(hf_config, "query_num", 64)
return query_num
def get_max_slice_num(self) -> int:
hf_config = self.get_hf_config()
max_slice_num = getattr(hf_config, "max_slice_num", 9)
return max_slice_num
def get_sliced_grid(self, image_size: ImageSize,
max_slice_num: int) -> Tuple[int, int]:
if self.get_model_version() == (2, 6):
slice_grid = self.get_image_processor().get_sliced_grid(
image_size, max_slice_num)
else:
slice_grid = self.get_image_processor().get_sliced_grid(image_size)
return slice_grid
def get_num_image_tokens(self, image_size: ImageSize,
max_slice_num: int) -> int:
slice_grid = self.get_sliced_grid(image_size, max_slice_num)
num_tokens = self.get_slice_query_num(
) + 2 # <image>(<unk> * query_num)</image>
if slice_grid is not None:
if self.get_model_version() == (2, 6):
num_additional_tokens = 0
else:
# <slice><image>(<unk> * query_num)</image></slice>
num_additional_tokens = 2
num_tokens += ((self.get_slice_query_num() + 2) \
* slice_grid[0] * slice_grid[1]) \
+ slice_grid[1] - 1 + num_additional_tokens
return num_tokens
def get_image_slice_nums(self, image_size: torch.Tensor,
max_slice_nums: int) -> int:
grid = self.get_sliced_grid(image_size, max_slice_nums)
return 1 if grid is None else grid[0] * grid[1] + 1
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size_with_most_features()
return self.get_num_image_tokens(image_size, self.get_max_slice_num())
def get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 9:1)
return self.get_default_image_sizes(self.get_max_slice_num())
def get_video_max_slice_num(self) -> int: def get_video_max_slice_num(self) -> int:
return 1 return 1
def get_video_frame_size_with_most_features(self) -> ImageSize: def get_video_frame_size_with_most_features(self) -> ImageSize:
return self.get_default_image_sizes(self.get_video_max_slice_num()) image_size = getattr(self.get_hf_config(), "image_size", 448)
max_slice_num = self.get_video_max_slice_num()
return ImageSize(width=image_size, height=image_size * max_slice_num)
def get_max_video_frames(self, max_tokens: int) -> int: def get_max_video_frames(self, max_tokens: int) -> int:
num_frame_tokens = self.get_max_video_frame_tokens() num_frame_tokens = self.get_max_video_frame_tokens()
@ -436,10 +454,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
max_images = mm_config.get_limit_per_prompt("image") max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video") max_videos = mm_config.get_limit_per_prompt("video")
# count <image_idx></image_idx> tokens max_image_tokens = self.get_max_image_tokens() * max_images
# which are not in get_max_image_tokens
max_image_tokens = self.get_max_image_tokens(
) * max_images + 4 * max_images
max_total_frames = self.get_max_video_frames(seq_len - max_total_frames = self.get_max_video_frames(seq_len -
max_image_tokens) max_image_tokens)
@ -447,10 +462,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return num_frames return num_frames
def get_default_image_sizes(self, num_slices: int) -> ImageSize:
image_size = getattr(self.get_hf_config(), "image_size", 448)
return ImageSize(width=image_size, height=image_size * num_slices)
_I = TypeVar("_I", _I = TypeVar("_I",
bound=MiniCPMVProcessingInfo, bound=MiniCPMVProcessingInfo,
@ -499,42 +510,30 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMVMultiModalDataParser() return MiniCPMVMultiModalDataParser()
def get_slice_image_placeholder(self, image_size: ImageSize,
**kwargs) -> str:
image_processor = self.info.get_image_processor()
version = self.info.get_model_version()
if version == (2, 0) or version == (2, 5):
return image_processor.get_slice_image_placeholder(image_size)
return image_processor.get_slice_image_placeholder(
image_size, **kwargs)
def get_image_prompt_texts(self, def get_image_prompt_texts(self,
image_size: ImageSize, image_size: ImageSize,
image_idx: int = 0) -> str: image_idx: int = 0) -> str:
return self.get_slice_image_placeholder(image_size, return self.info.get_slice_image_placeholder(
image_idx=image_idx) image_size,
image_idx=image_idx,
)
def get_video_prompt_texts(self, image_size: ImageSize, def get_video_prompt_texts(self, image_size: ImageSize,
num_frames: int) -> str: num_frames: int) -> str:
return self.get_slice_image_placeholder( return self.info.get_slice_image_placeholder(
image_size=image_size, image_size=image_size,
image_idx=0, image_idx=0,
max_slice_nums=self.info.get_video_max_slice_num(), max_slice_nums=self.info.get_video_max_slice_num(),
use_image_id=False, use_image_id=False,
) * num_frames ) * num_frames
def get_special_tokens(self) -> Dict[str, torch.Tensor]: def get_embed_is_patch(
self,
input_ids: list[int],
) -> torch.Tensor:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
special_tokens = { return torch.tensor(input_ids) == unk_token_id
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
}
if hasattr(tokenizer, "slice_start_id"):
special_tokens["slice_start_id"] = tokenizer.slice_start_id
special_tokens["slice_end_id"] = tokenizer.slice_end_id
return {k: torch.tensor(v) for k, v in special_tokens.items()}
def process_images( def process_images(
self, self,
@ -546,14 +545,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
parsed_images = (self._get_data_parser().parse_mm_data({ parsed_images = (self._get_data_parser().parse_mm_data({
"image": images "image": images
}).get_items("image", ImageProcessorItems)) }).get_items("image",
(MiniCPMVImageEmbeddingItems, ImageProcessorItems)))
return self._base_call_hf_processor( if isinstance(parsed_images, MiniCPMVImageEmbeddingItems):
prompts=[self.info.image_pattern] * len(parsed_images), image_inputs = {}
mm_data={"images": [[image] for image in parsed_images]}, else:
mm_kwargs=mm_kwargs, image_inputs = self._base_call_hf_processor(
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, prompts=[self.info.image_pattern] * len(parsed_images),
) mm_data={"images": [[image] for image in parsed_images]},
mm_kwargs=mm_kwargs,
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
image_repl_features = [
self.get_image_prompt_texts(size, idx)
for idx, size in enumerate(image_sizes)
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(image_repl_tokens)
for image_repl_tokens in image_repls_feature_tokens
]
image_inputs["embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"]
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
return image_inputs
def process_videos( def process_videos(
self, self,
@ -565,25 +593,55 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
parsed_videos = (self._get_data_parser().parse_mm_data({ parsed_videos = (self._get_data_parser().parse_mm_data({
"video": videos "video": videos
}).get_items("video", VideoProcessorItems)) }).get_items("video",
(MiniCPMVVideoEmbeddingItems, VideoProcessorItems)))
max_slice_num = self.info.get_video_max_slice_num() if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems):
video_inputs = {}
else:
video_inputs = self._base_call_hf_processor(
prompts=[
self.info.image_pattern * len(video)
for video in parsed_videos
],
mm_data={"images": list(parsed_videos)},
mm_kwargs={
**mm_kwargs,
"max_slice_nums":
self.info.get_video_max_slice_num(),
},
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
video_inputs = self._base_call_hf_processor( frame_sizes = [
prompts=[ parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
self.info.image_pattern * len(video) for video in parsed_videos ]
], num_frames = [
mm_data={"images": list(parsed_videos)}, parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
mm_kwargs={ ]
**mm_kwargs, "max_slice_nums": max_slice_num video_repl_features = [
}, self.get_video_prompt_texts(size, nframes)
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, for size, nframes in zip(frame_sizes, num_frames)
) ]
return {f"video_{k}": v for k, v in video_inputs.items()} tokenizer = self.info.get_tokenizer()
video_repls_feature_tokens = [
tokenizer.encode(video_repl, add_special_tokens=False)
for video_repl in video_repl_features
]
def get_placeholder_match_pattern(self) -> str: embed_is_patch = [
return r"\(<(image|video)>./</\1>\)" self.get_embed_is_patch(video_repl_tokens)
for video_repl_tokens in video_repls_feature_tokens
]
video_inputs["embed_is_patch"] = embed_is_patch
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
unk_token_id = tokenizer.get_vocab()["<unk>"]
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
return video_inputs
def process_mm_inputs( def process_mm_inputs(
self, self,
@ -602,7 +660,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
*, *,
out_keys: set[str], out_keys: set[str],
) -> Mapping[str, NestedTensors]: ) -> dict[str, NestedTensors]:
# This processor supports zipping prompt and mm_data together # This processor supports zipping prompt and mm_data together
if self.info.get_model_version() == (2, 6): if self.info.get_model_version() == (2, 6):
inputs = super()._call_hf_processor( inputs = super()._call_hf_processor(
@ -635,14 +693,13 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
# Do not support combination inputs of images and videos for now
# Try to handle interleaved multimodal data
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
input_ids = torch.tensor([tokenizer.encode(prompt)])
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs) mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
return BatchFeature({ return BatchFeature({
"input_ids": "input_ids": input_ids,
torch.tensor([tokenizer.encode(prompt)]),
**mm_inputs, **mm_inputs,
}) })
@ -701,39 +758,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return _minicpmv_field_config(hf_inputs) return _minicpmv_field_config(hf_inputs)
def apply(
self,
prompt: Union[str, List[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
if isinstance(prompt, list):
prompt = self.info.get_tokenizer().decode(prompt)
matches = re.findall(self.get_placeholder_match_pattern(), prompt)
mm_orders = {
f"{modality}_orders":
torch.tensor(
[index for index, m in enumerate(matches) if m == modality])
for modality in self.info.get_supported_mm_limits()
}
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
# Exclude <image_id>x</image_id> from placeholders
if "image" in result["mm_placeholders"] and \
self.info.get_model_version() == (2, 6):
result["mm_placeholders"]["image"] = [
PlaceholderRange(offset=p["offset"] + 3 + idx // 10,
length=p["length"] - 3 - idx // 10)
for idx, p in enumerate(result["mm_placeholders"]["image"])
]
result["mm_kwargs"].update(**mm_orders)
result["mm_kwargs"].update(**self.get_special_tokens())
return result
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsV0Only):
""" """
The abstract class of MiniCPMV can only be inherited, but cannot be The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated. instantiated.
@ -767,6 +793,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
prefix=maybe_prefix( prefix=maybe_prefix(
prefix, "resampler")) prefix, "resampler"))
self.mm_token_ids = set[int]()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors) self.llm.make_empty_intermediate_tensors)
@ -777,233 +804,191 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP,
return get_sampler() return get_sampler()
def get_embedding_with_vision( def _parse_and_validate_vision_input(
self, self,
input_ids: torch.Tensor, modality: str,
image_inputs: Optional[MiniCPMVImageInputs],
) -> torch.Tensor:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None:
return vlm_embedding
if image_inputs["type"] == "image_embeds":
vision_hidden_states = image_inputs["image_embeds"].to(
device=vlm_embedding.device,
dtype=vlm_embedding.dtype,
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack([
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding
def _get_image_bounds(
self,
input_ids: torch.Tensor,
im_start_id: torch.Tensor,
im_end_id: torch.Tensor,
slice_start_id: Optional[torch.Tensor] = None,
slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor:
# All the images in the batch should share the same special image
# bound token ids.
start_cond = input_ids == im_start_id[0]
end_cond = input_ids == im_end_id[0]
if slice_start_id is not None:
start_cond |= (input_ids == slice_start_id[0])
end_cond |= (input_ids == slice_end_id[0])
image_start_tokens, = torch.where(start_cond)
image_start_tokens += 1
image_end_tokens, = torch.where(end_cond)
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
if valid_image_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
return torch.hstack([
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
])
def _parse_and_validate_image_inputs(
self,
input_ids: torch.Tensor,
**kwargs: object, **kwargs: object,
) -> Optional[MiniCPMVImageInputs]: ) -> Optional[MiniCPMVImageInputs]:
image_keys = {"pixel_values", "tgt_sizes"} pixel_values = kwargs.pop("pixel_values", None)
pixel_data = { image_embeds = kwargs.pop("image_embeds", None)
"image": {
key: kwargs.pop(key, None)
for key in image_keys
},
"video": {
key: kwargs.pop("video_" + key, None)
for key in image_keys
}
}
embed_data = {
"image": kwargs.pop("image_embeds", None),
"video": kwargs.pop("video_embeds", None),
}
all_pixel_data = [ if pixel_values is None and image_embeds is None:
v for vs in pixel_data.values() for v in vs.values()
if v is not None
]
all_embed_data = [v for v in embed_data.values() if v is not None]
if len(all_pixel_data) == 0 and len(all_embed_data) == 0:
return None return None
im_start_id = kwargs.pop("im_start_id") image_token_id = kwargs.pop("image_token_id")
if not isinstance(im_start_id, torch.Tensor): if image_token_id is not None:
raise ValueError("Incorrect type of im_start_id. " assert isinstance(image_token_id, torch.Tensor)
f"Got type: {type(im_start_id)}") self.mm_token_ids.add(image_token_id.flatten().unique().item())
im_end_id = kwargs.pop("im_end_id") embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(im_end_id, torch.Tensor): if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of im_end_id. " raise ValueError(
f"Got type: {type(im_end_id)}") f"Incorrect type of embed_is_patch for {modality=}. "
f"Got type: {type(embed_is_patch)}")
slice_start_id = kwargs.pop("slice_start_id", None) embed_is_patch = flatten_bn(embed_is_patch)
if slice_start_id is not None and not isinstance(
slice_start_id, torch.Tensor):
raise ValueError("Incorrect type of slice_start_id. "
f"Got type: {type(slice_start_id)}")
slice_end_id = kwargs.pop("slice_end_id", None) if image_embeds is not None:
if slice_end_id is not None and not isinstance(slice_end_id, if not isinstance(image_embeds, (torch.Tensor, list)):
torch.Tensor): raise ValueError(
raise ValueError("Incorrect type of slice_end_id. " f"Incorrect type of image_embeds for {modality=}. "
f"Got type: {type(slice_end_id)}") f"Got type: {type(image_embeds)}")
if len(all_embed_data) > 0: image_embeds_flat = flatten_bn(image_embeds)
if len(all_embed_data) > 1:
raise ValueError("Incorrect inputs for vision embeddings. "
"Image embeds and video embeds can not "
"exist simultaneously.")
vision_embeds, = all_embed_data
if not isinstance(vision_embeds, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of vision_embeds. "
f"Got type: {type(vision_embeds)}")
return MiniCPMVImageEmbeddingInputs( return MiniCPMVImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
image_embeds=flatten_bn(flatten_2d_lists(vision_embeds), image_embeds=image_embeds_flat,
concat=True), embed_is_patch=embed_is_patch,
image_bounds=self._get_image_bounds(input_ids, im_start_id,
im_end_id, slice_start_id,
slice_end_id),
) )
order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]() if not isinstance(pixel_values, (torch.Tensor, list)):
for modality in ("image", "video"): raise ValueError(
modality_orders = kwargs.pop(f"{modality}_orders", None) f"Incorrect type of pixel_values for {modality=}. "
if modality_orders is not None: f"Got type: {type(pixel_values)}")
if not isinstance(modality_orders, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {modality}_orders. "
f"Got type: {type(modality_orders)}")
order_data[modality] = modality_orders tgt_sizes = kwargs.pop("tgt_sizes")
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. "
f"Got type: {type(tgt_sizes)}")
batch_sizes = { num_slices = [[len(p) for p in ps] for ps in pixel_values]
modality: len(modality_orders) num_slices_flat = flatten_bn(torch.tensor(num_slices))
for modality, modality_orders in order_data.items()
}
unique_batch_sizes = set(batch_sizes.values())
assert len(unique_batch_sizes) == 1, (
f"Found inconsistent batch sizes: {batch_sizes}")
batch_size, = unique_batch_sizes
pixel_values_flat = list[torch.Tensor]() pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
tgt_sizes_flat = list[torch.Tensor]() tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
for b in range(batch_size):
mm_orders_b = [(idx_b.item(), modality)
for modality, modality_orders in order_data.items()
for idx_b in modality_orders[b]]
for _, modality in sorted(mm_orders_b, key=lambda x: x[0]):
modality_pixel_data = pixel_data[modality]
modality_pixel_values = modality_pixel_data["pixel_values"]
if not isinstance(modality_pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel_values for {modality=}. "
f"Got type: {type(modality_pixel_values)}")
modality_tgt_sizes = modality_pixel_data["tgt_sizes"]
if not isinstance(modality_tgt_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of tgt_sizes for {modality=}. "
f"Got type: {type(modality_tgt_sizes)}")
pixel_values_flat += flatten_2d_lists(modality_pixel_values[b])
tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b])
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
if len(pixel_values_flat) != len(tgt_sizes_flat): if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError("Inconsistent flattened lengths, found: " raise ValueError("Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. " f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}") f"{len(tgt_sizes_flat)}")
if len(pixel_values_flat) == 0:
return None
return MiniCPMVImagePixelInputs( return MiniCPMVImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=pixel_values_flat, pixel_values=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat), tgt_sizes=tgt_sizes_flat,
image_bounds=self._get_image_bounds(input_ids, im_start_id, embed_is_patch=embed_is_patch,
im_end_id, slice_start_id, num_slices=num_slices_flat,
slice_end_id),
) )
def _parse_and_validate_inputs(self, input_ids: torch.Tensor, def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
**kwargs: object): modalities = {}
return self._parse_and_validate_image_inputs(input_ids, **kwargs)
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_vision_input(
"images", **kwargs)
if input_key in ("video_pixel_values",
"video_embeds") and "videos" not in modalities:
def _image_key(video_key: str):
if video_key == "video_token_id":
return "image_token_id"
return video_key.removeprefix("video_")
modalities["videos"] = self._parse_and_validate_vision_input(
"videos", **{
_image_key(k): v
for k, v in kwargs.items()
})
return modalities
def _process_vision_input(
self,
image_input: MiniCPMVImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["image_embeds"]
image_features_flat = self.get_vision_hidden_states(image_input)
# Reconstruct the batch dimension
return image_features_flat.split(image_input["num_slices"].tolist())
def _process_multimodal_inputs(self, modalities: dict):
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
image_features = self._process_vision_input(image_input)
multimodal_embeddings += tuple(
scatter_patch_features(
image_features,
image_input["embed_is_patch"],
))
if modality == "videos":
video_input = modalities["videos"]
video_features = self._process_vision_input(video_input)
multimodal_embeddings += tuple(
scatter_patch_features(
video_features,
video_input["embed_is_patch"],
))
return multimodal_embeddings
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
return self._process_multimodal_inputs(modalities)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
assert len(self.mm_token_ids) > 0
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
list(self.mm_token_ids),
)
return inputs_embeds
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
if intermediate_tensors is not None: if intermediate_tensors is not None:
vlm_embeddings = None inputs_embeds = None
else:
image_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs)
# always pass the input via `inputs_embeds` # NOTE: In v1, inputs_embeds is always generated at model runner from
# to make sure the computation graph is consistent # `get_multimodal_embeddings` and `get_input_embeddings`, this
# for `torch.compile` integration # condition is only for v0 compatibility.
input_ids = None elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
output = self.llm.model( inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.llm.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings, inputs_embeds=inputs_embeds,
) )
return output return hidden_states
def compute_logits( def compute_logits(
self, self,
@ -1105,9 +1090,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return model return model
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)
def init_resampler(self, def init_resampler(self,
embed_dim: int, embed_dim: int,
vision_dim: int, vision_dim: int,

View File

@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_embeds)` Shape: `(batch_size * num_images, num_embeds)`
""" """
num_crops: Union[torch.Tensor, list[torch.Tensor]] num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`""" """Shape: `(batch_size * num_images)`"""
@dataclass @dataclass
@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self.img_patch_id = img_patch_id.flatten().unique().item() self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch) embed_is_patch = flatten_bn(embed_is_patch)
num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs( return MolmoImageInputs(
images=images, images=images,
@ -1510,31 +1511,24 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feat_is_patch = image_input["feat_is_patch"] feat_is_patch = image_input["feat_is_patch"]
num_crops = image_input["num_crops"] num_crops = image_input["num_crops"]
if isinstance(images, list): # Call the vision backbone on the whole batch at once
# Call the vision backbone on the whole batch at once images_flat = flatten_bn(images, concat=True)
images_flat = flatten_bn(images, concat=True) image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks_flat = (None if image_masks is None else flatten_bn( image_masks, concat=True))
image_masks, concat=True)) feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
image_features_flat = self.vision_backbone( image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0), images=images_flat.unsqueeze(0),
image_masks=(None if image_masks_flat is None else image_masks=(None if image_masks_flat is None else
image_masks_flat.unsqueeze(0)), image_masks_flat.unsqueeze(0)),
).squeeze(0) ).squeeze(0)
# Reconstruct the batch dimension
num_crops_per_image = [nc.sum().item() for nc in num_crops]
image_features = image_features_flat.split(num_crops_per_image)
else:
image_features = self.vision_backbone(
images=images,
image_masks=image_masks,
)
# Only the features corresponding to patch tokens are relevant # Only the features corresponding to patch tokens are relevant
return [ return [
feats[f_is_patch] feats[f_is_patch] for feats, f_is_patch in zip(
for feats, f_is_patch in zip(image_features, feat_is_patch) image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
)
] ]
def get_multimodal_embeddings( def get_multimodal_embeddings(