[BugFix][Multi Modal] Fix TensorSchema shape mismatch in Molmo (#24559)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
Wenlong Wang 2025-09-10 06:14:27 -07:00 committed by GitHub
parent f36355abfd
commit 4c04eef706
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -76,20 +76,22 @@ class MolmoImageInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- nc: Number of crops
- nc: Number of crops (dynamic)
- np: Number of patches
- tp: Token sequence positions
- pd: Patch dimension
"""
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "np", "pd")]
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"})]
# Number of crops may vary per batch and image, so pass it as a list.
image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]],
TensorShape("bn", "nc", "np")]
TensorShape("bn", "nc", "np", dynamic_dims={"nc"})]
feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "np")]
feat_is_patch: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"})]
# A boolean mask indicating which image features correspond to patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")]