mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 04:15:01 +08:00
[VLM] Merged multi-modal processor for Molmo (#12966)
This commit is contained in:
parent
fdcf64d3c6
commit
c9d3ecf016
@ -793,7 +793,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
- * `MolmoForCausalLM`
|
- * `MolmoForCausalLM`
|
||||||
* Molmo
|
* Molmo
|
||||||
* T + I
|
* T + I
|
||||||
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-72B-0924`, etc.
|
* `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc.
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from ...utils import check_logprobs_close
|
|||||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
"THUDM/chatglm3-6b", # ChatGLM (text-only)
|
"THUDM/chatglm3-6b", # chatglm (text-only)
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
"meta-llama/Llama-3.2-1B-Instruct", # llama
|
"meta-llama/Llama-3.2-1B-Instruct", # llama
|
||||||
|
|||||||
@ -404,11 +404,10 @@ VLM_TEST_SETTINGS = {
|
|||||||
"molmo": VLMTestInfo(
|
"molmo": VLMTestInfo(
|
||||||
models=["allenai/Molmo-7B-D-0924"],
|
models=["allenai/Molmo-7B-D-0924"],
|
||||||
test_type=(VLMTestType.IMAGE),
|
test_type=(VLMTestType.IMAGE),
|
||||||
prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501
|
prompt_formatter=identity,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
image_size_factors=[(),(1.0, 1.0, 1.0)],
|
patch_hf_runner=model_utils.molmo_patch_hf_runner,
|
||||||
patch_hf_runner=model_utils.mlomo_patch_hf_runner,
|
|
||||||
postprocess_inputs=model_utils.molmo_post_processor,
|
postprocess_inputs=model_utils.molmo_post_processor,
|
||||||
),
|
),
|
||||||
# Tests for phi3v currently live in another file because of a bug in
|
# Tests for phi3v currently live in another file because of a bug in
|
||||||
|
|||||||
@ -6,7 +6,7 @@ typically specific to a small subset of models.
|
|||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
from pathlib import PosixPath
|
from pathlib import PosixPath
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
@ -17,9 +17,7 @@ from vllm.sequence import SampleLogprobs
|
|||||||
from vllm.transformers_utils.tokenizer import patch_padding_side
|
from vllm.transformers_utils.tokenizer import patch_padding_side
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
from .....conftest import (HfRunner, ImageAsset, PromptAudioInput,
|
from .....conftest import HfRunner, ImageAsset, _ImageAssets
|
||||||
PromptImageInput, PromptVideoInput, _ImageAssets)
|
|
||||||
from ....utils import TokensTextLogprobs
|
|
||||||
from .types import RunnerOutput
|
from .types import RunnerOutput
|
||||||
|
|
||||||
|
|
||||||
@ -522,74 +520,7 @@ def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
|||||||
return hf_model
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
def _generate_greedy_logprobs_limit(
|
def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
self,
|
|
||||||
prompts: List[str],
|
|
||||||
max_tokens: int,
|
|
||||||
num_logprobs: int,
|
|
||||||
images: Optional[PromptImageInput] = None,
|
|
||||||
audios: Optional[PromptAudioInput] = None,
|
|
||||||
videos: Optional[PromptVideoInput] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[TokensTextLogprobs]:
|
|
||||||
all_inputs = self.get_inputs(prompts,
|
|
||||||
images=images,
|
|
||||||
videos=videos,
|
|
||||||
audios=audios)
|
|
||||||
|
|
||||||
# Process in batches for inference.
|
|
||||||
if len(all_inputs):
|
|
||||||
input_ids_lst = []
|
|
||||||
images_lst = []
|
|
||||||
images_input_idx_lst = []
|
|
||||||
imges_masks_lst = []
|
|
||||||
for inputs in all_inputs:
|
|
||||||
input_ids_lst.append(inputs["input_ids"])
|
|
||||||
images_lst.append(inputs["images"])
|
|
||||||
images_input_idx_lst.append(inputs["image_input_idx"])
|
|
||||||
imges_masks_lst.append(inputs["image_masks"])
|
|
||||||
batch_inputs = {}
|
|
||||||
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
|
|
||||||
batch_inputs['images'] = torch.cat(images_lst, dim=0)
|
|
||||||
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
|
|
||||||
dim=0)
|
|
||||||
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)
|
|
||||||
|
|
||||||
outputs = self.model.generate_from_batch(
|
|
||||||
batch=self.wrap_device(batch_inputs,
|
|
||||||
device=self.model.device.type),
|
|
||||||
generation_config=GenerationConfig(
|
|
||||||
max_new_tokens=max_tokens,
|
|
||||||
stop_strings="<|endoftext|>",
|
|
||||||
do_sample=False,
|
|
||||||
),
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
output_hidden_states=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
all_logprobs: List[List[Dict[int, float]]] = []
|
|
||||||
all_output_ids: List[List[int]] = []
|
|
||||||
all_output_strs: List[str] = []
|
|
||||||
|
|
||||||
for index in range(len(all_inputs)):
|
|
||||||
(
|
|
||||||
seq_logprobs_lst,
|
|
||||||
output_len,
|
|
||||||
) = self._hidden_states_to_logprobs(outputs.hidden_states,
|
|
||||||
num_logprobs)
|
|
||||||
all_logprobs.append(seq_logprobs_lst)
|
|
||||||
seq_ids = outputs.sequences[index]
|
|
||||||
output_ids = seq_ids[-output_len:]
|
|
||||||
all_output_ids.append(output_ids.tolist())
|
|
||||||
all_output_strs.append(self.tokenizer.decode(output_ids))
|
|
||||||
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
|
||||||
return [(output_ids, output_str, output_logprobs)
|
|
||||||
for output_ids, output_str, output_logprobs in outputs]
|
|
||||||
|
|
||||||
|
|
||||||
####### Molmo-specific HuggingFace runner patchers
|
|
||||||
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
|
||||||
"""Patches and returns an instance of the HfRunner to use for Molmo."""
|
"""Patches and returns an instance of the HfRunner to use for Molmo."""
|
||||||
hf_processor = hf_model.processor
|
hf_processor = hf_model.processor
|
||||||
|
|
||||||
@ -598,10 +529,23 @@ def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
|||||||
|
|
||||||
hf_model.processor = _processor
|
hf_model.processor = _processor
|
||||||
|
|
||||||
setattr( # noqa: B010
|
def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
|
||||||
hf_model,
|
batch = {
|
||||||
"generate_greedy_logprobs_limit",
|
k: kwargs.pop(k)
|
||||||
types.MethodType(_generate_greedy_logprobs_limit, hf_model),
|
for k in ("input_ids", "images", "image_input_idx", "image_masks")
|
||||||
|
if k in kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.generate_from_batch(
|
||||||
|
batch,
|
||||||
|
generation_config=GenerationConfig(
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
stop_strings="<|endoftext|>",
|
||||||
|
do_sample=do_sample,
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||||
|
|
||||||
return hf_model
|
return hf_model
|
||||||
|
|||||||
@ -168,6 +168,8 @@ def _test_processing_correctness(
|
|||||||
"mistral-community/pixtral-12b",
|
"mistral-community/pixtral-12b",
|
||||||
"openbmb/MiniCPM-o-2_6",
|
"openbmb/MiniCPM-o-2_6",
|
||||||
"openbmb/MiniCPM-V-2_6",
|
"openbmb/MiniCPM-V-2_6",
|
||||||
|
"allenai/Molmo-7B-D-0924",
|
||||||
|
"allenai/Molmo-7B-O-0924",
|
||||||
"nvidia/NVLM-D-72B",
|
"nvidia/NVLM-D-72B",
|
||||||
"Qwen/Qwen-VL-Chat",
|
"Qwen/Qwen-VL-Chat",
|
||||||
"Qwen/Qwen2-VL-2B-Instruct",
|
"Qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
|||||||
@ -256,6 +256,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6",
|
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-V-2_6",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
|
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
|
||||||
|
extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
|
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -33,8 +33,7 @@ from dataclasses import dataclass, field
|
|||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||||
Dict, Generator, Generic, Iterator, List, Literal,
|
Dict, Generator, Generic, Iterator, List, Literal,
|
||||||
NamedTuple, Optional, Tuple, Type, TypeVar, Union,
|
NamedTuple, Optional, Tuple, Type, TypeVar, Union)
|
||||||
overload)
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
@ -826,38 +825,6 @@ JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
|
|||||||
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def json_map_leaves(
|
|
||||||
func: Callable[[T], U],
|
|
||||||
value: Dict[str, JSONTree[T]],
|
|
||||||
) -> Dict[str, JSONTree[U]]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def json_map_leaves(
|
|
||||||
func: Callable[[T], U],
|
|
||||||
value: List[JSONTree[T]],
|
|
||||||
) -> List[JSONTree[U]]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def json_map_leaves(
|
|
||||||
func: Callable[[T], U],
|
|
||||||
value: Tuple[JSONTree[T], ...],
|
|
||||||
) -> Tuple[JSONTree[U], ...]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def json_map_leaves(
|
|
||||||
func: Callable[[T], U],
|
|
||||||
value: JSONTree[T],
|
|
||||||
) -> JSONTree[U]:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
|
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
return {k: json_map_leaves(func, v) for k, v in value.items()}
|
return {k: json_map_leaves(func, v) for k, v in value.items()}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user