[VLM] Merged multi-modal processor for Molmo (#12966)

This commit is contained in:
Cyrus Leung 2025-02-13 20:34:00 +08:00 committed by GitHub
parent fdcf64d3c6
commit c9d3ecf016
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 745 additions and 493 deletions

View File

@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `MolmoForCausalLM` - * `MolmoForCausalLM`
* Molmo * Molmo
* T + I * T + I
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc. * `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
* ✅︎ * ✅︎

View File

@ -27,7 +27,7 @@ from ...utils import check_logprobs_close
marks=[pytest.mark.core_model, pytest.mark.cpu_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
), ),
pytest.param( pytest.param(
"THUDM/chatglm3-6b", # ChatGLM (text-only) "THUDM/chatglm3-6b", # chatglm (text-only)
), ),
pytest.param( pytest.param(
"meta-llama/Llama-3.2-1B-Instruct", # llama "meta-llama/Llama-3.2-1B-Instruct", # llama

View File

@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = {
"molmo": VLMTestInfo( "molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"], models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE), test_type=(VLMTestType.IMAGE),
prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501 prompt_formatter=identity,
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
image_size_factors=[(),(1.0, 1.0, 1.0)], patch_hf_runner=model_utils.molmo_patch_hf_runner,
patch_hf_runner=model_utils.mlomo_patch_hf_runner,
postprocess_inputs=model_utils.molmo_post_processor, postprocess_inputs=model_utils.molmo_post_processor,
), ),
# Tests for phi3v currently live in another file because of a bug in # Tests for phi3v currently live in another file because of a bug in

View File

@ -6,7 +6,7 @@ typically specific to a small subset of models.
import re import re
import types import types
from pathlib import PosixPath from pathlib import PosixPath
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from PIL.Image import Image from PIL.Image import Image
@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.transformers_utils.tokenizer import patch_padding_side
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from .....conftest import (HfRunner, ImageAsset, PromptAudioInput, from .....conftest import HfRunner, ImageAsset, _ImageAssets
PromptImageInput, PromptVideoInput, _ImageAssets)
from ....utils import TokensTextLogprobs
from .types import RunnerOutput from .types import RunnerOutput
@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model return hf_model
def _generate_greedy_logprobs_limit( def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> List[TokensTextLogprobs]:
all_inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
# Process in batches for inference.
if len(all_inputs):
input_ids_lst = []
images_lst = []
images_input_idx_lst = []
imges_masks_lst = []
for inputs in all_inputs:
input_ids_lst.append(inputs["input_ids"])
images_lst.append(inputs["images"])
images_input_idx_lst.append(inputs["image_input_idx"])
imges_masks_lst.append(inputs["image_masks"])
batch_inputs = {}
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
batch_inputs['images'] = torch.cat(images_lst, dim=0)
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
dim=0)
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)
outputs = self.model.generate_from_batch(
batch=self.wrap_device(batch_inputs,
device=self.model.device.type),
generation_config=GenerationConfig(
max_new_tokens=max_tokens,
stop_strings="<|endoftext|>",
do_sample=False,
),
tokenizer=self.tokenizer,
output_hidden_states=True,
return_dict_in_generate=True,
)
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for index in range(len(all_inputs)):
(
seq_logprobs_lst,
output_len,
) = self._hidden_states_to_logprobs(outputs.hidden_states,
num_logprobs)
all_logprobs.append(seq_logprobs_lst)
seq_ids = outputs.sequences[index]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
####### Molmo-specific HuggingFace runner patchers
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Molmo.""" """Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor = hf_model.processor hf_processor = hf_model.processor
@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model.processor = _processor hf_model.processor = _processor
setattr( # noqa: B010 def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
hf_model, batch = {
"generate_greedy_logprobs_limit", k: kwargs.pop(k)
types.MethodType(_generate_greedy_logprobs_limit, hf_model), for k in ("input_ids", "images", "image_input_idx", "image_masks")
) if k in kwargs
}
return self.generate_from_batch(
batch,
generation_config=GenerationConfig(
max_new_tokens=max_new_tokens,
stop_strings="<|endoftext|>",
do_sample=do_sample,
),
**kwargs,
)
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model return hf_model

View File

@ -168,6 +168,8 @@ def _test_processing_correctness(
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6", "openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B", "nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat", "Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2-VL-2B-Instruct",

View File

@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6", "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6",
trust_remote_code=True), trust_remote_code=True),
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
trust_remote_code=True), trust_remote_code=True),

File diff suppressed because it is too large Load Diff

View File

@ -353,17 +353,17 @@ class MultiModalFieldConfig:
Example: Example:
.. code-block:: .. code-block::
Input: Input:
Data: [[AAAA] Data: [[AAAA]
[BBBB] [BBBB]
[CCCC]] [CCCC]]
Output: Output:
Element 1: [AAAA] Element 1: [AAAA]
Element 2: [BBBB] Element 2: [BBBB]
Element 3: [CCCC] Element 3: [CCCC]
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalBatchedField(), field=MultiModalBatchedField(),
@ -384,18 +384,18 @@ class MultiModalFieldConfig:
Example: Example:
.. code-block:: .. code-block::
Given: Given:
slices: [slice(0, 3), slice(3, 7), slice(7, 9)] slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
Input: Input:
Data: [AAABBBBCC] Data: [AAABBBBCC]
Output: Output:
Element 1: [AAA] Element 1: [AAA]
Element 2: [BBBB] Element 2: [BBBB]
Element 3: [CC] Element 3: [CC]
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalFlatField(slices=slices), field=MultiModalFlatField(slices=slices),
@ -416,18 +416,18 @@ class MultiModalFieldConfig:
Example: Example:
.. code-block:: .. code-block::
Given: Given:
size_per_item: [3, 4, 2] size_per_item: [3, 4, 2]
Input: Input:
Data: [AAABBBBCC] Data: [AAABBBBCC]
Output: Output:
Element 1: [AAA] Element 1: [AAA]
Element 2: [BBBB] Element 2: [BBBB]
Element 3: [CC] Element 3: [CC]
See also: See also:
:func:`MultiModalFieldConfig.flat` :func:`MultiModalFieldConfig.flat`
@ -456,19 +456,19 @@ class MultiModalFieldConfig:
Example: Example:
.. code-block:: .. code-block::
Given: Given:
batch_size: 4 batch_size: 4
Input: Input:
Data: [XYZ] Data: [XYZ]
Output: Output:
Element 1: [XYZ] Element 1: [XYZ]
Element 2: [XYZ] Element 2: [XYZ]
Element 3: [XYZ] Element 3: [XYZ]
Element 4: [XYZ] Element 4: [XYZ]
""" """
return MultiModalFieldConfig( return MultiModalFieldConfig(
field=MultiModalSharedField(batch_size), field=MultiModalSharedField(batch_size),

View File

@ -33,8 +33,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps from functools import cache, lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Iterator, List, Literal, Dict, Generator, Generic, Iterator, List, Literal,
NamedTuple, Optional, Tuple, Type, TypeVar, Union, NamedTuple, Optional, Tuple, Type, TypeVar, Union)
overload)
from uuid import uuid4 from uuid import uuid4
import cloudpickle import cloudpickle
@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
"""A nested JSON structure where the leaves need not be JSON-serializable.""" """A nested JSON structure where the leaves need not be JSON-serializable."""
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Dict[str, JSONTree[T]],
) -> Dict[str, JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: List[JSONTree[T]],
) -> List[JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Tuple[JSONTree[T], ...],
) -> Tuple[JSONTree[U], ...]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: JSONTree[T],
) -> JSONTree[U]:
...
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
if isinstance(value, dict): if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()} return {k: json_map_leaves(func, v) for k, v in value.items()}