mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 21:08:53 +08:00
[Misc] Introduce encode_*_url utility function (#31208)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3faa8bee57
commit
bb62dda2c3
@ -8,7 +8,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
|
||||
from vllm.multimodal.utils import encode_audio_base64, encode_audio_url, fetch_audio
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@ -53,6 +53,14 @@ def base64_encoded_audio() -> dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def url_encoded_audio() -> dict[str, str]:
|
||||
return {
|
||||
audio_url: encode_audio_url(*fetch_audio(audio_url))
|
||||
for audio_url in TEST_AUDIO_URLS
|
||||
}
|
||||
|
||||
|
||||
def dummy_messages_from_audio_url(
|
||||
audio_urls: str | list[str],
|
||||
content_text: str = "What's happening in this audio?",
|
||||
@ -149,11 +157,9 @@ async def test_single_chat_session_audio_base64encoded(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
audio_url: str,
|
||||
base64_encoded_audio: dict[str, str],
|
||||
url_encoded_audio: dict[str, str],
|
||||
):
|
||||
messages = dummy_messages_from_audio_url(
|
||||
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
|
||||
)
|
||||
messages = dummy_messages_from_audio_url(url_encoded_audio[audio_url])
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
|
||||
@ -7,7 +7,7 @@ import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.multimodal.utils import encode_video_base64, fetch_video
|
||||
from vllm.multimodal.utils import encode_video_url, fetch_video
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@ -48,9 +48,9 @@ async def client(server):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base64_encoded_video() -> dict[str, str]:
|
||||
def url_encoded_video() -> dict[str, str]:
|
||||
return {
|
||||
video_url: encode_video_base64(fetch_video(video_url)[0])
|
||||
video_url: encode_video_url(fetch_video(video_url)[0])
|
||||
for video_url in TEST_VIDEO_URLS
|
||||
}
|
||||
|
||||
@ -175,11 +175,9 @@ async def test_single_chat_session_video_base64encoded(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
video_url: str,
|
||||
base64_encoded_video: dict[str, str],
|
||||
url_encoded_video: dict[str, str],
|
||||
):
|
||||
messages = dummy_messages_from_video_url(
|
||||
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
|
||||
)
|
||||
messages = dummy_messages_from_video_url(url_encoded_video[video_url])
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@ -223,11 +221,9 @@ async def test_single_chat_session_video_base64encoded_beamsearch(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
video_url: str,
|
||||
base64_encoded_video: dict[str, str],
|
||||
url_encoded_video: dict[str, str],
|
||||
):
|
||||
messages = dummy_messages_from_video_url(
|
||||
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
|
||||
)
|
||||
messages = dummy_messages_from_video_url(url_encoded_video[video_url])
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
|
||||
@ -9,7 +9,7 @@ import pytest_asyncio
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm.multimodal.base import MediaWithBytes
|
||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||
from vllm.multimodal.utils import encode_image_url, fetch_image
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@ -35,7 +35,7 @@ EXPECTED_MM_BEAM_SEARCH_RES = [
|
||||
],
|
||||
[
|
||||
"The image shows a Venn diagram with three over",
|
||||
"The image shows a colorful Venn diagram with",
|
||||
"The image displays a Venn diagram with three over",
|
||||
],
|
||||
[
|
||||
"This image displays a gradient of colors ranging from",
|
||||
@ -70,11 +70,9 @@ async def client(server):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base64_encoded_image(local_asset_server) -> dict[str, str]:
|
||||
def url_encoded_image(local_asset_server) -> dict[str, str]:
|
||||
return {
|
||||
image_asset: encode_image_base64(
|
||||
local_asset_server.get_image_asset(image_asset)
|
||||
)
|
||||
image_asset: encode_image_url(local_asset_server.get_image_asset(image_asset))
|
||||
for image_asset in TEST_IMAGE_ASSETS
|
||||
}
|
||||
|
||||
@ -234,11 +232,11 @@ async def test_single_chat_session_image_base64encoded(
|
||||
model_name: str,
|
||||
raw_image_url: str,
|
||||
image_url: str,
|
||||
base64_encoded_image: dict[str, str],
|
||||
url_encoded_image: dict[str, str],
|
||||
):
|
||||
content_text = "What's in this image?"
|
||||
messages = dummy_messages_from_image_url(
|
||||
f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}",
|
||||
url_encoded_image[raw_image_url],
|
||||
content_text,
|
||||
)
|
||||
|
||||
@ -288,15 +286,13 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
image_idx: int,
|
||||
base64_encoded_image: dict[str, str],
|
||||
url_encoded_image: dict[str, str],
|
||||
):
|
||||
# NOTE: This test also validates that we pass MM data through beam search
|
||||
raw_image_url = TEST_IMAGE_ASSETS[image_idx]
|
||||
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx]
|
||||
|
||||
messages = dummy_messages_from_image_url(
|
||||
f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}"
|
||||
)
|
||||
messages = dummy_messages_from_image_url(url_encoded_image[raw_image_url])
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
|
||||
@ -10,7 +10,7 @@ from transformers import AutoProcessor
|
||||
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||
from vllm.multimodal.base import MediaWithBytes
|
||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
|
||||
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
|
||||
MAXIMUM_IMAGES = 2
|
||||
@ -48,14 +48,6 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base64_encoded_image(local_asset_server) -> dict[str, str]:
|
||||
return {
|
||||
image_url: encode_image_base64(local_asset_server.get_image_asset(image_url))
|
||||
for image_url in TEST_IMAGE_ASSETS
|
||||
}
|
||||
|
||||
|
||||
def get_hf_prompt_tokens(model_name, content, image_url):
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name, trust_remote_code=True, num_crops=4
|
||||
|
||||
@ -25,9 +25,9 @@ from vllm.entrypoints.chat_utils import (
|
||||
)
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import (
|
||||
encode_audio_base64,
|
||||
encode_image_base64,
|
||||
encode_video_base64,
|
||||
encode_audio_url,
|
||||
encode_image_url,
|
||||
encode_video_url,
|
||||
)
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
@ -141,22 +141,19 @@ def mistral_model_config():
|
||||
@pytest.fixture(scope="module")
|
||||
def image_url():
|
||||
image = ImageAsset("cherry_blossom")
|
||||
base64 = encode_image_base64(image.pil_image)
|
||||
return f"data:image/jpeg;base64,{base64}"
|
||||
return encode_image_url(image.pil_image)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def video_url():
|
||||
video = VideoAsset("baby_reading", 1)
|
||||
base64 = encode_video_base64(video.np_ndarrays)
|
||||
return f"data:video/jpeg;base64,{base64}"
|
||||
return encode_video_url(video.np_ndarrays)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def audio_url():
|
||||
audio = AudioAsset("mary_had_lamb")
|
||||
base64 = encode_audio_base64(*audio.audio_and_sample_rate)
|
||||
return f"data:audio/ogg;base64,{base64}"
|
||||
return encode_audio_url(*audio.audio_and_sample_rate)
|
||||
|
||||
|
||||
def _assert_mm_data_is_image_input(
|
||||
|
||||
@ -8,7 +8,7 @@ from PIL.Image import Image
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.multimodal.utils import encode_image_url
|
||||
|
||||
MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"
|
||||
|
||||
@ -31,10 +31,7 @@ def test_keye_vl(
|
||||
question: str,
|
||||
):
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
image_urls = [
|
||||
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
|
||||
]
|
||||
image_urls = [encode_image_url(image) for image in images]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
|
||||
@ -15,7 +15,7 @@ from transformers import AutoProcessor
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.multimodal.utils import encode_image_url
|
||||
from vllm.multimodal.video import sample_frames_from_video
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -178,8 +178,7 @@ def build_dots_ocr_prompt(images, config):
|
||||
"""Build Dots.OCR specific prompt with OCR instructions."""
|
||||
# Use only stop_sign image for Dots.OCR
|
||||
image = images[0] # Already filtered to stop_sign
|
||||
|
||||
image_url = f"data:image/jpeg;base64,{encode_image_base64(image)}"
|
||||
image_url = encode_image_url(image)
|
||||
|
||||
placeholders = [{"type": "image_url", "image_url": {"url": image_url}}]
|
||||
messages = [
|
||||
@ -204,9 +203,7 @@ def build_processor_prompt(images, config):
|
||||
config["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
image_urls = [
|
||||
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
|
||||
]
|
||||
image_urls = [encode_image_url(img) for img in images]
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
@ -225,9 +222,7 @@ def build_processor_prompt(images, config):
|
||||
|
||||
def build_ovis_prompt(images, config):
|
||||
"""Build Ovis2.5 specific prompt with custom format."""
|
||||
image_urls = [
|
||||
f"data:image/jpeg;base64,{encode_image_base64(img)}" for img in images
|
||||
]
|
||||
image_urls = [encode_image_url(img) for img in images]
|
||||
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
|
||||
@ -31,7 +31,7 @@ import openai
|
||||
import requests
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.multimodal.utils import encode_image_url
|
||||
|
||||
MAX_OUTPUT_LEN = 256
|
||||
|
||||
@ -49,9 +49,7 @@ SAMPLE_PROMPTS_MM: list[dict] = [
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image;base64,{encode_image_base64(image_1)}"
|
||||
},
|
||||
"image_url": {"url": encode_image_url(image_1)},
|
||||
},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
@ -66,9 +64,7 @@ SAMPLE_PROMPTS_MM: list[dict] = [
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image;base64,{encode_image_base64(image_2)}"
|
||||
},
|
||||
"image_url": {"url": encode_image_url(image_2)},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.multimodal.utils import encode_image_url
|
||||
|
||||
# Use a small vision model for testing
|
||||
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
@ -52,9 +52,9 @@ async def client(image_server):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base64_encoded_image(local_asset_server) -> dict[str, str]:
|
||||
def url_encoded_image(local_asset_server) -> dict[str, str]:
|
||||
return {
|
||||
image_url: encode_image_base64(local_asset_server.get_image_asset(image_url))
|
||||
image_url: encode_image_url(local_asset_server.get_image_asset(image_url))
|
||||
for image_url in TEST_IMAGE_ASSETS
|
||||
}
|
||||
|
||||
@ -95,7 +95,7 @@ async def test_single_chat_session_image_base64encoded(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
raw_image_url: str,
|
||||
base64_encoded_image: dict[str, str],
|
||||
url_encoded_image: dict[str, str],
|
||||
):
|
||||
content_text = "What's in this image?"
|
||||
messages = [
|
||||
@ -104,7 +104,7 @@ async def test_single_chat_session_image_base64encoded(
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}", # noqa: E501
|
||||
"image_url": url_encoded_image[raw_image_url],
|
||||
"detail": "auto",
|
||||
},
|
||||
{"type": "input_text", "text": content_text},
|
||||
|
||||
@ -9,7 +9,7 @@ from PIL import Image
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.multimodal.utils import encode_image_url
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
|
||||
@ -74,7 +74,7 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
|
||||
placeholders = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image;base64,{encode_image_base64(image_pil)}"},
|
||||
"image_url": {"url": encode_image_url(image_pil)},
|
||||
}
|
||||
for image_pil in image_urls
|
||||
]
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.multimodal.utils import encode_image_url
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS
|
||||
@ -12,11 +12,9 @@ from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base64_encoded_image(local_asset_server) -> dict[str, str]:
|
||||
def url_encoded_image(local_asset_server) -> dict[str, str]:
|
||||
return {
|
||||
image_asset: encode_image_base64(
|
||||
local_asset_server.get_image_asset(image_asset)
|
||||
)
|
||||
image_asset: encode_image_url(local_asset_server.get_image_asset(image_asset))
|
||||
for image_asset in TEST_IMAGE_ASSETS
|
||||
}
|
||||
|
||||
@ -24,19 +22,16 @@ def base64_encoded_image(local_asset_server) -> dict[str, str]:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU")
|
||||
@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"])
|
||||
async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, str]):
|
||||
async def test_basic_vision(model_name: str, url_encoded_image: dict[str, str]):
|
||||
pytest.skip("Skip this test until it's fixed.")
|
||||
|
||||
def whats_in_this_image_msg(b64):
|
||||
def whats_in_this_image_msg(url):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
|
||||
},
|
||||
{"type": "image_url", "image_url": {"url": url}},
|
||||
],
|
||||
}
|
||||
]
|
||||
@ -63,14 +58,14 @@ async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, str
|
||||
|
||||
# Other requests now should be much faster
|
||||
for image_url in TEST_IMAGE_ASSETS:
|
||||
image_base64 = base64_encoded_image[image_url]
|
||||
chat_completion_from_base64 = await client.chat.completions.create(
|
||||
image_url = url_encoded_image[image_url]
|
||||
chat_completion_from_url = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=whats_in_this_image_msg(image_base64),
|
||||
messages=whats_in_this_image_msg(image_url),
|
||||
max_completion_tokens=24,
|
||||
temperature=0.0,
|
||||
)
|
||||
result = chat_completion_from_base64
|
||||
result = chat_completion_from_url
|
||||
assert result
|
||||
choice = result.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
|
||||
@ -111,11 +111,16 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(filepath, sr=None)
|
||||
|
||||
def encode_base64(self, media: tuple[npt.NDArray, int]) -> str:
|
||||
def encode_base64(
|
||||
self,
|
||||
media: tuple[npt.NDArray, int],
|
||||
*,
|
||||
audio_format: str = "WAV",
|
||||
) -> str:
|
||||
audio, sr = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
soundfile.write(buffer, audio, sr, format="WAV")
|
||||
soundfile.write(buffer, audio, sr, format=audio_format)
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
@ -8,8 +8,12 @@ import pybase64
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import MediaIO, MediaWithBytes
|
||||
|
||||
logger = init_logger(__file__)
|
||||
|
||||
|
||||
def rescale_image_size(
|
||||
image: Image.Image, size_factor: float, transpose: int = -1
|
||||
@ -104,8 +108,17 @@ class ImageMediaIO(MediaIO[Image.Image]):
|
||||
self,
|
||||
media: Image.Image,
|
||||
*,
|
||||
image_format: str = "JPEG",
|
||||
image_format: str | None = None,
|
||||
) -> str:
|
||||
if image_format is None:
|
||||
logger.warning_once(
|
||||
"The default format of `ImageMediaIO.encode_base64` will be changed "
|
||||
'from "JPEG" to "PNG" in v0.15 to avoid lossy compression. '
|
||||
"To continue using the old default, "
|
||||
'pass `format="JPEG"` explicitly to silence this warning.'
|
||||
)
|
||||
image_format = "JPEG"
|
||||
|
||||
image = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import mimetypes
|
||||
from collections.abc import Generator, Set
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
@ -357,17 +358,31 @@ class MediaConnector:
|
||||
def encode_audio_base64(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
*,
|
||||
format: str = "WAV",
|
||||
) -> str:
|
||||
"""Encode audio as base64."""
|
||||
audio_io = AudioMediaIO()
|
||||
return audio_io.encode_base64((audio, sampling_rate))
|
||||
return audio_io.encode_base64((audio, sampling_rate), audio_format=format)
|
||||
|
||||
|
||||
def encode_audio_url(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: int,
|
||||
*,
|
||||
format: str = "WAV",
|
||||
) -> str:
|
||||
"""Encode audio as a data URL."""
|
||||
audio_b64 = encode_audio_base64(audio, sampling_rate, format=format)
|
||||
mimetype = mimetypes.types_map.get("." + format.lower(), "audio")
|
||||
return f"data:{mimetype};base64,{audio_b64}"
|
||||
|
||||
|
||||
def encode_image_base64(
|
||||
image: Image.Image,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
format: str = "JPEG",
|
||||
format: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Encode a pillow image to base64 format.
|
||||
@ -378,10 +393,45 @@ def encode_image_base64(
|
||||
return image_io.encode_base64(image, image_format=format)
|
||||
|
||||
|
||||
def encode_video_base64(frames: npt.NDArray) -> str:
|
||||
def encode_image_url(
|
||||
image: Image.Image,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
format: str = "PNG",
|
||||
) -> str:
|
||||
"""
|
||||
Encode a pillow image as a data URL.
|
||||
|
||||
By default, the image is converted into RGB format before being encoded.
|
||||
"""
|
||||
image_b64 = encode_image_base64(image, image_mode=image_mode, format=format)
|
||||
mimetype = mimetypes.types_map.get("." + format.lower(), "image")
|
||||
return f"data:{mimetype};base64,{image_b64}"
|
||||
|
||||
|
||||
def encode_video_base64(
|
||||
frames: npt.NDArray,
|
||||
*,
|
||||
format: str = "JPEG",
|
||||
) -> str:
|
||||
image_io = ImageMediaIO()
|
||||
video_io = VideoMediaIO(image_io)
|
||||
return video_io.encode_base64(frames)
|
||||
return video_io.encode_base64(frames, video_format=format)
|
||||
|
||||
|
||||
def encode_video_url(
|
||||
frames: npt.NDArray,
|
||||
*,
|
||||
format: str = "JPEG",
|
||||
) -> str:
|
||||
video_b64 = encode_video_base64(frames, format=format)
|
||||
|
||||
if format.lower() == "jpeg":
|
||||
mimetype = "video/jpeg"
|
||||
else:
|
||||
mimetype = mimetypes.types_map.get("." + format.lower(), "video")
|
||||
|
||||
return f"data:{mimetype};base64,{video_b64}"
|
||||
|
||||
|
||||
def argsort_mm_positions(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user