[Bugfix] Fix shape checking for Fuyu (#21709)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-07-28 15:05:56 +08:00 committed by GitHub
parent 18cc33dd60
commit 139a97ec56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -55,14 +55,15 @@ class FuyuImagePatchInputs(TensorSchema):
""" """
Dimensions: Dimensions:
- bn: Batch size * number of images - bn: Batch size * number of images
- fn: Num channels * patch_size_x * patch_size_y - bnp: Batch size * number of images * number of patches
- fn: patch_size_x * patch_size_y * num_channels
""" """
type: Literal["image_patches"] = "image_patches" type: Literal["image_patches"] = "image_patches"
flat_data: Annotated[ flat_data: Annotated[
torch.Tensor, torch.Tensor,
TensorShape("bn", "fn"), TensorShape("bnp", "fn"),
] ]
patches_per_image: Annotated[list[int], TensorShape("bn")] patches_per_image: Annotated[list[int], TensorShape("bn")]
@ -309,8 +310,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_patches = kwargs.pop("image_patches", None) image_patches = kwargs.pop("image_patches", None)
if image_patches is not None: if image_patches is not None:
image_patches_flat = flatten_bn(image_patches) image_patches_flat = flatten_bn(image_patches)
flat_data = flatten_bn(image_patches, concat=True).data.to( flat_data = flatten_bn(image_patches_flat, concat=True)
self.vision_embed_tokens.weight.dtype)
return FuyuImagePatchInputs( return FuyuImagePatchInputs(
type="image_patches", type="image_patches",
flat_data=flat_data, flat_data=flat_data,