[Bugfix][Multi Modal] Fix incorrect Molmo image processing (#26563)

Signed-off-by: sanghol <sanghol@allenai.org>
This commit is contained in:
sangho.lee 2025-10-11 00:28:23 -05:00 committed by GitHub
parent ddaff2938e
commit 55392bc879
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -114,11 +114,11 @@ class MolmoImageInputs(TensorSchema):
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
]
feat_is_patch: Annotated[
image_input_idx: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
]
# A boolean mask indicating which image features correspond to patch tokens.
# An index tensor that maps image features to their corresponding patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
@ -1177,7 +1177,7 @@ class MolmoProcessorWrapper:
num_crops = torch.tensor(tilings).prod(-1) + 1
assert num_crops.sum() == len(feat_is_patch)
outputs["feat_is_patch"] = feat_is_patch
outputs["image_input_idx"] = image_input_idx
outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id
@ -1211,8 +1211,9 @@ class MolmoProcessingInfo(BaseProcessingInfo):
image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h
extra = image_token_length_w * image_token_length_h
joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators
extra = 2 + (image_token_length_w + 1) * image_token_length_h
joint = 2 + ((ncols + 1) // pooling_size + 1) * ((nrows + 1) // pooling_size)
return extra + joint
@ -1299,7 +1300,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
return dict(
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
)
@ -1444,7 +1445,7 @@ class MolmoForCausalLM(
) -> Optional[MolmoImageInputs]:
images = kwargs.pop("images", None)
image_masks = kwargs.pop("image_masks", None)
feat_is_patch = kwargs.pop("feat_is_patch", None)
image_input_idx = kwargs.pop("image_input_idx", None)
num_crops = kwargs.pop("num_crops", None)
if images is None:
@ -1466,7 +1467,7 @@ class MolmoForCausalLM(
return MolmoImageInputs(
images=images,
image_masks=image_masks,
feat_is_patch=feat_is_patch,
image_input_idx=image_input_idx,
num_crops=num_crops,
)
@ -1476,7 +1477,7 @@ class MolmoForCausalLM(
) -> list[torch.Tensor]:
images = image_input["images"]
image_masks = image_input["image_masks"]
feat_is_patch = image_input["feat_is_patch"]
image_input_idx = image_input["image_input_idx"]
num_crops = image_input["num_crops"]
# Call the vision backbone on the whole batch at once
@ -1484,7 +1485,7 @@ class MolmoForCausalLM(
image_masks_flat = (
None if image_masks is None else flatten_bn(image_masks, concat=True)
)
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
image_input_idx_flat = flatten_bn(image_input_idx, concat=True)
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
@ -1494,13 +1495,18 @@ class MolmoForCausalLM(
).squeeze(0)
# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch]
for feats, f_is_patch in zip(
image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
)
]
# Re-order the features using the image_input_idx tensor
results = []
num_crops_list = num_crops.tolist()
for feats, img_idx in zip(
image_features_flat.split(num_crops_list),
image_input_idx_flat.split(num_crops_list),
):
is_valid = img_idx >= 0
valid_img_idx = img_idx[is_valid]
order = torch.argsort(valid_img_idx)
results.append(feats[is_valid][order])
return results
def get_language_model(self) -> torch.nn.Module:
return self.model