[Bugfix][Multi Modal] Fix incorrect Molmo image processing (#26563)

Signed-off-by: sanghol <sanghol@allenai.org>
This commit is contained in:
sangho.lee 2025-10-11 00:28:23 -05:00 committed by GitHub
parent ddaff2938e
commit 55392bc879
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -114,11 +114,11 @@ class MolmoImageInputs(TensorSchema):
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}), TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
] ]
feat_is_patch: Annotated[ image_input_idx: Annotated[
Union[torch.Tensor, list[torch.Tensor]], Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}), TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
] ]
# A boolean mask indicating which image features correspond to patch tokens. # An index tensor that maps image features to their corresponding patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")] num_crops: Annotated[torch.Tensor, TensorShape("bn")]
@ -1177,7 +1177,7 @@ class MolmoProcessorWrapper:
num_crops = torch.tensor(tilings).prod(-1) + 1 num_crops = torch.tensor(tilings).prod(-1) + 1
assert num_crops.sum() == len(feat_is_patch) assert num_crops.sum() == len(feat_is_patch)
outputs["feat_is_patch"] = feat_is_patch outputs["image_input_idx"] = image_input_idx
outputs["num_crops"] = num_crops outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id outputs["img_patch_id"] = self.image_patch_id
@ -1211,8 +1211,9 @@ class MolmoProcessingInfo(BaseProcessingInfo):
image_token_length_w = processor.image_token_length_w image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h image_token_length_h = processor.image_token_length_h
extra = image_token_length_w * image_token_length_h # Calculate total tokens: 2 for start/end + (w+1)*h for column separators
joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size) extra = 2 + (image_token_length_w + 1) * image_token_length_h
joint = 2 + ((ncols + 1) // pooling_size + 1) * ((nrows + 1) // pooling_size)
return extra + joint return extra + joint
@ -1299,7 +1300,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
return dict( return dict(
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops), images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops), image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops), image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"), num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images), img_patch_id=MultiModalFieldConfig.shared("image", num_images),
) )
@ -1444,7 +1445,7 @@ class MolmoForCausalLM(
) -> Optional[MolmoImageInputs]: ) -> Optional[MolmoImageInputs]:
images = kwargs.pop("images", None) images = kwargs.pop("images", None)
image_masks = kwargs.pop("image_masks", None) image_masks = kwargs.pop("image_masks", None)
feat_is_patch = kwargs.pop("feat_is_patch", None) image_input_idx = kwargs.pop("image_input_idx", None)
num_crops = kwargs.pop("num_crops", None) num_crops = kwargs.pop("num_crops", None)
if images is None: if images is None:
@ -1466,7 +1467,7 @@ class MolmoForCausalLM(
return MolmoImageInputs( return MolmoImageInputs(
images=images, images=images,
image_masks=image_masks, image_masks=image_masks,
feat_is_patch=feat_is_patch, image_input_idx=image_input_idx,
num_crops=num_crops, num_crops=num_crops,
) )
@ -1476,7 +1477,7 @@ class MolmoForCausalLM(
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
images = image_input["images"] images = image_input["images"]
image_masks = image_input["image_masks"] image_masks = image_input["image_masks"]
feat_is_patch = image_input["feat_is_patch"] image_input_idx = image_input["image_input_idx"]
num_crops = image_input["num_crops"] num_crops = image_input["num_crops"]
# Call the vision backbone on the whole batch at once # Call the vision backbone on the whole batch at once
@ -1484,7 +1485,7 @@ class MolmoForCausalLM(
image_masks_flat = ( image_masks_flat = (
None if image_masks is None else flatten_bn(image_masks, concat=True) None if image_masks is None else flatten_bn(image_masks, concat=True)
) )
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True) image_input_idx_flat = flatten_bn(image_input_idx, concat=True)
image_features_flat = self.vision_backbone( image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0), images=images_flat.unsqueeze(0),
@ -1494,13 +1495,18 @@ class MolmoForCausalLM(
).squeeze(0) ).squeeze(0)
# Only the features corresponding to patch tokens are relevant # Only the features corresponding to patch tokens are relevant
return [ # Re-order the features using the image_input_idx tensor
feats[f_is_patch] results = []
for feats, f_is_patch in zip( num_crops_list = num_crops.tolist()
image_features_flat.split(num_crops.tolist()), for feats, img_idx in zip(
feat_is_patch_flat.split(num_crops.tolist()), image_features_flat.split(num_crops_list),
) image_input_idx_flat.split(num_crops_list),
] ):
is_valid = img_idx >= 0
valid_img_idx = img_idx[is_valid]
order = torch.argsort(valid_img_idx)
results.append(feats[is_valid][order])
return results
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model