[Model] Support multi-image for Molmo (#15438)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-26 11:26:33 +08:00 committed by GitHub
parent e42389f9d7
commit 997c8811d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 35 deletions

View File

@ -853,7 +853,7 @@ See [this page](#generative-models) for more information on how to use generativ
*
- * `MolmoForCausalLM`
* Molmo
* T + I
* T + I<sup>+</sup>
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
* ✅︎
* ✅︎

View File

@ -431,7 +431,7 @@ VLM_TEST_SETTINGS = {
),
"molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE),
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=identity,
max_model_len=4096,
max_num_seqs=2,

View File

@ -57,7 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
from .vision import select_patch_features
from .vision import scatter_patch_features, select_patch_features
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
@ -71,13 +71,13 @@ POOLING_SIZE = 2
class MolmoImageInputs(TypedDict):
images: Union[torch.Tensor, List[torch.Tensor]]
images: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
image_masks: Optional[Union[torch.Tensor, List[torch.Tensor]]]
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
"""Shape: `(batch_size, num_crops, num_patch)`"""
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
@ -85,7 +85,7 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
@ -93,7 +93,7 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size, num_embeds)`
"""
num_crops: torch.Tensor
num_crops: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
@ -1144,13 +1144,7 @@ class MolmoProcessorWrapper:
image_input_idx = outputs.pop("image_input_idx", None)
if image_input_idx is not None:
input_is_patch = input_ids == self.image_patch_id
image_input_idx_flat: torch.Tensor = image_input_idx.view(-1)
image_valid_flat = image_input_idx_flat >= 0
feat_is_patch_flat = image_valid_flat.clone()
feat_is_patch_flat[image_valid_flat] = (
input_is_patch[image_input_idx_flat[image_valid_flat]])
feat_is_patch = feat_is_patch_flat.view(*image_input_idx.shape)
feat_is_patch = image_input_idx >= 0
input_is_embed = torch.isin(
input_ids,
@ -1165,6 +1159,17 @@ class MolmoProcessorWrapper:
embed_is_patch = embed_ids == self.image_patch_id
assert embed_is_patch.sum() == feat_is_patch.sum()
# image_tokens = extra_joint + joint
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
assert len(embed_start) == len(embed_end) == len(images)
embed_is_patch = [
embed_is_patch[start:end + 1]
for start, end in zip(embed_start, embed_end)
]
tilings = [
self.select_tiling(
image_width=image.size[0],
@ -1180,7 +1185,7 @@ class MolmoProcessorWrapper:
outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id
return BatchFeature(outputs, tensor_type=return_tensors)
return BatchFeature(outputs)
class MolmoProcessingInfo(BaseProcessingInfo):
@ -1190,9 +1195,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
# TODO: Investigate different `embed_is_patch` between cache/no-cache
# in multi-image case
return {"image": 1}
return {"image": None}
def get_mm_max_tokens_per_item(
self,
@ -1325,7 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
"image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.shared("image", num_images),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
)
@ -1499,7 +1502,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def _process_image_input(
self,
image_input: MolmoImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
) -> Union[torch.Tensor, list[torch.Tensor]]:
if isinstance(image_input["images"], list):
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(image_input["images"], concat=True)
@ -1530,7 +1533,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> list[torch.Tensor]:
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
@ -1565,16 +1568,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
feats_per_image = features.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
_, _, embed_dim = features.shape
(num_embeds, ) = embed_is_patch.shape
features = torch.cat([
feats[f_is_patch]
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
])
embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)
return embeds_in_batch
return scatter_patch_features(features, embed_is_patch)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

View File

@ -155,7 +155,7 @@ def resolve_visual_encoder_outputs(
def scatter_patch_features(
features: torch.Tensor,
embed_is_patch: torch.Tensor,
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]],
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
@ -194,14 +194,19 @@ def scatter_patch_features(
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_images, num_embeds = embed_is_patch.shape
num_embeds_per_image = [num_embeds] * num_images
num_embeds_per_image = [
e_is_patch.numel() for e_is_patch in embed_is_patch
]
if isinstance(embed_is_patch, torch.Tensor):
embed_is_patch_flat = embed_is_patch.view(-1)
else:
embed_is_patch_flat = torch.cat(embed_is_patch)
embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
embeds_flat[embed_is_patch_flat] = features.flatten(0, -2)
return embeds_flat.split(num_embeds_per_image)