[Frontend][Core] Add plumbing to support audio language models (#7446)

This commit is contained in:
Peter Salas 2024-08-13 10:39:33 -07:00 committed by GitHub
parent e20233d361
commit 00c3d68e45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 599 additions and 120 deletions

View File

@ -112,6 +112,8 @@ autodoc_mock_imports = [
"tensorizer",
"pynvml",
"outlines",
"librosa",
"soundfile",
"gguf",
"lark",
]

View File

@ -15,14 +15,14 @@ This document walks you through the steps to extend a vLLM model so that it acce
It is assumed that you have already implemented the model in vLLM according to :ref:`these steps <adding_a_new_model>`.
Further update the model as follows:
- Implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface.
- Implement the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
.. code-block:: diff
+ from vllm.model_executor.models.interfaces import SupportsVision
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
- class YourModelForImage2Seq(nn.Module):
+ class YourModelForImage2Seq(nn.Module, SupportsVision):
+ class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
.. note::
The model class does not have to be named :code:`*ForCausalLM`.
@ -51,11 +51,11 @@ This decorator accepts a function that maps multi-modal inputs to the keyword ar
.. code-block:: diff
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsMultiModal
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
@ -72,13 +72,13 @@ and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.regis
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
Here are some examples:
@ -98,13 +98,13 @@ In such cases, you can define your own dummy data by registering a factory metho
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
.. note::
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
@ -128,14 +128,14 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:

View File

@ -20,4 +20,6 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1

View File

@ -0,0 +1,351 @@
import math
import sys
import time
from typing import Dict, List, Optional, Tuple, Union, cast
from unittest.mock import patch
import librosa
import numpy as np
import openai
import pytest
import requests
import torch
from vllm import ModelRegistry
from vllm.config import MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
from vllm.utils import get_open_port
from ...utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
MODEL_NAME = "facebook/opt-125m"
TEST_AUDIO_URLS = [
"https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
]
def server_function(port):
def fake_input_mapper(ctx: InputContext, data: object):
assert isinstance(data, tuple)
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
# Resample it to 1 sample per second
audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return llm_inputs
audio, sr = multi_modal_data.get("audio")
audio_duration = math.ceil(len(audio) / sr)
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
cached_get_tokenizer(ctx.model_config.tokenizer),
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=62, # "_"
repeat_count=audio_duration)
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", lambda *_, **__: 100)
@INPUT_REGISTRY.register_input_processor(fake_input_processor)
class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
def __init__(self, *args, multimodal_config: MultiModalConfig,
**kwargs):
assert multimodal_config is not None
super().__init__(*args, **kwargs)
def forward(
self,
*args,
processed_audio: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return super().forward(*args, **kwargs)
ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
with patch("vllm.entrypoints.chat_utils._mm_token_str",
lambda *_, **__: "_"):
sys.argv = ["placeholder.py"] + \
(f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
"--dtype bfloat16 --enforce-eager --api-key token-abc123 "
f"--port {port} --chat-template {chatml_jinja_path} "
"--disable-frontend-multiprocessing").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server',
run_name='__main__')
@pytest.fixture(scope="module")
def client():
port = get_open_port()
ctx = torch.multiprocessing.get_context("spawn")
server = ctx.Process(target=server_function, args=(port, ))
server.start()
MAX_SERVER_START_WAIT_S = 60
client = openai.AsyncOpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
# run health check
health_url = f"http://localhost:{port}/health"
start = time.time()
while True:
try:
if requests.get(health_url).status_code == 200:
break
except Exception as err:
result = server.exitcode
if result is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT_S:
raise RuntimeError("Server failed to start in time.") from err
try:
yield client
finally:
server.kill()
@pytest.fixture(scope="session")
def base64_encoded_audio() -> Dict[str, str]:
return {
audio_url: encode_audio_base64(*fetch_audio(audio_url))
for audio_url in TEST_AUDIO_URLS
}
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
model_name: str, audio_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_single_chat_session_audio_base64encoded(
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
base64_encoded_audio: Dict[str, str]):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url":
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
message = choice.message
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
model_name: str, audio_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
)
output = chat_completion.choices[0].message.content
stop_reason = chat_completion.choices[0].finish_reason
# test streaming
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
stream=True,
)
chunks: List[str] = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == stop_reason
assert delta.content
assert "".join(chunks) == output
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
audio_url: str):
messages = [{
"role":
"user",
"content": [
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
}
},
{
"type": "text",
"text": "What's happening in this audio?"
},
],
}]
with pytest.raises(openai.BadRequestError): # test multi-audio input
await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
)
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0

View File

@ -2,7 +2,8 @@ import codecs
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union, cast)
# yapf conflicts with isort for this block
# yapf: disable
@ -21,12 +22,27 @@ from typing_extensions import Required, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
class AudioURL(TypedDict, total=False):
url: Required[str]
"""
Either a URL of the audio or a data URL with base64 encoded audio data.
"""
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
audio_url: Required[AudioURL]
type: Required[Literal["audio_url"]]
"""The type of the content part."""
class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
@ -35,6 +51,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam]
@ -97,34 +114,41 @@ def load_chat_template(
@lru_cache(maxsize=None)
def _image_token_str(model_config: ModelConfig,
tokenizer: PreTrainedTokenizer) -> Optional[str]:
def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
raise TypeError(f"Unknown model type: {model_type}")
if modality == "image":
model_type = model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
raise TypeError("No audio models are supported yet.")
else:
raise TypeError(f"Unknown modality: {modality}")
# TODO: Let user specify how to insert image tokens into prompt
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
"""Combine image and text prompts for vision language model"""
def _get_full_multimodal_text_prompt(placeholder_token_str: str,
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return f"{image_token_str}\n{text_prompt}"
# placeholder + text prompt format. This may change in the future.
return f"{placeholder_token_str}\n{text_prompt}"
def _parse_chat_message_content_parts(
@ -135,6 +159,7 @@ def _parse_chat_message_content_parts(
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
modality: Literal["image", "audio"] = "image"
for part in parts:
part_type = part["type"]
@ -142,9 +167,10 @@ def _parse_chat_message_content_parts(
text = cast(ChatCompletionContentPartTextParam, part)["text"]
texts.append(text)
elif part_type == "image_url":
modality = "image"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple 'image_url' input is currently not supported.")
"Multiple multimodal inputs is currently not supported.")
image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
@ -156,21 +182,32 @@ def _parse_chat_message_content_parts(
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
elif part_type == "audio_url":
modality = "audio"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
audio_url = cast(ChatCompletionContentPartAudioParam,
part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if mm_futures:
image_token_str = _image_token_str(model_config, tokenizer)
if image_token_str is not None:
if image_token_str in text_prompt:
placeholder_token_str = _mm_token_str(model_config, tokenizer,
modality)
if placeholder_token_str is not None:
if placeholder_token_str in text_prompt:
logger.warning(
"Detected image token string in the text prompt. "
"Detected multi-modal token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_image_text_prompt(
image_token_str=image_token_str,
text_prompt = _get_full_multimodal_text_prompt(
placeholder_token_str=placeholder_token_str,
text_prompt=text_prompt,
)

View File

@ -44,6 +44,7 @@ if TYPE_CHECKING:
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_AUDIO_FETCH_TIMEOUT: int = 5
VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
@ -321,6 +322,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Timeout for fetching audio when serving multimodal models
# Default is 5 seconds
"VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "5")),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH":

View File

@ -38,7 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator)
from vllm.model_executor.models.interfaces import (has_inner_state,
supports_lora,
supports_vision)
supports_multimodal)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
@ -131,7 +131,7 @@ def _get_model_initialization_kwargs(
"be added in the future. If this is important to you, "
"please open an issue on github.")
if supports_vision(model_class):
if supports_multimodal(model_class):
if multimodal_config is None:
raise ValueError("Provide vision related configurations "
"through LLM entrypoint or engine arguments.")

View File

@ -20,8 +20,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
@ -457,7 +457,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: Blip2Config,
@ -621,9 +621,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)
input_ids = None
else:

View File

@ -35,7 +35,7 @@ from vllm.multimodal.image import (cached_get_tokenizer,
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.utils import print_warning_once
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
logger = init_logger(__name__)
@ -886,7 +886,7 @@ class ChameleonModel(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(
self,

View File

@ -40,8 +40,8 @@ from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
@ -209,7 +209,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsVision):
class FuyuForCausalLM(nn.Module, SupportsMultiModal):
def __init__(self,
config: FuyuConfig,
@ -271,9 +271,9 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
self.image_token_id)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = None

View File

@ -10,12 +10,15 @@ logger = init_logger(__name__)
@runtime_checkable
class SupportsVision(Protocol):
"""The interface required for all vision language models (VLMs)."""
supports_vision: ClassVar[Literal[True]] = True
class SupportsMultiModal(Protocol):
"""
A flag that indicates this model supports vision inputs.
The interface required for all multimodal (vision or audio) language
models.
"""
supports_multimodal: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multimodal inputs.
Note:
There is no need to redefine this flag if this class is in the
@ -29,30 +32,31 @@ class SupportsVision(Protocol):
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsVisionType(Protocol):
supports_vision: Literal[True]
class _SupportsMultiModalType(Protocol):
supports_multimodal: Literal[True]
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
...
@overload
def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]:
def supports_multimodal(
model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
...
@overload
def supports_vision(model: object) -> TypeIs[SupportsVision]:
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
...
def supports_vision(
def supports_multimodal(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]:
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
if isinstance(model, type):
return isinstance(model, _SupportsVisionType)
return isinstance(model, _SupportsMultiModalType)
return isinstance(model, SupportsVision)
return isinstance(model, SupportsMultiModal)
@runtime_checkable

View File

@ -27,9 +27,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
merge_multimodal_embeddings)
IMG_START = '<img>'
IMG_END = '</img>'
@ -292,7 +292,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
class InternVLChatModel(nn.Module, SupportsVision):
class InternVLChatModel(nn.Module, SupportsMultiModal):
def __init__(self,
config: PretrainedConfig,
@ -451,9 +451,9 @@ class InternVLChatModel(nn.Module, SupportsVision):
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
self.img_context_token_id)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
input_ids = None
else:
inputs_embeds = None

View File

@ -19,12 +19,12 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict):
@ -181,7 +181,7 @@ def _init_vision_tower(hf_config: LlavaConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: LlavaConfig,
@ -338,7 +338,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_vision_embeddings(
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)

View File

@ -23,13 +23,13 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -275,7 +275,7 @@ def _init_vision_tower(hf_config: LlavaNextConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: LlavaNextConfig,
@ -571,7 +571,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_vision_embeddings(
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)

View File

@ -48,7 +48,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.qwen2 import Qwen2Model
@ -479,7 +479,7 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
class MiniCPMVBaseModel(nn.Module, SupportsVision):
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.

View File

@ -19,10 +19,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import merge_vision_embeddings
from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
@ -130,7 +130,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: PaliGemmaConfig,
@ -244,7 +244,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)

View File

@ -42,8 +42,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
@ -453,7 +453,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
def __init__(self,
config: PretrainedConfig,
@ -568,9 +568,9 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
self.image_token_id)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
input_ids = None
else:
inputs_embeds = None

View File

@ -54,41 +54,42 @@ def init_vllm_registered_model(
)
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: BatchedTensors,
image_token_id: int) -> torch.Tensor:
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: BatchedTensors,
placeholder_token_id: int) -> torch.Tensor:
"""
Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder image tokens in
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
mask = (input_ids == image_token_id)
mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum()
if isinstance(vision_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
if isinstance(multimodal_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
total_tokens = batch_size * batch_tokens
if num_expected_tokens != total_tokens:
expr = f"{batch_size} x {batch_tokens}"
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders")
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
inputs_embeds[mask] = multimodal_embeddings.view(
total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in vision_embeddings]
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"image tokens to {num_expected_tokens} placeholders")
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = torch.cat(vision_embeddings)
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
return inputs_embeds

17
vllm/multimodal/audio.py Normal file
View File

@ -0,0 +1,17 @@
from vllm.inputs.registry import InputContext
from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin
class AudioPlugin(MultiModalPlugin):
"""Plugin for audio data."""
def get_data_key(self) -> str:
return "audio"
def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
raise NotImplementedError("There is no default audio input mapper")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
raise NotImplementedError(
"There is no default maximum multimodal tokens")

View File

@ -3,8 +3,9 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from typing import Any, Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypedDict, TypeVar, Union, cast
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast
import numpy as np
import torch
import torch.types
from PIL import Image
@ -121,6 +122,9 @@ class MultiModalDataBuiltins(TypedDict, total=False):
image: Image.Image
"""The input image."""
audio: Tuple[np.ndarray, Union[int, float]]
"""The input audio and its sampling rate."""
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
"""

View File

@ -6,6 +6,7 @@ import torch
from vllm.config import ModelConfig
from vllm.logger import init_logger
from .audio import AudioPlugin
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
MultiModalPlugin, MultiModalTokensCalc)
from .image import ImagePlugin
@ -19,7 +20,7 @@ class MultiModalRegistry:
:class:`~vllm.multimodal.MultiModalPlugin` for each modality.
"""
DEFAULT_PLUGINS = (ImagePlugin(), )
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin())
def __init__(
self,

View File

@ -1,11 +1,14 @@
import base64
from io import BytesIO
from typing import Union
from typing import Tuple, Union
import librosa
import numpy as np
import soundfile
from PIL import Image
from vllm.connections import global_http_connection
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.base import MultiModalDataDict
@ -63,11 +66,62 @@ async def async_fetch_image(image_url: str,
return image.convert(image_mode)
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
if audio_url.startswith("http"):
audio_bytes = global_http_connection.get_bytes(
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
elif audio_url.startswith("data:audio"):
_, audio_base64 = audio_url.split(",", 1)
audio_bytes = base64.b64decode(audio_base64)
else:
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
"with either 'data:audio' or 'http'.")
return librosa.load(BytesIO(audio_bytes), sr=None)
async def async_fetch_audio(
audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
"""
Asynchronously fetch audio from a URL.
"""
if audio_url.startswith("http"):
audio_bytes = await global_http_connection.async_get_bytes(
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
elif audio_url.startswith("data:audio"):
_, audio_base64 = audio_url.split(",", 1)
audio_bytes = base64.b64decode(audio_base64)
else:
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
"with either 'data:audio' or 'http'.")
return librosa.load(BytesIO(audio_bytes), sr=None)
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = await async_fetch_audio(audio_url)
return {"audio": (audio, sr)}
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await async_fetch_image(image_url)
return {"image": image}
def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
) -> str:
"""Encode audio as base64."""
buffered = BytesIO()
soundfile.write(buffered, audio, sampling_rate, format="WAV")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def encode_image_base64(
image: Image.Image,
*,

View File

@ -40,7 +40,7 @@ from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
supports_multimodal)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
@ -900,9 +900,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if self.lora_config:
assert supports_lora(self.model), "Model does not support LoRA"
assert not supports_vision(
assert not supports_multimodal(
self.model
), "To be tested: vision language model with LoRA settings."
), "To be tested: multimodal language model with LoRA settings."
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
@ -1054,7 +1054,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# of images processed.
model_config = self.model_config
if supports_vision(self.model):
if supports_multimodal(self.model):
max_mm_tokens = MULTIMODAL_REGISTRY \
.get_max_multimodal_tokens(model_config)
max_num_seqs_orig = max_num_seqs

View File

@ -12,7 +12,7 @@ from vllm.distributed import broadcast_tensor_dict
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_vision
from vllm.model_executor.models.interfaces import supports_multimodal
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sampling_params import SamplingParams
@ -165,7 +165,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# of images processed.
model_config = self.model_config
if supports_vision(self.model):
if supports_multimodal(self.model):
max_mm_tokens = MULTIMODAL_REGISTRY \
.get_max_multimodal_tokens(model_config)
max_num_seqs_orig = max_num_seqs