mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:35:34 +08:00
[Bugfix][Multi Modal] Fix incorrect Molmo image processing (#26563)
Signed-off-by: sanghol <sanghol@allenai.org>
This commit is contained in:
parent
ddaff2938e
commit
55392bc879
@ -114,11 +114,11 @@ class MolmoImageInputs(TensorSchema):
|
||||
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
|
||||
]
|
||||
|
||||
feat_is_patch: Annotated[
|
||||
image_input_idx: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
|
||||
]
|
||||
# A boolean mask indicating which image features correspond to patch tokens.
|
||||
# An index tensor that maps image features to their corresponding patch tokens.
|
||||
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
@ -1177,7 +1177,7 @@ class MolmoProcessorWrapper:
|
||||
num_crops = torch.tensor(tilings).prod(-1) + 1
|
||||
assert num_crops.sum() == len(feat_is_patch)
|
||||
|
||||
outputs["feat_is_patch"] = feat_is_patch
|
||||
outputs["image_input_idx"] = image_input_idx
|
||||
outputs["num_crops"] = num_crops
|
||||
outputs["img_patch_id"] = self.image_patch_id
|
||||
|
||||
@ -1211,8 +1211,9 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
image_token_length_w = processor.image_token_length_w
|
||||
image_token_length_h = processor.image_token_length_h
|
||||
|
||||
extra = image_token_length_w * image_token_length_h
|
||||
joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
|
||||
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators
|
||||
extra = 2 + (image_token_length_w + 1) * image_token_length_h
|
||||
joint = 2 + ((ncols + 1) // pooling_size + 1) * ((nrows + 1) // pooling_size)
|
||||
|
||||
return extra + joint
|
||||
|
||||
@ -1299,7 +1300,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
return dict(
|
||||
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
||||
image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
||||
image_input_idx=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
@ -1444,7 +1445,7 @@ class MolmoForCausalLM(
|
||||
) -> Optional[MolmoImageInputs]:
|
||||
images = kwargs.pop("images", None)
|
||||
image_masks = kwargs.pop("image_masks", None)
|
||||
feat_is_patch = kwargs.pop("feat_is_patch", None)
|
||||
image_input_idx = kwargs.pop("image_input_idx", None)
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
|
||||
if images is None:
|
||||
@ -1466,7 +1467,7 @@ class MolmoForCausalLM(
|
||||
return MolmoImageInputs(
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
feat_is_patch=feat_is_patch,
|
||||
image_input_idx=image_input_idx,
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
@ -1476,7 +1477,7 @@ class MolmoForCausalLM(
|
||||
) -> list[torch.Tensor]:
|
||||
images = image_input["images"]
|
||||
image_masks = image_input["image_masks"]
|
||||
feat_is_patch = image_input["feat_is_patch"]
|
||||
image_input_idx = image_input["image_input_idx"]
|
||||
num_crops = image_input["num_crops"]
|
||||
|
||||
# Call the vision backbone on the whole batch at once
|
||||
@ -1484,7 +1485,7 @@ class MolmoForCausalLM(
|
||||
image_masks_flat = (
|
||||
None if image_masks is None else flatten_bn(image_masks, concat=True)
|
||||
)
|
||||
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
|
||||
image_input_idx_flat = flatten_bn(image_input_idx, concat=True)
|
||||
|
||||
image_features_flat = self.vision_backbone(
|
||||
images=images_flat.unsqueeze(0),
|
||||
@ -1494,13 +1495,18 @@ class MolmoForCausalLM(
|
||||
).squeeze(0)
|
||||
|
||||
# Only the features corresponding to patch tokens are relevant
|
||||
return [
|
||||
feats[f_is_patch]
|
||||
for feats, f_is_patch in zip(
|
||||
image_features_flat.split(num_crops.tolist()),
|
||||
feat_is_patch_flat.split(num_crops.tolist()),
|
||||
)
|
||||
]
|
||||
# Re-order the features using the image_input_idx tensor
|
||||
results = []
|
||||
num_crops_list = num_crops.tolist()
|
||||
for feats, img_idx in zip(
|
||||
image_features_flat.split(num_crops_list),
|
||||
image_input_idx_flat.split(num_crops_list),
|
||||
):
|
||||
is_valid = img_idx >= 0
|
||||
valid_img_idx = img_idx[is_valid]
|
||||
order = torch.argsort(valid_img_idx)
|
||||
results.append(feats[is_valid][order])
|
||||
return results
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user