mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 07:37:59 +08:00
[Bugfix] Fix shape checking for Fuyu (#21709)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
18cc33dd60
commit
139a97ec56
@ -55,14 +55,15 @@ class FuyuImagePatchInputs(TensorSchema):
|
|||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
- bn: Batch size * number of images
|
- bn: Batch size * number of images
|
||||||
- fn: Num channels * patch_size_x * patch_size_y
|
- bnp: Batch size * number of images * number of patches
|
||||||
|
- fn: patch_size_x * patch_size_y * num_channels
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["image_patches"] = "image_patches"
|
type: Literal["image_patches"] = "image_patches"
|
||||||
|
|
||||||
flat_data: Annotated[
|
flat_data: Annotated[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
TensorShape("bn", "fn"),
|
TensorShape("bnp", "fn"),
|
||||||
]
|
]
|
||||||
|
|
||||||
patches_per_image: Annotated[list[int], TensorShape("bn")]
|
patches_per_image: Annotated[list[int], TensorShape("bn")]
|
||||||
@ -309,8 +310,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_patches = kwargs.pop("image_patches", None)
|
image_patches = kwargs.pop("image_patches", None)
|
||||||
if image_patches is not None:
|
if image_patches is not None:
|
||||||
image_patches_flat = flatten_bn(image_patches)
|
image_patches_flat = flatten_bn(image_patches)
|
||||||
flat_data = flatten_bn(image_patches, concat=True).data.to(
|
flat_data = flatten_bn(image_patches_flat, concat=True)
|
||||||
self.vision_embed_tokens.weight.dtype)
|
|
||||||
return FuyuImagePatchInputs(
|
return FuyuImagePatchInputs(
|
||||||
type="image_patches",
|
type="image_patches",
|
||||||
flat_data=flat_data,
|
flat_data=flat_data,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user