mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 04:35:01 +08:00
[Bugfix][V1] Fix molmo text-only inputs (#11676)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
4ca5d40adc
commit
32c9eff2ff
@ -341,6 +341,16 @@ VLM_TEST_SETTINGS = {
|
|||||||
),
|
),
|
||||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||||
),
|
),
|
||||||
|
"molmo": VLMTestInfo(
|
||||||
|
models=["allenai/Molmo-7B-D-0924"],
|
||||||
|
test_type=(VLMTestType.IMAGE),
|
||||||
|
prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
image_size_factors=[(),(1.0, 1.0, 1.0)],
|
||||||
|
patch_hf_runner=model_utils.mlomo_patch_hf_runner,
|
||||||
|
postprocess_inputs=model_utils.molmo_post_processor,
|
||||||
|
),
|
||||||
# Tests for phi3v currently live in another file because of a bug in
|
# Tests for phi3v currently live in another file because of a bug in
|
||||||
# transformers. Once this issue is fixed, we can enable them here instead.
|
# transformers. Once this issue is fixed, we can enable them here instead.
|
||||||
# https://github.com/huggingface/transformers/issues/34307
|
# https://github.com/huggingface/transformers/issues/34307
|
||||||
|
|||||||
@ -5,17 +5,20 @@ typically specific to a small subset of models.
|
|||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
from pathlib import PosixPath
|
from pathlib import PosixPath
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
|
from transformers import (AutoConfig, AutoTokenizer, BatchEncoding,
|
||||||
|
GenerationConfig)
|
||||||
|
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
from vllm.transformers_utils.tokenizer import patch_padding_side
|
from vllm.transformers_utils.tokenizer import patch_padding_side
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
from .....conftest import HfRunner, ImageAsset, _ImageAssets
|
from .....conftest import (HfRunner, ImageAsset, PromptAudioInput,
|
||||||
|
PromptImageInput, PromptVideoInput, _ImageAssets)
|
||||||
|
from ....utils import TokensTextLogprobs
|
||||||
from .types import RunnerOutput
|
from .types import RunnerOutput
|
||||||
|
|
||||||
|
|
||||||
@ -222,6 +225,11 @@ def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
|
|||||||
return {"model_inputs": hf_inputs}
|
return {"model_inputs": hf_inputs}
|
||||||
|
|
||||||
|
|
||||||
|
def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str):
|
||||||
|
hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype)
|
||||||
|
return {k: v.unsqueeze(0) for k, v in hf_inputs.items()}
|
||||||
|
|
||||||
|
|
||||||
####### Prompt path encoders for models that need models on disk
|
####### Prompt path encoders for models that need models on disk
|
||||||
def qwen_prompt_path_encoder(
|
def qwen_prompt_path_encoder(
|
||||||
tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset],
|
tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset],
|
||||||
@ -451,3 +459,88 @@ def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
|||||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||||
|
|
||||||
return hf_model
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_greedy_logprobs_limit(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
images: Optional[PromptImageInput] = None,
|
||||||
|
audios: Optional[PromptAudioInput] = None,
|
||||||
|
videos: Optional[PromptVideoInput] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[TokensTextLogprobs]:
|
||||||
|
all_inputs = self.get_inputs(prompts,
|
||||||
|
images=images,
|
||||||
|
videos=videos,
|
||||||
|
audios=audios)
|
||||||
|
|
||||||
|
# Process in batches for inference.
|
||||||
|
if len(all_inputs):
|
||||||
|
input_ids_lst = []
|
||||||
|
images_lst = []
|
||||||
|
images_input_idx_lst = []
|
||||||
|
imges_masks_lst = []
|
||||||
|
for inputs in all_inputs:
|
||||||
|
input_ids_lst.append(inputs["input_ids"])
|
||||||
|
images_lst.append(inputs["images"])
|
||||||
|
images_input_idx_lst.append(inputs["image_input_idx"])
|
||||||
|
imges_masks_lst.append(inputs["image_masks"])
|
||||||
|
batch_inputs = {}
|
||||||
|
batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0)
|
||||||
|
batch_inputs['images'] = torch.cat(images_lst, dim=0)
|
||||||
|
batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst,
|
||||||
|
dim=0)
|
||||||
|
batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0)
|
||||||
|
|
||||||
|
outputs = self.model.generate_from_batch(
|
||||||
|
batch=self.wrap_device(batch_inputs,
|
||||||
|
device=self.model.device.type),
|
||||||
|
generation_config=GenerationConfig(
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
stop_strings="<|endoftext|>",
|
||||||
|
do_sample=False,
|
||||||
|
),
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_logprobs: List[List[Dict[int, float]]] = []
|
||||||
|
all_output_ids: List[List[int]] = []
|
||||||
|
all_output_strs: List[str] = []
|
||||||
|
|
||||||
|
for index in range(len(all_inputs)):
|
||||||
|
(
|
||||||
|
seq_logprobs_lst,
|
||||||
|
output_len,
|
||||||
|
) = self._hidden_states_to_logprobs(outputs.hidden_states,
|
||||||
|
num_logprobs)
|
||||||
|
all_logprobs.append(seq_logprobs_lst)
|
||||||
|
seq_ids = outputs.sequences[index]
|
||||||
|
output_ids = seq_ids[-output_len:]
|
||||||
|
all_output_ids.append(output_ids.tolist())
|
||||||
|
all_output_strs.append(self.tokenizer.decode(output_ids))
|
||||||
|
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
||||||
|
return [(output_ids, output_str, output_logprobs)
|
||||||
|
for output_ids, output_str, output_logprobs in outputs]
|
||||||
|
|
||||||
|
|
||||||
|
####### Molmo-specific HuggingFace runner patchers
|
||||||
|
def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
|
"""Patches and returns an instance of the HfRunner to use for Molmo."""
|
||||||
|
hf_processor = hf_model.processor
|
||||||
|
|
||||||
|
def _processor(*args, **kwargs):
|
||||||
|
return hf_processor.process(*args, **kwargs)
|
||||||
|
|
||||||
|
hf_model.processor = _processor
|
||||||
|
|
||||||
|
setattr( # noqa: B010
|
||||||
|
hf_model,
|
||||||
|
"generate_greedy_logprobs_limit",
|
||||||
|
types.MethodType(_generate_greedy_logprobs_limit, hf_model),
|
||||||
|
)
|
||||||
|
|
||||||
|
return hf_model
|
||||||
|
|||||||
@ -1081,45 +1081,25 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
|
|||||||
else:
|
else:
|
||||||
out = processor.process(None, image, tokens=inputs["prompt_token_ids"])
|
out = processor.process(None, image, tokens=inputs["prompt_token_ids"])
|
||||||
|
|
||||||
|
# If there is no image, return directly.
|
||||||
|
if image is None:
|
||||||
|
new_prompt_token_ids = out["input_ids"].tolist()
|
||||||
|
prompt = inputs.get("prompt")
|
||||||
|
if prompt is None:
|
||||||
|
prompt = tokenizer.decode(new_prompt_token_ids)
|
||||||
|
return token_inputs(
|
||||||
|
prompt_token_ids=new_prompt_token_ids,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
image_processor = processor.image_processor
|
image_processor = processor.image_processor
|
||||||
max_total_crops = 1 + image_processor.max_crops
|
max_total_crops = 1 + image_processor.max_crops
|
||||||
if image is not None:
|
|
||||||
images, image_input_idx, image_masks = pad_images(
|
images, image_input_idx, image_masks = pad_images(
|
||||||
max_total_crops,
|
max_total_crops,
|
||||||
out["images"],
|
out["images"],
|
||||||
out["image_input_idx"],
|
out["image_input_idx"],
|
||||||
out.get("image_masks"),
|
out.get("image_masks"),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
base_image_input_size = image_processor.base_image_input_size
|
|
||||||
image_patch_size = image_processor.image_patch_size
|
|
||||||
image_num_patch = (
|
|
||||||
base_image_input_size[0] // image_patch_size,
|
|
||||||
base_image_input_size[1] // image_patch_size,
|
|
||||||
)
|
|
||||||
n_pixels = image_patch_size * image_patch_size * 3
|
|
||||||
n_patches = image_num_patch[0] * image_num_patch[1]
|
|
||||||
|
|
||||||
image_length_w = image_processor.image_token_length_w
|
|
||||||
image_length_h = image_processor.image_token_length_h
|
|
||||||
tokens_per_image = image_length_w * image_length_h
|
|
||||||
images = torch.full(
|
|
||||||
(max_total_crops, n_patches, n_pixels),
|
|
||||||
-1,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
image_input_idx = torch.full(
|
|
||||||
(max_total_crops, tokens_per_image),
|
|
||||||
-1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
if image_processor.image_padding_mask:
|
|
||||||
image_masks = torch.full(
|
|
||||||
(max_total_crops, n_patches),
|
|
||||||
-1,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_data = dict(
|
image_data = dict(
|
||||||
images=images,
|
images=images,
|
||||||
image_input_idx=image_input_idx,
|
image_input_idx=image_input_idx,
|
||||||
@ -1143,11 +1123,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
|
|||||||
offset = i
|
offset = i
|
||||||
size += 1
|
size += 1
|
||||||
image_data["image_start_end"] = (offset, offset + size)
|
image_data["image_start_end"] = (offset, offset + size)
|
||||||
|
|
||||||
prompt = inputs.get("prompt")
|
prompt = inputs.get("prompt")
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = tokenizer.decode(new_prompt_token_ids)
|
prompt = tokenizer.decode(new_prompt_token_ids)
|
||||||
|
|
||||||
return token_inputs(
|
return token_inputs(
|
||||||
prompt_token_ids=new_prompt_token_ids,
|
prompt_token_ids=new_prompt_token_ids,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user