[BugFix][Multi Modal] Fix TensorSchema shape mismatch in Molmo (#24559)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
Wenlong Wang 2025-09-10 06:14:27 -07:00 committed by GitHub
parent f36355abfd
commit 4c04eef706
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -76,20 +76,22 @@ class MolmoImageInputs(TensorSchema):
""" """
Dimensions: Dimensions:
- bn: Batch size * number of images - bn: Batch size * number of images
- nc: Number of crops - nc: Number of crops (dynamic)
- np: Number of patches - np: Number of patches
- tp: Token sequence positions
- pd: Patch dimension - pd: Patch dimension
""" """
images: Annotated[Union[torch.Tensor, list[torch.Tensor]], images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "np", "pd")] TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"})]
# Number of crops may vary per batch and image, so pass it as a list.
image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]], image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]],
TensorShape("bn", "nc", "np")] TensorShape("bn", "nc", "np", dynamic_dims={"nc"})]
feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]], feat_is_patch: Annotated[
TensorShape("bn", "nc", "np")] Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"})]
# A boolean mask indicating which image features correspond to patch tokens. # A boolean mask indicating which image features correspond to patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")] num_crops: Annotated[torch.Tensor, TensorShape("bn")]