Re-submit: Fix: Proper RGBA -> RGB conversion for PIL images. (#18569)

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
Chenheli Hua 2025-05-22 18:59:18 -07:00 committed by GitHub
parent 46791e1b4b
commit 04eb88dc80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 89 additions and 20 deletions

View File

@ -35,6 +35,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -257,7 +258,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
if isinstance(image, dict) and "bytes" in image: if isinstance(image, dict) and "bytes" in image:
image = Image.open(BytesIO(image["bytes"])) image = Image.open(BytesIO(image["bytes"]))
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image = image.convert("RGB") image = convert_image_mode(image, "RGB")
with io.BytesIO() as image_data: with io.BytesIO() as image_data:
image.save(image_data, format="JPEG") image.save(image_data, format="JPEG")
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")

View File

@ -11,6 +11,7 @@ from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -45,7 +46,8 @@ def get_mixed_modalities_query() -> QueryResult:
"audio": "audio":
AudioAsset("mary_had_lamb").audio_and_sample_rate, AudioAsset("mary_had_lamb").audio_and_sample_rate,
"image": "image":
ImageAsset("cherry_blossom").pil_image.convert("RGB"), convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB"),
"video": "video":
VideoAsset(name="baby_reading", num_frames=16).np_ndarrays, VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
}, },

View File

@ -19,6 +19,7 @@ from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.image import convert_image_mode
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -1096,8 +1097,8 @@ def get_multi_modal_input(args):
""" """
if args.modality == "image": if args.modality == "image":
# Input image and question # Input image and question
image = ImageAsset("cherry_blossom") \ image = convert_image_mode(
.pil_image.convert("RGB") ImageAsset("cherry_blossom").pil_image, "RGB")
img_questions = [ img_questions = [
"What is the content of this image?", "What is the content of this image?",
"Describe the content of this image in detail.", "Describe the content of this image in detail.",

View File

@ -4,6 +4,7 @@ import pytest
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode
models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"] models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]
@ -26,8 +27,9 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None:
give the same result. give the same result.
""" """
image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB") image_cherry = convert_image_mode(
image_stop = ImageAsset("stop_sign").pil_image.convert("RGB") ImageAsset("cherry_blossom").pil_image, "RGB")
image_stop = convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB")
images = [image_cherry, image_stop] images = [image_cherry, image_stop]
video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays

View File

@ -12,7 +12,7 @@ from transformers import AutoTokenizer
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.image import rescale_image_size from vllm.multimodal.image import convert_image_mode, rescale_image_size
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
@ -267,7 +267,7 @@ def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
# use the example speech question so that the model outputs are reasonable # use the example speech question so that the model outputs are reasonable
audio = librosa.load(speech_question, sr=None) audio = librosa.load(speech_question, sr=None)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB") image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
inputs_vision_speech = [ inputs_vision_speech = [
( (

View File

@ -4,6 +4,7 @@ import pytest
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.multimodal.image import convert_image_mode
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
@ -58,7 +59,7 @@ def test_oot_registration_embedding(
assert all(v == 0 for v in output.outputs.embedding) assert all(v == 0 for v in output.outputs.embedding)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB") image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
@create_new_process_for_each_test() @create_new_process_for_each_test()

Binary file not shown.

After

Width:  |  Height:  |  Size: 219 KiB

View File

@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
import numpy as np
from PIL import Image, ImageChops
from vllm.multimodal.image import convert_image_mode
ASSETS_DIR = Path(__file__).parent / "assets"
assert ASSETS_DIR.exists()
def test_rgb_to_rgb():
# Start with an RGB image.
original_image = Image.open(ASSETS_DIR / "image1.png").convert("RGB")
converted_image = convert_image_mode(original_image, "RGB")
# RGB to RGB should be a no-op.
diff = ImageChops.difference(original_image, converted_image)
assert diff.getbbox() is None
def test_rgba_to_rgb():
original_image = Image.open(ASSETS_DIR / "rgba.png")
original_image_numpy = np.array(original_image)
converted_image = convert_image_mode(original_image, "RGB")
converted_image_numpy = np.array(converted_image)
for i in range(original_image_numpy.shape[0]):
for j in range(original_image_numpy.shape[1]):
# Verify that all transparent pixels are converted to white.
if original_image_numpy[i][j][3] == 0:
assert converted_image_numpy[i][j][0] == 255
assert converted_image_numpy[i][j][1] == 255
assert converted_image_numpy[i][j][2] == 255

View File

@ -10,6 +10,7 @@ import numpy as np
import pytest import pytest
from PIL import Image, ImageChops from PIL import Image, ImageChops
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector, from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata) merge_and_sort_multimodal_metadata)
@ -53,7 +54,7 @@ def get_supported_suffixes() -> tuple[str, ...]:
def _image_equals(a: Image.Image, b: Image.Image) -> bool: def _image_equals(a: Image.Image, b: Image.Image) -> bool:
return (np.asarray(a) == np.asarray(b.convert(a.mode))).all() return (np.asarray(a) == np.asarray(convert_image_mode(b, a.mode))).all()
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -13,7 +13,6 @@ generation. Supported dataset types include:
TODO: Implement CustomDataset to parse a JSON file and convert its contents into TODO: Implement CustomDataset to parse a JSON file and convert its contents into
SampleRequest instances, similar to the approach used in ShareGPT. SampleRequest instances, similar to the approach used in ShareGPT.
""" """
import base64 import base64
import io import io
import json import json
@ -33,6 +32,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.image import convert_image_mode
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -259,7 +259,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
if isinstance(image, dict) and 'bytes' in image: if isinstance(image, dict) and 'bytes' in image:
image = Image.open(BytesIO(image['bytes'])) image = Image.open(BytesIO(image['bytes']))
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image = image.convert("RGB") image = convert_image_mode(image, "RGB")
with io.BytesIO() as image_data: with io.BytesIO() as image_data:
image.save(image_data, format="JPEG") image.save(image_data, format="JPEG")
image_base64 = base64.b64encode( image_base64 = base64.b64encode(

View File

@ -23,6 +23,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel) InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors) MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
@ -77,7 +78,7 @@ InternVLImageInputs = Union[InternVLImagePixelInputs,
def build_transform(input_size: int): def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
return T.Compose([ return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
T.Resize((input_size, input_size), T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(), T.ToTensor(),

View File

@ -24,6 +24,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel) InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors) MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
@ -78,7 +79,7 @@ SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs,
def build_transform(input_size: int): def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
return T.Compose([ return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
T.Resize((input_size, input_size), T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(), T.ToTensor(),

View File

@ -10,6 +10,7 @@ from blake3 import blake3
from PIL import Image from PIL import Image
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.image import convert_image_mode
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
@ -35,7 +36,8 @@ class MultiModalHasher:
return np.array(obj).tobytes() return np.array(obj).tobytes()
if isinstance(obj, Image.Image): if isinstance(obj, Image.Image):
return cls.item_to_bytes("image", np.array(obj.convert("RGBA"))) return cls.item_to_bytes("image",
np.array(convert_image_mode(obj, "RGBA")))
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
return cls.item_to_bytes("tensor", obj.numpy()) return cls.item_to_bytes("tensor", obj.numpy())
if isinstance(obj, np.ndarray): if isinstance(obj, np.ndarray):

View File

@ -22,6 +22,25 @@ def rescale_image_size(image: Image.Image,
return image return image
# TODO: Support customizable background color to fill in.
def rgba_to_rgb(
image: Image.Image, background_color=(255, 255, 255)) -> Image.Image:
"""Convert an RGBA image to RGB with filled background color."""
assert image.mode == "RGBA"
converted = Image.new("RGB", image.size, background_color)
converted.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return converted
def convert_image_mode(image: Image.Image, to_mode: str):
if image.mode == to_mode:
return image
elif image.mode == "RGBA" and to_mode == "RGB":
return rgba_to_rgb(image)
else:
return image.convert(to_mode)
class ImageMediaIO(MediaIO[Image.Image]): class ImageMediaIO(MediaIO[Image.Image]):
def __init__(self, *, image_mode: str = "RGB") -> None: def __init__(self, *, image_mode: str = "RGB") -> None:
@ -32,7 +51,7 @@ class ImageMediaIO(MediaIO[Image.Image]):
def load_bytes(self, data: bytes) -> Image.Image: def load_bytes(self, data: bytes) -> Image.Image:
image = Image.open(BytesIO(data)) image = Image.open(BytesIO(data))
image.load() image.load()
return image.convert(self.image_mode) return convert_image_mode(image, self.image_mode)
def load_base64(self, media_type: str, data: str) -> Image.Image: def load_base64(self, media_type: str, data: str) -> Image.Image:
return self.load_bytes(base64.b64decode(data)) return self.load_bytes(base64.b64decode(data))
@ -40,7 +59,7 @@ class ImageMediaIO(MediaIO[Image.Image]):
def load_file(self, filepath: Path) -> Image.Image: def load_file(self, filepath: Path) -> Image.Image:
image = Image.open(filepath) image = Image.open(filepath)
image.load() image.load()
return image.convert(self.image_mode) return convert_image_mode(image, self.image_mode)
def encode_base64( def encode_base64(
self, self,
@ -51,7 +70,7 @@ class ImageMediaIO(MediaIO[Image.Image]):
image = media image = media
with BytesIO() as buffer: with BytesIO() as buffer:
image = image.convert(self.image_mode) image = convert_image_mode(image, self.image_mode)
image.save(buffer, image_format) image.save(buffer, image_format)
data = buffer.getvalue() data = buffer.getvalue()

View File

@ -33,6 +33,8 @@ from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
Unpack) Unpack)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.multimodal.image import convert_image_mode
__all__ = ['OvisProcessor'] __all__ = ['OvisProcessor']
IGNORE_ID = -100 IGNORE_ID = -100
@ -361,8 +363,8 @@ class OvisProcessor(ProcessorMixin):
# pick the partition with maximum covering_ratio and break the tie using #sub_images # pick the partition with maximum covering_ratio and break the tie using #sub_images
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
if convert_to_rgb and image.mode != 'RGB': if convert_to_rgb:
image = image.convert('RGB') image = convert_image_mode(image, 'RGB')
sides = self.get_image_size() sides = self.get_image_size()