mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:55:00 +08:00
[Bugfix][Multi Modal] Fix incorrect Molmo image processing (#26563)
Signed-off-by: sanghol <sanghol@allenai.org>
This commit is contained in:
parent
ddaff2938e
commit
55392bc879
@ -114,11 +114,11 @@ class MolmoImageInputs(TensorSchema):
|
|||||||
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
|
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
|
||||||
]
|
]
|
||||||
|
|
||||||
feat_is_patch: Annotated[
|
image_input_idx: Annotated[
|
||||||
Union[torch.Tensor, list[torch.Tensor]],
|
Union[torch.Tensor, list[torch.Tensor]],
|
||||||
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
|
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
|
||||||
]
|
]
|
||||||
# A boolean mask indicating which image features correspond to patch tokens.
|
# An index tensor that maps image features to their corresponding patch tokens.
|
||||||
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
|
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
|
|
||||||
|
|
||||||
@ -1177,7 +1177,7 @@ class MolmoProcessorWrapper:
|
|||||||
num_crops = torch.tensor(tilings).prod(-1) + 1
|
num_crops = torch.tensor(tilings).prod(-1) + 1
|
||||||
assert num_crops.sum() == len(feat_is_patch)
|
assert num_crops.sum() == len(feat_is_patch)
|
||||||
|
|
||||||
outputs["feat_is_patch"] = feat_is_patch
|
outputs["image_input_idx"] = image_input_idx
|
||||||
outputs["num_crops"] = num_crops
|
outputs["num_crops"] = num_crops
|
||||||
outputs["img_patch_id"] = self.image_patch_id
|
outputs["img_patch_id"] = self.image_patch_id
|
||||||
|
|
||||||
@ -1211,8 +1211,9 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
|||||||
image_token_length_w = processor.image_token_length_w
|
image_token_length_w = processor.image_token_length_w
|
||||||
image_token_length_h = processor.image_token_length_h
|
image_token_length_h = processor.image_token_length_h
|
||||||
|
|
||||||
extra = image_token_length_w * image_token_length_h
|
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators
|
||||||
joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
|
extra = 2 + (image_token_length_w + 1) * image_token_length_h
|
||||||
|
joint = 2 + ((ncols + 1) // pooling_size + 1) * ((nrows + 1) // pooling_size)
|
||||||
|
|
||||||
return extra + joint
|
return extra + joint
|
||||||
|
|
||||||
@ -1299,7 +1300,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
|||||||
return dict(
|
return dict(
|
||||||
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
||||||
image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
||||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
||||||
num_crops=MultiModalFieldConfig.batched("image"),
|
num_crops=MultiModalFieldConfig.batched("image"),
|
||||||
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
||||||
)
|
)
|
||||||
@ -1444,7 +1445,7 @@ class MolmoForCausalLM(
|
|||||||
) -> Optional[MolmoImageInputs]:
|
) -> Optional[MolmoImageInputs]:
|
||||||
images = kwargs.pop("images", None)
|
images = kwargs.pop("images", None)
|
||||||
image_masks = kwargs.pop("image_masks", None)
|
image_masks = kwargs.pop("image_masks", None)
|
||||||
feat_is_patch = kwargs.pop("feat_is_patch", None)
|
image_input_idx = kwargs.pop("image_input_idx", None)
|
||||||
num_crops = kwargs.pop("num_crops", None)
|
num_crops = kwargs.pop("num_crops", None)
|
||||||
|
|
||||||
if images is None:
|
if images is None:
|
||||||
@ -1466,7 +1467,7 @@ class MolmoForCausalLM(
|
|||||||
return MolmoImageInputs(
|
return MolmoImageInputs(
|
||||||
images=images,
|
images=images,
|
||||||
image_masks=image_masks,
|
image_masks=image_masks,
|
||||||
feat_is_patch=feat_is_patch,
|
image_input_idx=image_input_idx,
|
||||||
num_crops=num_crops,
|
num_crops=num_crops,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1476,7 +1477,7 @@ class MolmoForCausalLM(
|
|||||||
) -> list[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
images = image_input["images"]
|
images = image_input["images"]
|
||||||
image_masks = image_input["image_masks"]
|
image_masks = image_input["image_masks"]
|
||||||
feat_is_patch = image_input["feat_is_patch"]
|
image_input_idx = image_input["image_input_idx"]
|
||||||
num_crops = image_input["num_crops"]
|
num_crops = image_input["num_crops"]
|
||||||
|
|
||||||
# Call the vision backbone on the whole batch at once
|
# Call the vision backbone on the whole batch at once
|
||||||
@ -1484,7 +1485,7 @@ class MolmoForCausalLM(
|
|||||||
image_masks_flat = (
|
image_masks_flat = (
|
||||||
None if image_masks is None else flatten_bn(image_masks, concat=True)
|
None if image_masks is None else flatten_bn(image_masks, concat=True)
|
||||||
)
|
)
|
||||||
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
|
image_input_idx_flat = flatten_bn(image_input_idx, concat=True)
|
||||||
|
|
||||||
image_features_flat = self.vision_backbone(
|
image_features_flat = self.vision_backbone(
|
||||||
images=images_flat.unsqueeze(0),
|
images=images_flat.unsqueeze(0),
|
||||||
@ -1494,13 +1495,18 @@ class MolmoForCausalLM(
|
|||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
||||||
# Only the features corresponding to patch tokens are relevant
|
# Only the features corresponding to patch tokens are relevant
|
||||||
return [
|
# Re-order the features using the image_input_idx tensor
|
||||||
feats[f_is_patch]
|
results = []
|
||||||
for feats, f_is_patch in zip(
|
num_crops_list = num_crops.tolist()
|
||||||
image_features_flat.split(num_crops.tolist()),
|
for feats, img_idx in zip(
|
||||||
feat_is_patch_flat.split(num_crops.tolist()),
|
image_features_flat.split(num_crops_list),
|
||||||
)
|
image_input_idx_flat.split(num_crops_list),
|
||||||
]
|
):
|
||||||
|
is_valid = img_idx >= 0
|
||||||
|
valid_img_idx = img_idx[is_valid]
|
||||||
|
order = torch.argsort(valid_img_idx)
|
||||||
|
results.append(feats[is_valid][order])
|
||||||
|
return results
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user