mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:30:37 +08:00
[VLM] Support pan-and-scan for Gemma3 multi-modal processor (#14672)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
a73122de96
commit
382403921f
@ -763,7 +763,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎\*
|
||||
* ⚠️
|
||||
- * `GLM4VForCausalLM`<sup>^</sup>
|
||||
* GLM-4V
|
||||
* T + I
|
||||
@ -856,12 +856,12 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `PaliGemmaForConditionalGeneration`
|
||||
* PaliGemma ⚠️, PaliGemma 2 ⚠️
|
||||
* PaliGemma, PaliGemma 2
|
||||
* T + I<sup>E</sup>
|
||||
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ⚠️
|
||||
- * `Phi3VForCausalLM`
|
||||
* Phi-3-Vision, Phi-3.5-Vision
|
||||
* T + I<sup>E+</sup>
|
||||
@ -926,34 +926,15 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
|
||||
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
|
||||
|
||||
:::{warning}
|
||||
vLLM does not currently support PrefixLM attention mask, so our PaliGemma implementation uses regular causal attention, which causes the model output to be unstable.
|
||||
|
||||
We may deprecate this model series in a future release.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support backends other than FlashAttention.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
|
||||
For more details, please see: <gh-pr:4087#issuecomment-2250397630>
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
:::{important}
|
||||
To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
|
||||
`pip install git+https://github.com/huggingface/transformers`.
|
||||
The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357).
|
||||
|
||||
Pan-and-scan image pre-processing is currently supported on V0 (but not V1).
|
||||
You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`.
|
||||
:::
|
||||
|
||||
:::{warning}
|
||||
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
|
||||
However, there are differences in how they handle text + image inputs:
|
||||
|
||||
@ -969,9 +950,23 @@ V1 currently uses a simplified attention pattern:
|
||||
- Will be updated in the future to support the correct behavior
|
||||
|
||||
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
|
||||
:::
|
||||
|
||||
Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views.
|
||||
Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions.
|
||||
:::{note}
|
||||
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support backends other than FlashAttention.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
|
||||
For more details, please see: <gh-pr:4087#issuecomment-2250397630>
|
||||
:::
|
||||
|
||||
:::{warning}
|
||||
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
|
||||
:::
|
||||
|
||||
### Pooling Models
|
||||
|
||||
@ -123,10 +123,14 @@ def run_gemma3(questions: list[str], modality: str):
|
||||
assert modality == "image"
|
||||
model_name = "google/gemma-3-4b-it"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
# Default is False; setting it to True is not supported in V1 yet
|
||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
prompts = [("<bos><start_of_turn>user\n"
|
||||
f"<start_of_image>{question}<end_of_turn>\n"
|
||||
|
||||
@ -83,10 +83,14 @@ def load_deepseek_vl2(question: str, image_urls: list[str]):
|
||||
def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "google/gemma-3-4b-it"
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)})
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
# Default is False; setting it to True is not supported in V1 yet
|
||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
|
||||
@ -9,7 +9,7 @@ from pathlib import PosixPath
|
||||
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
from transformers import AutoModelForVision2Seq
|
||||
from transformers import AutoModelForPreTraining, AutoModelForVision2Seq
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
@ -234,6 +234,23 @@ VLM_TEST_SETTINGS = {
|
||||
num_logprobs=10,
|
||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||
),
|
||||
"gemma3": VLMTestInfo(
|
||||
models=["google/gemma-3-4b-it"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
|
||||
single_image_prompts=IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "<start_of_image>What's the content in the center of the image?", # noqa: E501
|
||||
"cherry_blossom": "<start_of_image>What is the season?", # noqa: E501
|
||||
}),
|
||||
multi_image_prompt="<start_of_image><start_of_image>Describe the two images in detail.", # noqa: E501
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
# TODO: Use AutoModelForVision2Seq once transformers supports this
|
||||
auto_cls=AutoModelForPreTraining,
|
||||
dtype="bfloat16",
|
||||
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
|
||||
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
|
||||
),
|
||||
"glm4v": VLMTestInfo(
|
||||
models=["THUDM/glm-4v-9b"],
|
||||
test_type=VLMTestType.IMAGE,
|
||||
|
||||
@ -304,6 +304,18 @@ def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
return hf_model
|
||||
|
||||
|
||||
def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for Gemma 3."""
|
||||
hf_processor = hf_model.processor
|
||||
|
||||
def processor(*args, **kwargs):
|
||||
return hf_processor(*args, do_pan_and_scan=True, **kwargs)
|
||||
|
||||
hf_model.processor = processor
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for GLM4."""
|
||||
hf_processor = hf_model.processor
|
||||
|
||||
@ -348,7 +348,11 @@ class InputRegistry:
|
||||
dummy_factory = self._get_dummy_data_factory(model_cls)
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
dummy_factory, overrides=model_config.mm_processor_kwargs)
|
||||
dummy_factory,
|
||||
overrides=model_config.mm_processor_kwargs,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
|
||||
dummy_data = dummy_factory(InputContext(model_config), seq_len,
|
||||
_MultiModalCounts(mm_counts),
|
||||
@ -381,6 +385,7 @@ class InputRegistry:
|
||||
self,
|
||||
ctx: InputContext,
|
||||
inputs: ProcessorInputs,
|
||||
**kwargs: object,
|
||||
) -> ProcessorInputs:
|
||||
"""The default input processor is a no-op."""
|
||||
return inputs
|
||||
@ -447,6 +452,8 @@ class InputRegistry:
|
||||
model_config.mm_processor_kwargs,
|
||||
inputs.get("mm_processor_kwargs", {}), # type: ignore
|
||||
processor,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
|
||||
processed_inputs = processor(
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import math
|
||||
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, Gemma3Config, ProcessorMixin
|
||||
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
|
||||
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -14,10 +16,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
PromptUpdate, encode_tokens)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -31,8 +34,15 @@ logger = init_logger(__name__)
|
||||
|
||||
class Gemma3ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
pixel_values: torch.Tensor
|
||||
"""
|
||||
Shape: `(num_crops_total, num_channels, height, width)`
|
||||
|
||||
`num_crops_total` is the total number of crops
|
||||
over each image over each prompt in the batch.
|
||||
"""
|
||||
num_crops: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images,)`"""
|
||||
|
||||
|
||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
@ -40,6 +50,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
|
||||
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
@ -48,22 +61,160 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
return {"image": hf_config.mm_tokens_per_image}
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
|
||||
def _resolve_image_kwargs(
|
||||
self,
|
||||
processor: Gemma3Processor,
|
||||
keys: set[str],
|
||||
) -> dict[str, Any]:
|
||||
image_processor = processor.image_processor
|
||||
kwargs = processor._merge_kwargs(
|
||||
Gemma3ProcessorKwargs,
|
||||
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
|
||||
)
|
||||
|
||||
images_kwargs = kwargs["images_kwargs"]
|
||||
|
||||
def _resolve_kw(key: str):
|
||||
val = getattr(image_processor, key)
|
||||
if val is None:
|
||||
val = images_kwargs[key]
|
||||
|
||||
return val
|
||||
|
||||
return {k: _resolve_kw(k) for k in keys}
|
||||
|
||||
def get_num_crops(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[Gemma3Processor],
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
images_kwargs = self._resolve_image_kwargs(
|
||||
processor, {
|
||||
"do_pan_and_scan", "pan_and_scan_min_crop_size",
|
||||
"pan_and_scan_max_num_crops",
|
||||
"pan_and_scan_min_ratio_to_activate"
|
||||
})
|
||||
|
||||
do_pan_and_scan = images_kwargs["do_pan_and_scan"]
|
||||
pan_and_scan_min_crop_size = images_kwargs[
|
||||
"pan_and_scan_min_crop_size"]
|
||||
pan_and_scan_max_num_crops = images_kwargs[
|
||||
"pan_and_scan_max_num_crops"]
|
||||
pan_and_scan_min_ratio_to_activate = images_kwargs[
|
||||
"pan_and_scan_min_ratio_to_activate"]
|
||||
|
||||
if not do_pan_and_scan:
|
||||
return 0
|
||||
|
||||
# Based on Gemma3ImageProcessor.pan_and_scan
|
||||
if image_width >= image_height:
|
||||
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
|
||||
return 0
|
||||
|
||||
num_crops_w = min(
|
||||
int(math.floor(image_width / pan_and_scan_min_crop_size)),
|
||||
int(math.floor(image_width / image_height + 0.5)),
|
||||
)
|
||||
|
||||
num_crops_w = max(2, num_crops_w)
|
||||
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
|
||||
num_crops_h = 1
|
||||
else:
|
||||
if image_height / image_width < pan_and_scan_min_ratio_to_activate:
|
||||
return 0
|
||||
|
||||
num_crops_h = min(
|
||||
int(math.floor(image_height / pan_and_scan_min_crop_size)),
|
||||
int(math.floor(image_height / image_width + 0.5)),
|
||||
)
|
||||
|
||||
num_crops_h = max(2, num_crops_h)
|
||||
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
|
||||
num_crops_w = 1
|
||||
|
||||
crop_size_w = int(math.ceil(image_width / num_crops_w))
|
||||
crop_size_h = int(math.ceil(image_height / num_crops_h))
|
||||
|
||||
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
|
||||
return 0
|
||||
|
||||
return num_crops_w * num_crops_h
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[Gemma3Processor],
|
||||
) -> str:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
image_token = processor.boi_token
|
||||
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
if num_crops == 0:
|
||||
image_text = image_token
|
||||
else:
|
||||
crops_image_tokens = " ".join(image_token
|
||||
for _ in range(num_crops))
|
||||
image_text = (
|
||||
f"Here is the original image {image_token} and here are some "
|
||||
f"crops to help you see better {crops_image_tokens}")
|
||||
|
||||
return image_text.replace(image_token, processor.full_image_sequence)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[ProcessorMixin],
|
||||
processor: Optional[Gemma3Processor],
|
||||
) -> int:
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
return hf_config.mm_tokens_per_image
|
||||
tokenizer = self.get_tokenizer()
|
||||
image_repl = self.get_image_repl(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
image_repl_tokens = encode_tokens(
|
||||
tokenizer,
|
||||
image_repl,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
return len(image_repl_tokens)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
return ImageSize(height=8000, width=50)
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
images_kwargs = self._resolve_image_kwargs(
|
||||
processor, {"pan_and_scan_max_num_crops"})
|
||||
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
|
||||
|
||||
# Result in the max possible feature size (h:w = max_num_crops:1)
|
||||
return ImageSize(height=50 * max_num_crops, width=50)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
processor=None,
|
||||
)
|
||||
|
||||
|
||||
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
||||
@ -73,10 +224,11 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
boi_token = tokenizer.boi_token
|
||||
processor = self.info.get_hf_processor()
|
||||
image_token = processor.boi_token
|
||||
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
@ -86,8 +238,13 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
# NOTE: We need to separate the image tokens here because
|
||||
# encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
|
||||
# with the detection of prompt updates when the image tokens are
|
||||
# right next to each other
|
||||
return ProcessorInputs(
|
||||
prompt_text=" ".join([boi_token] * num_images),
|
||||
prompt_text=" ".join([image_token] * num_images),
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
@ -100,22 +257,49 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# TODO(woosuk): Support pan-and-scan.
|
||||
img_kwargs = mm_kwargs.get("images_kwargs", {})
|
||||
img_kwargs["do_pan_and_scan"] = False
|
||||
mm_kwargs["images_kwargs"] = img_kwargs
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt,
|
||||
mm_data,
|
||||
mm_kwargs,
|
||||
)
|
||||
|
||||
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
|
||||
if (images := mm_data.get("images")) is not None:
|
||||
assert isinstance(images, list)
|
||||
|
||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
||||
"image":
|
||||
images
|
||||
}).get_items("image", ImageProcessorItems))
|
||||
image_sizes = [
|
||||
parsed_images.get_image_size(i)
|
||||
for i in range(len(parsed_images))
|
||||
]
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
|
||||
num_crops = [
|
||||
self.info.get_num_crops(image_width=size.width,
|
||||
image_height=size.height,
|
||||
processor=hf_processor)
|
||||
for size in image_sizes
|
||||
]
|
||||
|
||||
processed_outputs["num_crops"] = torch.tensor(num_crops)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
num_crops = hf_inputs.get("num_crops", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops + 1),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@ -123,25 +307,23 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
hf_config = self.info.get_hf_config()
|
||||
|
||||
boi_token = tokenizer.boi_token
|
||||
image_token = tokenizer.image_token
|
||||
mm_tokens_per_image = hf_config.mm_tokens_per_image
|
||||
image_tokens_expanded = "".join([image_token] * mm_tokens_per_image)
|
||||
image_token = hf_processor.boi_token
|
||||
|
||||
def get_replacement_gemma3(item_idx: int):
|
||||
return PromptUpdateDetails(
|
||||
full=hf_processor.full_image_sequence,
|
||||
features=image_tokens_expanded,
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
|
||||
image_size = images.get_image_size(item_idx)
|
||||
return self.info.get_image_repl(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=boi_token,
|
||||
target=image_token,
|
||||
replacement=get_replacement_gemma3,
|
||||
)
|
||||
]
|
||||
@ -254,19 +436,27 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list[torch.Tensor])):
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops values. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(
|
||||
@ -283,7 +473,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_input: Gemma3ImageInputs,
|
||||
) -> torch.Tensor:
|
||||
assert self.vision_tower is not None
|
||||
pixel_values = image_input["data"]
|
||||
|
||||
pixel_values = image_input["pixel_values"]
|
||||
vision_outputs = self._image_pixels_to_features(
|
||||
self.vision_tower,
|
||||
pixel_values,
|
||||
|
||||
@ -226,7 +226,11 @@ class MultiModalPlugin(ABC):
|
||||
|
||||
if callable(max_mm_tokens):
|
||||
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
|
||||
max_mm_tokens, overrides=model_config.mm_processor_kwargs)
|
||||
max_mm_tokens,
|
||||
overrides=model_config.mm_processor_kwargs,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
max_mm_tokens = max_mm_tokens(InputContext(model_config),
|
||||
**mm_processor_kwargs)
|
||||
|
||||
|
||||
@ -1488,11 +1488,11 @@ def get_allowed_kwarg_only_overrides(
|
||||
if requires_kw_only:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword-only args "
|
||||
"and and will be dropped: %s", dropped_keys)
|
||||
"and will be dropped: %s", dropped_keys)
|
||||
else:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword args "
|
||||
"and and will be dropped: %s", dropped_keys)
|
||||
"and will be dropped: %s", dropped_keys)
|
||||
|
||||
return filtered_overrides
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user