diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 56ea8c5d8372b..f106195e10585 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -853,7 +853,7 @@ See [this page](#generative-models) for more information on how to use generativ * - * `MolmoForCausalLM` * Molmo - * T + I + * T + I+ * `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. * ✅︎ * ✅︎ diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 94b61b6ae7803..d500ef5d8b805 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -431,7 +431,7 @@ VLM_TEST_SETTINGS = { ), "molmo": VLMTestInfo( models=["allenai/Molmo-7B-D-0924"], - test_type=(VLMTestType.IMAGE), + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), prompt_formatter=identity, max_model_len=4096, max_num_seqs=2, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 3f0c644a5a866..146d48e522119 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -57,7 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) -from .vision import select_patch_features +from .vision import scatter_patch_features, select_patch_features # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] @@ -71,13 +71,13 @@ POOLING_SIZE = 2 class MolmoImageInputs(TypedDict): - images: Union[torch.Tensor, List[torch.Tensor]] + images: Union[torch.Tensor, list[torch.Tensor]] """Shape: `(batch_size, num_crops, num_patch, patch_dim)`""" - image_masks: Optional[Union[torch.Tensor, List[torch.Tensor]]] + image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]] """Shape: `(batch_size, num_crops, num_patch)`""" - feat_is_patch: Union[torch.Tensor, List[torch.Tensor]] + feat_is_patch: Union[torch.Tensor, list[torch.Tensor]] """ A boolean mask indicating which image features correspond to patch tokens. @@ -85,7 +85,7 @@ class MolmoImageInputs(TypedDict): Shape: `(batch_size, num_crops, num_patch)` """ - embed_is_patch: Union[torch.Tensor, List[torch.Tensor]] + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] """ A boolean mask indicating which image embeddings correspond to patch tokens. @@ -93,7 +93,7 @@ class MolmoImageInputs(TypedDict): Shape: `(batch_size, num_embeds)` """ - num_crops: torch.Tensor + num_crops: Union[torch.Tensor, list[torch.Tensor]] """Shape: `(batch_size, num_images)`""" @@ -1144,13 +1144,7 @@ class MolmoProcessorWrapper: image_input_idx = outputs.pop("image_input_idx", None) if image_input_idx is not None: - input_is_patch = input_ids == self.image_patch_id - image_input_idx_flat: torch.Tensor = image_input_idx.view(-1) - image_valid_flat = image_input_idx_flat >= 0 - feat_is_patch_flat = image_valid_flat.clone() - feat_is_patch_flat[image_valid_flat] = ( - input_is_patch[image_input_idx_flat[image_valid_flat]]) - feat_is_patch = feat_is_patch_flat.view(*image_input_idx.shape) + feat_is_patch = image_input_idx >= 0 input_is_embed = torch.isin( input_ids, @@ -1165,6 +1159,17 @@ class MolmoProcessorWrapper: embed_is_patch = embed_ids == self.image_patch_id assert embed_is_patch.sum() == feat_is_patch.sum() + # image_tokens = extra_joint + joint + # Both `extra_joint` and `joint` have `im_start_id` and `im_end_id` + embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0] + embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0] + assert len(embed_start) == len(embed_end) == len(images) + + embed_is_patch = [ + embed_is_patch[start:end + 1] + for start, end in zip(embed_start, embed_end) + ] + tilings = [ self.select_tiling( image_width=image.size[0], @@ -1180,7 +1185,7 @@ class MolmoProcessorWrapper: outputs["num_crops"] = num_crops outputs["img_patch_id"] = self.image_patch_id - return BatchFeature(outputs, tensor_type=return_tensors) + return BatchFeature(outputs) class MolmoProcessingInfo(BaseProcessingInfo): @@ -1190,9 +1195,7 @@ class MolmoProcessingInfo(BaseProcessingInfo): return MolmoProcessorWrapper(processor) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - # TODO: Investigate different `embed_is_patch` between cache/no-cache - # in multi-image case - return {"image": 1} + return {"image": None} def get_mm_max_tokens_per_item( self, @@ -1325,7 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): "image", num_crops), feat_is_patch=MultiModalFieldConfig.flat_from_sizes( "image", num_crops), - embed_is_patch=MultiModalFieldConfig.shared("image", num_images), + embed_is_patch=MultiModalFieldConfig.batched("image"), num_crops=MultiModalFieldConfig.batched("image"), img_patch_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -1499,7 +1502,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, def _process_image_input( self, image_input: MolmoImageInputs, - ) -> Union[torch.Tensor, List[torch.Tensor]]: + ) -> Union[torch.Tensor, list[torch.Tensor]]: if isinstance(image_input["images"], list): # Call the vision backbone on the whole batch at once images_flat = flatten_bn(image_input["images"], concat=True) @@ -1530,7 +1533,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch) num_crops: torch.Tensor, # Shape: (num_images,) embed_is_patch: torch.Tensor, # Shape: (num_embeds,) - ) -> list[torch.Tensor]: + ) -> tuple[torch.Tensor, ...]: """ Scatter the patch features into a contiguous tensor that corresponds to the embedding tokens defined by the multimodal processor. @@ -1565,16 +1568,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, feats_per_image = features.split(num_crops_per_image) f_is_patch_per_image = feat_is_patch.split(num_crops_per_image) - _, _, embed_dim = features.shape - (num_embeds, ) = embed_is_patch.shape + features = torch.cat([ + feats[f_is_patch] + for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image) + ]) - embeds_in_batch = list[torch.Tensor]() - for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image): - embeds = feats.new_full((num_embeds, embed_dim), torch.nan) - embeds[embed_is_patch] = feats[f_is_patch] - embeds_in_batch.append(embeds) - - return embeds_in_batch + return scatter_patch_features(features, embed_is_patch) def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 250b0ee3c2a1b..c91459398308e 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -155,7 +155,7 @@ def resolve_visual_encoder_outputs( def scatter_patch_features( features: torch.Tensor, - embed_is_patch: torch.Tensor, + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]], ) -> tuple[torch.Tensor, ...]: """ Scatter the patch features into a contiguous tensor that corresponds @@ -194,14 +194,19 @@ def scatter_patch_features( The resulting embedding tensor is: [ nan p1 p2 nan p3 p4 nan nan ] """ - num_images, num_embeds = embed_is_patch.shape - num_embeds_per_image = [num_embeds] * num_images + num_embeds_per_image = [ + e_is_patch.numel() for e_is_patch in embed_is_patch + ] + if isinstance(embed_is_patch, torch.Tensor): + embed_is_patch_flat = embed_is_patch.view(-1) + else: + embed_is_patch_flat = torch.cat(embed_is_patch) embeds_flat = features.new_full( (sum(num_embeds_per_image), features.shape[-1]), fill_value=torch.nan, ) - embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2) + embeds_flat[embed_is_patch_flat] = features.flatten(0, -2) return embeds_flat.split(num_embeds_per_image)