mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 03:37:02 +08:00
[Model] Support multi-image for Molmo (#15438)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e42389f9d7
commit
997c8811d6
@ -853,7 +853,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
- * `MolmoForCausalLM`
|
||||
* Molmo
|
||||
* T + I
|
||||
* T + I<sup>+</sup>
|
||||
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
|
||||
@ -431,7 +431,7 @@ VLM_TEST_SETTINGS = {
|
||||
),
|
||||
"molmo": VLMTestInfo(
|
||||
models=["allenai/Molmo-7B-D-0924"],
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=identity,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
|
||||
@ -57,7 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import select_patch_features
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
# TODO: hard-coded for now. Consider making it configurable.
|
||||
VIT_LAYERS = [-2, -9]
|
||||
@ -71,13 +71,13 @@ POOLING_SIZE = 2
|
||||
|
||||
|
||||
class MolmoImageInputs(TypedDict):
|
||||
images: Union[torch.Tensor, List[torch.Tensor]]
|
||||
images: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
|
||||
|
||||
image_masks: Optional[Union[torch.Tensor, List[torch.Tensor]]]
|
||||
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
|
||||
"""Shape: `(batch_size, num_crops, num_patch)`"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image features correspond
|
||||
to patch tokens.
|
||||
@ -85,7 +85,7 @@ class MolmoImageInputs(TypedDict):
|
||||
Shape: `(batch_size, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
@ -93,7 +93,7 @@ class MolmoImageInputs(TypedDict):
|
||||
Shape: `(batch_size, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: torch.Tensor
|
||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
@ -1144,13 +1144,7 @@ class MolmoProcessorWrapper:
|
||||
|
||||
image_input_idx = outputs.pop("image_input_idx", None)
|
||||
if image_input_idx is not None:
|
||||
input_is_patch = input_ids == self.image_patch_id
|
||||
image_input_idx_flat: torch.Tensor = image_input_idx.view(-1)
|
||||
image_valid_flat = image_input_idx_flat >= 0
|
||||
feat_is_patch_flat = image_valid_flat.clone()
|
||||
feat_is_patch_flat[image_valid_flat] = (
|
||||
input_is_patch[image_input_idx_flat[image_valid_flat]])
|
||||
feat_is_patch = feat_is_patch_flat.view(*image_input_idx.shape)
|
||||
feat_is_patch = image_input_idx >= 0
|
||||
|
||||
input_is_embed = torch.isin(
|
||||
input_ids,
|
||||
@ -1165,6 +1159,17 @@ class MolmoProcessorWrapper:
|
||||
embed_is_patch = embed_ids == self.image_patch_id
|
||||
assert embed_is_patch.sum() == feat_is_patch.sum()
|
||||
|
||||
# image_tokens = extra_joint + joint
|
||||
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
|
||||
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
|
||||
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
|
||||
assert len(embed_start) == len(embed_end) == len(images)
|
||||
|
||||
embed_is_patch = [
|
||||
embed_is_patch[start:end + 1]
|
||||
for start, end in zip(embed_start, embed_end)
|
||||
]
|
||||
|
||||
tilings = [
|
||||
self.select_tiling(
|
||||
image_width=image.size[0],
|
||||
@ -1180,7 +1185,7 @@ class MolmoProcessorWrapper:
|
||||
outputs["num_crops"] = num_crops
|
||||
outputs["img_patch_id"] = self.image_patch_id
|
||||
|
||||
return BatchFeature(outputs, tensor_type=return_tensors)
|
||||
return BatchFeature(outputs)
|
||||
|
||||
|
||||
class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
@ -1190,9 +1195,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
return MolmoProcessorWrapper(processor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
# TODO: Investigate different `embed_is_patch` between cache/no-cache
|
||||
# in multi-image case
|
||||
return {"image": 1}
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
@ -1325,7 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
"image", num_crops),
|
||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops),
|
||||
embed_is_patch=MultiModalFieldConfig.shared("image", num_images),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
@ -1499,7 +1502,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: MolmoImageInputs,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
if isinstance(image_input["images"], list):
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(image_input["images"], concat=True)
|
||||
@ -1530,7 +1533,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
|
||||
num_crops: torch.Tensor, # Shape: (num_images,)
|
||||
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
|
||||
) -> list[torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
@ -1565,16 +1568,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
feats_per_image = features.split(num_crops_per_image)
|
||||
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
|
||||
|
||||
_, _, embed_dim = features.shape
|
||||
(num_embeds, ) = embed_is_patch.shape
|
||||
features = torch.cat([
|
||||
feats[f_is_patch]
|
||||
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
|
||||
])
|
||||
|
||||
embeds_in_batch = list[torch.Tensor]()
|
||||
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
|
||||
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
|
||||
embeds[embed_is_patch] = feats[f_is_patch]
|
||||
embeds_in_batch.append(embeds)
|
||||
|
||||
return embeds_in_batch
|
||||
return scatter_patch_features(features, embed_is_patch)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
|
||||
@ -155,7 +155,7 @@ def resolve_visual_encoder_outputs(
|
||||
|
||||
def scatter_patch_features(
|
||||
features: torch.Tensor,
|
||||
embed_is_patch: torch.Tensor,
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
@ -194,14 +194,19 @@ def scatter_patch_features(
|
||||
The resulting embedding tensor is:
|
||||
[ nan p1 p2 nan p3 p4 nan nan ]
|
||||
"""
|
||||
num_images, num_embeds = embed_is_patch.shape
|
||||
num_embeds_per_image = [num_embeds] * num_images
|
||||
num_embeds_per_image = [
|
||||
e_is_patch.numel() for e_is_patch in embed_is_patch
|
||||
]
|
||||
if isinstance(embed_is_patch, torch.Tensor):
|
||||
embed_is_patch_flat = embed_is_patch.view(-1)
|
||||
else:
|
||||
embed_is_patch_flat = torch.cat(embed_is_patch)
|
||||
|
||||
embeds_flat = features.new_full(
|
||||
(sum(num_embeds_per_image), features.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
|
||||
embeds_flat[embed_is_patch_flat] = features.flatten(0, -2)
|
||||
|
||||
return embeds_flat.split(num_embeds_per_image)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user