[Bugfix] Ensure correctness of Cohere2Vision processing (#23245)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-20 19:09:18 +08:00 committed by GitHub
parent 83e69a09d6
commit 68fcd3fa73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 19 deletions

View File

@ -268,6 +268,7 @@ def _test_processing_correctness_one(
"CohereForAI/aya-vision-8b",
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"CohereLabs/command-a-vision-07-2025",
"deepseek-ai/deepseek-vl2-tiny",
"microsoft/Florence-2-base",
"adept/fuyu-8b",

View File

@ -250,8 +250,7 @@ class AyaVisionMultiModalProcessor(
image_processor = hf_processor.image_processor
def get_replacement(item_idx: int):
images: ImageProcessorItems = mm_items.get("image",
ImageProcessorItems)
images = mm_items.get_items("image", ImageProcessorItems)
image_size: ImageSize = images.get_image_size(item_idx)
num_patches = self.info.get_num_patches(
image_width=image_size.width,

View File

@ -10,6 +10,8 @@ import torch
from torch import nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.cohere2_vision import Cohere2VisionConfig
from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import ( # noqa: E501
get_optimal_tiled_canvas)
from transformers.models.cohere2_vision.processing_cohere2_vision import (
Cohere2VisionProcessor)
@ -150,14 +152,46 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo):
max_patches = image_processor.max_patches
return ImageSize(height=height * max_patches, width=width)
def get_num_patches(self, image_width: int, image_height: int) -> int:
def get_num_patches(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Cohere2VisionProcessor],
) -> int:
"""
Calculate the number of image patches for a given image.
Uses the HF processor to determine the actual number of patches.
"""
return self.get_hf_processor(
).image_processor.get_number_of_image_patches(image_height,
image_width, {})
if processor is None:
processor = self.get_hf_processor()
image_processor = processor.image_processor
# The current implementation of get_number_of_image_patches
# is incorrect, so we patch it here.
# return image_processor.get_number_of_image_patches(image_height,
# image_width, {})
min_patches = image_processor.min_patches
max_patches = image_processor.max_patches
patch_size = image_processor.size
crop_to_patches = image_processor.crop_to_patches
if not crop_to_patches:
return 1
num_columns, num_rows = get_optimal_tiled_canvas(
(image_height, image_width),
(patch_size["height"], patch_size["width"]),
min_patches,
max_patches,
)
num_patches = num_columns * num_rows
if num_patches > 1:
num_patches += 1 # Thumbnail image
return num_patches
class Cohere2VisionDummyInputsBuilder(
@ -208,6 +242,8 @@ class Cohere2VisionMultiModalProcessor(
# Ensure num_patches is available for proper tensor splitting
if "num_patches" not in processed_outputs and (
images := mm_data.get("images")) is not None:
hf_processor = self.info.get_hf_processor(**mm_kwargs)
# Fallback calculation if HF processor didn't provide num_patches
parsed_images = self._get_data_parser().parse_mm_data({
"image":
@ -217,8 +253,9 @@ class Cohere2VisionMultiModalProcessor(
num_patches = [
self.info.get_num_patches(
image_width=parsed_images.get_image_size(i).width,
image_height=parsed_images.get_image_size(i).height)
for i in range(len(parsed_images))
image_height=parsed_images.get_image_size(i).height,
processor=hf_processor,
) for i in range(len(parsed_images))
]
processed_outputs["num_patches"] = torch.tensor(num_patches)
@ -245,25 +282,25 @@ class Cohere2VisionMultiModalProcessor(
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token
img_tokens_per_tile = int(hf_processor.patch_size**2)
img_line_break_token = hf_processor.img_line_break_token
boi_token = hf_processor.boi_token
eoi_token = hf_processor.eoi_token
def get_replacement(item_idx: int):
images: ImageProcessorItems = mm_items.get("image",
ImageProcessorItems)
images = mm_items.get_items("image", ImageProcessorItems)
image_size: ImageSize = images.get_image_size(item_idx)
num_patches = self.info.get_num_patches(image_size.height,
image_size.width)
img_tokens_per_tile = int(hf_processor.patch_size**2)
single_tile_tokens = image_token * img_tokens_per_tile + \
img_line_break_token
img_string = f"{boi_token}\
{single_tile_tokens * num_patches}\
{eoi_token}"
num_patches = self.info.get_num_patches(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
patch_tokens = (image_token * img_tokens_per_tile +
img_line_break_token)
repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}"
return PromptUpdateDetails.select_text(img_string, image_token)
return PromptUpdateDetails.select_text(repl, image_token)
return [
PromptReplacement(