diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 4fb571122abbf..7e1d478562a4c 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -55,14 +55,15 @@ class FuyuImagePatchInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - - fn: Num channels * patch_size_x * patch_size_y + - bnp: Batch size * number of images * number of patches + - fn: patch_size_x * patch_size_y * num_channels """ type: Literal["image_patches"] = "image_patches" flat_data: Annotated[ torch.Tensor, - TensorShape("bn", "fn"), + TensorShape("bnp", "fn"), ] patches_per_image: Annotated[list[int], TensorShape("bn")] @@ -309,8 +310,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): image_patches = kwargs.pop("image_patches", None) if image_patches is not None: image_patches_flat = flatten_bn(image_patches) - flat_data = flatten_bn(image_patches, concat=True).data.to( - self.vision_embed_tokens.weight.dtype) + flat_data = flatten_bn(image_patches_flat, concat=True) + return FuyuImagePatchInputs( type="image_patches", flat_data=flat_data,