mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 02:54:27 +08:00
[Frontend][Core] Add plumbing to support audio language models (#7446)
This commit is contained in:
parent
e20233d361
commit
00c3d68e45
@ -112,6 +112,8 @@ autodoc_mock_imports = [
|
|||||||
"tensorizer",
|
"tensorizer",
|
||||||
"pynvml",
|
"pynvml",
|
||||||
"outlines",
|
"outlines",
|
||||||
|
"librosa",
|
||||||
|
"soundfile",
|
||||||
"gguf",
|
"gguf",
|
||||||
"lark",
|
"lark",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -15,14 +15,14 @@ This document walks you through the steps to extend a vLLM model so that it acce
|
|||||||
It is assumed that you have already implemented the model in vLLM according to :ref:`these steps <adding_a_new_model>`.
|
It is assumed that you have already implemented the model in vLLM according to :ref:`these steps <adding_a_new_model>`.
|
||||||
Further update the model as follows:
|
Further update the model as follows:
|
||||||
|
|
||||||
- Implement the :class:`~vllm.model_executor.models.interfaces.SupportsVision` interface.
|
- Implement the :class:`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
|
||||||
|
|
||||||
.. code-block:: diff
|
.. code-block:: diff
|
||||||
|
|
||||||
+ from vllm.model_executor.models.interfaces import SupportsVision
|
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
|
|
||||||
- class YourModelForImage2Seq(nn.Module):
|
- class YourModelForImage2Seq(nn.Module):
|
||||||
+ class YourModelForImage2Seq(nn.Module, SupportsVision):
|
+ class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
The model class does not have to be named :code:`*ForCausalLM`.
|
The model class does not have to be named :code:`*ForCausalLM`.
|
||||||
@ -51,11 +51,11 @@ This decorator accepts a function that maps multi-modal inputs to the keyword ar
|
|||||||
|
|
||||||
.. code-block:: diff
|
.. code-block:: diff
|
||||||
|
|
||||||
from vllm.model_executor.models.interfaces import SupportsVision
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
+ from vllm.multimodal import MULTIMODAL_REGISTRY
|
+ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
|
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||||
class YourModelForImage2Seq(nn.Module, SupportsVision):
|
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
|
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
|
||||||
|
|
||||||
@ -72,13 +72,13 @@ and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.regis
|
|||||||
.. code-block:: diff
|
.. code-block:: diff
|
||||||
|
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
from vllm.model_executor.models.interfaces import SupportsVision
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||||
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
|
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
|
||||||
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
||||||
class YourModelForImage2Seq(nn.Module, SupportsVision):
|
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
Here are some examples:
|
Here are some examples:
|
||||||
|
|
||||||
@ -98,13 +98,13 @@ In such cases, you can define your own dummy data by registering a factory metho
|
|||||||
.. code-block:: diff
|
.. code-block:: diff
|
||||||
|
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
from vllm.model_executor.models.interfaces import SupportsVision
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
|
||||||
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
||||||
class YourModelForImage2Seq(nn.Module, SupportsVision):
|
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
|
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
|
||||||
@ -128,14 +128,14 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce
|
|||||||
.. code-block:: diff
|
.. code-block:: diff
|
||||||
|
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
from vllm.model_executor.models.interfaces import SupportsVision
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
|
||||||
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
|
||||||
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
|
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
|
||||||
class YourModelForImage2Seq(nn.Module, SupportsVision):
|
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
|
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
|
||||||
Here are some examples:
|
Here are some examples:
|
||||||
|
|||||||
@ -20,4 +20,6 @@ outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
|
|||||||
typing_extensions >= 4.10
|
typing_extensions >= 4.10
|
||||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||||
pyzmq
|
pyzmq
|
||||||
|
librosa # Required for audio processing
|
||||||
|
soundfile # Required for audio processing
|
||||||
gguf == 0.9.1
|
gguf == 0.9.1
|
||||||
|
|||||||
351
tests/entrypoints/openai/test_audio.py
Normal file
351
tests/entrypoints/openai/test_audio.py
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union, cast
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import ModelRegistry
|
||||||
|
from vllm.config import MultiModalConfig
|
||||||
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
|
from vllm.inputs.data import LLMInputs
|
||||||
|
from vllm.inputs.registry import InputContext
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
|
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.base import MultiModalInputs
|
||||||
|
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||||
|
repeat_and_pad_image_tokens)
|
||||||
|
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
|
||||||
|
from vllm.utils import get_open_port
|
||||||
|
|
||||||
|
from ...utils import VLLM_PATH
|
||||||
|
|
||||||
|
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||||
|
assert chatml_jinja_path.exists()
|
||||||
|
|
||||||
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
|
TEST_AUDIO_URLS = [
|
||||||
|
"https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def server_function(port):
|
||||||
|
|
||||||
|
def fake_input_mapper(ctx: InputContext, data: object):
|
||||||
|
assert isinstance(data, tuple)
|
||||||
|
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
|
||||||
|
|
||||||
|
# Resample it to 1 sample per second
|
||||||
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
|
||||||
|
return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
|
||||||
|
|
||||||
|
def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
|
||||||
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
|
if multi_modal_data is None or "audio" not in multi_modal_data:
|
||||||
|
return llm_inputs
|
||||||
|
|
||||||
|
audio, sr = multi_modal_data.get("audio")
|
||||||
|
audio_duration = math.ceil(len(audio) / sr)
|
||||||
|
|
||||||
|
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||||
|
cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||||
|
llm_inputs.get("prompt"),
|
||||||
|
llm_inputs["prompt_token_ids"],
|
||||||
|
image_token_id=62, # "_"
|
||||||
|
repeat_count=audio_duration)
|
||||||
|
|
||||||
|
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||||
|
prompt=new_prompt,
|
||||||
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
|
||||||
|
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||||
|
"audio", lambda *_, **__: 100)
|
||||||
|
@INPUT_REGISTRY.register_input_processor(fake_input_processor)
|
||||||
|
class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
|
||||||
|
|
||||||
|
def __init__(self, *args, multimodal_config: MultiModalConfig,
|
||||||
|
**kwargs):
|
||||||
|
assert multimodal_config is not None
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
processed_audio: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
|
||||||
|
|
||||||
|
with patch("vllm.entrypoints.chat_utils._mm_token_str",
|
||||||
|
lambda *_, **__: "_"):
|
||||||
|
sys.argv = ["placeholder.py"] + \
|
||||||
|
(f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
|
||||||
|
"--dtype bfloat16 --enforce-eager --api-key token-abc123 "
|
||||||
|
f"--port {port} --chat-template {chatml_jinja_path} "
|
||||||
|
"--disable-frontend-multiprocessing").split()
|
||||||
|
import runpy
|
||||||
|
runpy.run_module('vllm.entrypoints.openai.api_server',
|
||||||
|
run_name='__main__')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def client():
|
||||||
|
port = get_open_port()
|
||||||
|
ctx = torch.multiprocessing.get_context("spawn")
|
||||||
|
server = ctx.Process(target=server_function, args=(port, ))
|
||||||
|
server.start()
|
||||||
|
MAX_SERVER_START_WAIT_S = 60
|
||||||
|
client = openai.AsyncOpenAI(
|
||||||
|
base_url=f"http://localhost:{port}/v1",
|
||||||
|
api_key="token-abc123",
|
||||||
|
)
|
||||||
|
# run health check
|
||||||
|
health_url = f"http://localhost:{port}/health"
|
||||||
|
start = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if requests.get(health_url).status_code == 200:
|
||||||
|
break
|
||||||
|
except Exception as err:
|
||||||
|
result = server.exitcode
|
||||||
|
if result is not None:
|
||||||
|
raise RuntimeError("Server exited unexpectedly.") from err
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
if time.time() - start > MAX_SERVER_START_WAIT_S:
|
||||||
|
raise RuntimeError("Server failed to start in time.") from err
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield client
|
||||||
|
finally:
|
||||||
|
server.kill()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def base64_encoded_audio() -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
audio_url: encode_audio_base64(*fetch_audio(audio_url))
|
||||||
|
for audio_url in TEST_AUDIO_URLS
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
|
async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
|
||||||
|
model_name: str, audio_url: str):
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "audio_url",
|
||||||
|
"audio_url": {
|
||||||
|
"url": audio_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's happening in this audio?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
# test single completion
|
||||||
|
chat_completion = await client.chat.completions.create(model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=5)
|
||||||
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
|
choice = chat_completion.choices[0]
|
||||||
|
assert choice.finish_reason == "length"
|
||||||
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
|
completion_tokens=10, prompt_tokens=36, total_tokens=46)
|
||||||
|
|
||||||
|
message = choice.message
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 10
|
||||||
|
assert message.role == "assistant"
|
||||||
|
messages.append({"role": "assistant", "content": message.content})
|
||||||
|
|
||||||
|
# test multi-turn dialogue
|
||||||
|
messages.append({"role": "user", "content": "express your result in json"})
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
|
async def test_single_chat_session_audio_base64encoded(
|
||||||
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str,
|
||||||
|
base64_encoded_audio: Dict[str, str]):
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "audio_url",
|
||||||
|
"audio_url": {
|
||||||
|
"url":
|
||||||
|
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's happening in this audio?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
# test single completion
|
||||||
|
chat_completion = await client.chat.completions.create(model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=5)
|
||||||
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
|
choice = chat_completion.choices[0]
|
||||||
|
assert choice.finish_reason == "length"
|
||||||
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
|
completion_tokens=10, prompt_tokens=36, total_tokens=46)
|
||||||
|
|
||||||
|
message = choice.message
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 10
|
||||||
|
assert message.role == "assistant"
|
||||||
|
messages.append({"role": "assistant", "content": message.content})
|
||||||
|
|
||||||
|
# test multi-turn dialogue
|
||||||
|
messages.append({"role": "user", "content": "express your result in json"})
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
|
async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
|
||||||
|
model_name: str, audio_url: str):
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "audio_url",
|
||||||
|
"audio_url": {
|
||||||
|
"url": audio_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's happening in this audio?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
# test single completion
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
output = chat_completion.choices[0].message.content
|
||||||
|
stop_reason = chat_completion.choices[0].finish_reason
|
||||||
|
|
||||||
|
# test streaming
|
||||||
|
stream = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
chunks: List[str] = []
|
||||||
|
finish_reason_count = 0
|
||||||
|
async for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.role:
|
||||||
|
assert delta.role == "assistant"
|
||||||
|
if delta.content:
|
||||||
|
chunks.append(delta.content)
|
||||||
|
if chunk.choices[0].finish_reason is not None:
|
||||||
|
finish_reason_count += 1
|
||||||
|
# finish reason should only return in last block
|
||||||
|
assert finish_reason_count == 1
|
||||||
|
assert chunk.choices[0].finish_reason == stop_reason
|
||||||
|
assert delta.content
|
||||||
|
assert "".join(chunks) == output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
|
||||||
|
async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
|
||||||
|
audio_url: str):
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "audio_url",
|
||||||
|
"audio_url": {
|
||||||
|
"url": audio_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "audio_url",
|
||||||
|
"audio_url": {
|
||||||
|
"url": audio_url
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's happening in this audio?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
with pytest.raises(openai.BadRequestError): # test multi-audio input
|
||||||
|
await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# the server should still work afterwards
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
completion = completion.choices[0].text
|
||||||
|
assert completion is not None and len(completion) >= 0
|
||||||
@ -2,7 +2,8 @@ import codecs
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast
|
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
|
||||||
|
Union, cast)
|
||||||
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -21,12 +22,27 @@ from typing_extensions import Required, TypedDict
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.utils import async_get_and_parse_image
|
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||||
|
async_get_and_parse_image)
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioURL(TypedDict, total=False):
|
||||||
|
url: Required[str]
|
||||||
|
"""
|
||||||
|
Either a URL of the audio or a data URL with base64 encoded audio data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||||
|
audio_url: Required[AudioURL]
|
||||||
|
|
||||||
|
type: Required[Literal["audio_url"]]
|
||||||
|
"""The type of the content part."""
|
||||||
|
|
||||||
|
|
||||||
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
||||||
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
||||||
|
|
||||||
@ -35,6 +51,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
|||||||
|
|
||||||
|
|
||||||
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
|
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
|
||||||
|
ChatCompletionContentPartAudioParam,
|
||||||
CustomChatCompletionContentPartParam]
|
CustomChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
|
||||||
@ -97,10 +114,11 @@ def load_chat_template(
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def _image_token_str(model_config: ModelConfig,
|
def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
|
||||||
tokenizer: PreTrainedTokenizer) -> Optional[str]:
|
modality: Literal["image", "audio"]) -> Optional[str]:
|
||||||
# TODO: Let user specify how to insert image tokens into prompt
|
# TODO: Let user specify how to insert image tokens into prompt
|
||||||
# (similar to chat template)
|
# (similar to chat template)
|
||||||
|
if modality == "image":
|
||||||
model_type = model_config.hf_config.model_type
|
model_type = model_config.hf_config.model_type
|
||||||
if model_type == "phi3_v":
|
if model_type == "phi3_v":
|
||||||
# Workaround since this token is not defined in the tokenizer
|
# Workaround since this token is not defined in the tokenizer
|
||||||
@ -114,17 +132,23 @@ def _image_token_str(model_config: ModelConfig,
|
|||||||
return tokenizer.decode(model_config.hf_config.image_token_index)
|
return tokenizer.decode(model_config.hf_config.image_token_index)
|
||||||
if model_type in ("chameleon", "internvl_chat"):
|
if model_type in ("chameleon", "internvl_chat"):
|
||||||
return "<image>"
|
return "<image>"
|
||||||
|
|
||||||
raise TypeError(f"Unknown model type: {model_type}")
|
raise TypeError(f"Unknown model type: {model_type}")
|
||||||
|
elif modality == "audio":
|
||||||
|
raise TypeError("No audio models are supported yet.")
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unknown modality: {modality}")
|
||||||
|
|
||||||
|
|
||||||
# TODO: Let user specify how to insert image tokens into prompt
|
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||||
# (similar to chat template)
|
# (similar to chat template)
|
||||||
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
|
def _get_full_multimodal_text_prompt(placeholder_token_str: str,
|
||||||
"""Combine image and text prompts for vision language model"""
|
text_prompt: str) -> str:
|
||||||
|
"""Combine multimodal prompts for a multimodal language model"""
|
||||||
|
|
||||||
# NOTE: For now we assume all model architectures use the same
|
# NOTE: For now we assume all model architectures use the same
|
||||||
# image + text prompt format. This may change in the future.
|
# placeholder + text prompt format. This may change in the future.
|
||||||
return f"{image_token_str}\n{text_prompt}"
|
return f"{placeholder_token_str}\n{text_prompt}"
|
||||||
|
|
||||||
|
|
||||||
def _parse_chat_message_content_parts(
|
def _parse_chat_message_content_parts(
|
||||||
@ -135,6 +159,7 @@ def _parse_chat_message_content_parts(
|
|||||||
) -> ChatMessageParseResult:
|
) -> ChatMessageParseResult:
|
||||||
texts: List[str] = []
|
texts: List[str] = []
|
||||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||||
|
modality: Literal["image", "audio"] = "image"
|
||||||
|
|
||||||
for part in parts:
|
for part in parts:
|
||||||
part_type = part["type"]
|
part_type = part["type"]
|
||||||
@ -142,9 +167,10 @@ def _parse_chat_message_content_parts(
|
|||||||
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
elif part_type == "image_url":
|
elif part_type == "image_url":
|
||||||
|
modality = "image"
|
||||||
if len(mm_futures) > 0:
|
if len(mm_futures) > 0:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Multiple 'image_url' input is currently not supported.")
|
"Multiple multimodal inputs is currently not supported.")
|
||||||
|
|
||||||
image_url = cast(ChatCompletionContentPartImageParam,
|
image_url = cast(ChatCompletionContentPartImageParam,
|
||||||
part)["image_url"]
|
part)["image_url"]
|
||||||
@ -156,21 +182,32 @@ def _parse_chat_message_content_parts(
|
|||||||
|
|
||||||
image_future = async_get_and_parse_image(image_url["url"])
|
image_future = async_get_and_parse_image(image_url["url"])
|
||||||
mm_futures.append(image_future)
|
mm_futures.append(image_future)
|
||||||
|
elif part_type == "audio_url":
|
||||||
|
modality = "audio"
|
||||||
|
if len(mm_futures) > 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Multiple multimodal inputs is currently not supported.")
|
||||||
|
|
||||||
|
audio_url = cast(ChatCompletionContentPartAudioParam,
|
||||||
|
part)["audio_url"]
|
||||||
|
audio_future = async_get_and_parse_audio(audio_url["url"])
|
||||||
|
mm_futures.append(audio_future)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||||
|
|
||||||
text_prompt = "\n".join(texts)
|
text_prompt = "\n".join(texts)
|
||||||
|
|
||||||
if mm_futures:
|
if mm_futures:
|
||||||
image_token_str = _image_token_str(model_config, tokenizer)
|
placeholder_token_str = _mm_token_str(model_config, tokenizer,
|
||||||
if image_token_str is not None:
|
modality)
|
||||||
if image_token_str in text_prompt:
|
if placeholder_token_str is not None:
|
||||||
|
if placeholder_token_str in text_prompt:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Detected image token string in the text prompt. "
|
"Detected multi-modal token string in the text prompt. "
|
||||||
"Skipping prompt formatting.")
|
"Skipping prompt formatting.")
|
||||||
else:
|
else:
|
||||||
text_prompt = _get_full_image_text_prompt(
|
text_prompt = _get_full_multimodal_text_prompt(
|
||||||
image_token_str=image_token_str,
|
placeholder_token_str=placeholder_token_str,
|
||||||
text_prompt=text_prompt,
|
text_prompt=text_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -44,6 +44,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
|
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
|
||||||
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
|
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
|
||||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||||
|
VLLM_AUDIO_FETCH_TIMEOUT: int = 5
|
||||||
VLLM_TARGET_DEVICE: str = "cuda"
|
VLLM_TARGET_DEVICE: str = "cuda"
|
||||||
MAX_JOBS: Optional[str] = None
|
MAX_JOBS: Optional[str] = None
|
||||||
NVCC_THREADS: Optional[str] = None
|
NVCC_THREADS: Optional[str] = None
|
||||||
@ -321,6 +322,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_IMAGE_FETCH_TIMEOUT":
|
"VLLM_IMAGE_FETCH_TIMEOUT":
|
||||||
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
|
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
|
||||||
|
|
||||||
|
# Timeout for fetching audio when serving multimodal models
|
||||||
|
# Default is 5 seconds
|
||||||
|
"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||||
|
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "5")),
|
||||||
|
|
||||||
# Path to the XLA persistent cache directory.
|
# Path to the XLA persistent cache directory.
|
||||||
# Only used for XLA devices such as TPUs.
|
# Only used for XLA devices such as TPUs.
|
||||||
"VLLM_XLA_CACHE_PATH":
|
"VLLM_XLA_CACHE_PATH":
|
||||||
|
|||||||
@ -38,7 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
safetensors_weights_iterator)
|
safetensors_weights_iterator)
|
||||||
from vllm.model_executor.models.interfaces import (has_inner_state,
|
from vllm.model_executor.models.interfaces import (has_inner_state,
|
||||||
supports_lora,
|
supports_lora,
|
||||||
supports_vision)
|
supports_multimodal)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
@ -131,7 +131,7 @@ def _get_model_initialization_kwargs(
|
|||||||
"be added in the future. If this is important to you, "
|
"be added in the future. If this is important to you, "
|
||||||
"please open an issue on github.")
|
"please open an issue on github.")
|
||||||
|
|
||||||
if supports_vision(model_class):
|
if supports_multimodal(model_class):
|
||||||
if multimodal_config is None:
|
if multimodal_config is None:
|
||||||
raise ValueError("Provide vision related configurations "
|
raise ValueError("Provide vision related configurations "
|
||||||
"through LLM entrypoint or engine arguments.")
|
"through LLM entrypoint or engine arguments.")
|
||||||
|
|||||||
@ -20,8 +20,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
|||||||
|
|
||||||
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||||
get_max_blip_image_tokens)
|
get_max_blip_image_tokens)
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsMultiModal
|
||||||
from .utils import merge_vision_embeddings
|
from .utils import merge_multimodal_embeddings
|
||||||
|
|
||||||
_KEYS_TO_MODIFY_MAPPING = {
|
_KEYS_TO_MODIFY_MAPPING = {
|
||||||
"language_model.lm_head": "lm_head",
|
"language_model.lm_head": "lm_head",
|
||||||
@ -457,7 +457,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
|
||||||
class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
|
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: Blip2Config,
|
config: Blip2Config,
|
||||||
@ -621,8 +621,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
BLIP2_IMAGE_TOKEN_ID)
|
BLIP2_IMAGE_TOKEN_ID)
|
||||||
|
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|||||||
@ -35,7 +35,7 @@ from vllm.multimodal.image import (cached_get_tokenizer,
|
|||||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsMultiModal
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -886,7 +886,7 @@ class ChameleonModel(nn.Module):
|
|||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
|
||||||
class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
|
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -40,8 +40,8 @@ from vllm.multimodal.image import (cached_get_image_processor,
|
|||||||
cached_get_tokenizer)
|
cached_get_tokenizer)
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
|
||||||
|
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsMultiModal
|
||||||
from .utils import merge_vision_embeddings
|
from .utils import merge_multimodal_embeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -209,7 +209,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
|
|||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
|
||||||
class FuyuForCausalLM(nn.Module, SupportsVision):
|
class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: FuyuConfig,
|
config: FuyuConfig,
|
||||||
@ -271,8 +271,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
|||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.image_token_id)
|
self.image_token_id)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -10,12 +10,15 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class SupportsVision(Protocol):
|
class SupportsMultiModal(Protocol):
|
||||||
"""The interface required for all vision language models (VLMs)."""
|
|
||||||
|
|
||||||
supports_vision: ClassVar[Literal[True]] = True
|
|
||||||
"""
|
"""
|
||||||
A flag that indicates this model supports vision inputs.
|
The interface required for all multimodal (vision or audio) language
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
supports_multimodal: ClassVar[Literal[True]] = True
|
||||||
|
"""
|
||||||
|
A flag that indicates this model supports multimodal inputs.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
There is no need to redefine this flag if this class is in the
|
There is no need to redefine this flag if this class is in the
|
||||||
@ -29,30 +32,31 @@ class SupportsVision(Protocol):
|
|||||||
# We can't use runtime_checkable with ClassVar for issubclass checks
|
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||||
# so we need to treat the class as an instance and use isinstance instead
|
# so we need to treat the class as an instance and use isinstance instead
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class _SupportsVisionType(Protocol):
|
class _SupportsMultiModalType(Protocol):
|
||||||
supports_vision: Literal[True]
|
supports_multimodal: Literal[True]
|
||||||
|
|
||||||
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
|
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]:
|
def supports_multimodal(
|
||||||
|
model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def supports_vision(model: object) -> TypeIs[SupportsVision]:
|
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
def supports_vision(
|
def supports_multimodal(
|
||||||
model: Union[Type[object], object],
|
model: Union[Type[object], object],
|
||||||
) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]:
|
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
|
||||||
if isinstance(model, type):
|
if isinstance(model, type):
|
||||||
return isinstance(model, _SupportsVisionType)
|
return isinstance(model, _SupportsMultiModalType)
|
||||||
|
|
||||||
return isinstance(model, SupportsVision)
|
return isinstance(model, SupportsMultiModal)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|||||||
@ -27,9 +27,9 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
|||||||
|
|
||||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||||
get_clip_num_patches)
|
get_clip_num_patches)
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsMultiModal
|
||||||
from .utils import (filter_weights, init_vllm_registered_model,
|
from .utils import (filter_weights, init_vllm_registered_model,
|
||||||
merge_vision_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
IMG_START = '<img>'
|
IMG_START = '<img>'
|
||||||
IMG_END = '</img>'
|
IMG_END = '</img>'
|
||||||
@ -292,7 +292,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
|
|||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
|
||||||
class InternVLChatModel(nn.Module, SupportsVision):
|
class InternVLChatModel(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@ -451,8 +451,8 @@ class InternVLChatModel(nn.Module, SupportsVision):
|
|||||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
input_ids)
|
input_ids)
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.img_context_token_id)
|
self.img_context_token_id)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -19,12 +19,12 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
|||||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||||
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
||||||
input_processor_for_clip)
|
input_processor_for_clip)
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsMultiModal
|
||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||||
input_processor_for_siglip)
|
input_processor_for_siglip)
|
||||||
from .utils import (filter_weights, init_vllm_registered_model,
|
from .utils import (filter_weights, init_vllm_registered_model,
|
||||||
merge_vision_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
|
|
||||||
class LlavaImagePixelInputs(TypedDict):
|
class LlavaImagePixelInputs(TypedDict):
|
||||||
@ -181,7 +181,7 @@ def _init_vision_tower(hf_config: LlavaConfig):
|
|||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
||||||
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: LlavaConfig,
|
config: LlavaConfig,
|
||||||
@ -338,7 +338,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
input_ids)
|
input_ids)
|
||||||
|
|
||||||
inputs_embeds = merge_vision_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.config.image_token_index)
|
self.config.image_token_index)
|
||||||
|
|
||||||
|
|||||||
@ -23,13 +23,13 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
|||||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||||
dummy_seq_data_for_clip, get_clip_image_feature_size,
|
dummy_seq_data_for_clip, get_clip_image_feature_size,
|
||||||
get_clip_patch_grid_length, input_processor_for_clip)
|
get_clip_patch_grid_length, input_processor_for_clip)
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsMultiModal
|
||||||
from .llava import LlavaMultiModalProjector
|
from .llava import LlavaMultiModalProjector
|
||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
||||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||||
from .utils import (filter_weights, init_vllm_registered_model,
|
from .utils import (filter_weights, init_vllm_registered_model,
|
||||||
merge_vision_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -275,7 +275,7 @@ def _init_vision_tower(hf_config: LlavaNextConfig):
|
|||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
|
||||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: LlavaNextConfig,
|
config: LlavaNextConfig,
|
||||||
@ -571,7 +571,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||||
input_ids)
|
input_ids)
|
||||||
|
|
||||||
inputs_embeds = merge_vision_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.config.image_token_index)
|
self.config.image_token_index)
|
||||||
|
|
||||||
|
|||||||
@ -48,7 +48,7 @@ from vllm.model_executor.layers.sampler import Sampler
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import SupportsVision
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
from vllm.model_executor.models.llama import LlamaModel
|
from vllm.model_executor.models.llama import LlamaModel
|
||||||
from vllm.model_executor.models.minicpm import MiniCPMModel
|
from vllm.model_executor.models.minicpm import MiniCPMModel
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||||
@ -479,7 +479,7 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
return llm_inputs
|
return llm_inputs
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMVBaseModel(nn.Module, SupportsVision):
|
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||||
"""
|
"""
|
||||||
The abstract class of MiniCPMV can only be inherited, but cannot be
|
The abstract class of MiniCPMV can only be inherited, but cannot be
|
||||||
instantiated.
|
instantiated.
|
||||||
|
|||||||
@ -19,10 +19,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
|||||||
from vllm.multimodal.image import cached_get_tokenizer
|
from vllm.multimodal.image import cached_get_tokenizer
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsMultiModal
|
||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
||||||
from .utils import merge_vision_embeddings
|
from .utils import merge_multimodal_embeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
|||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
|
||||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PaliGemmaConfig,
|
config: PaliGemmaConfig,
|
||||||
@ -244,7 +244,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
|||||||
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
inputs_embeds = merge_vision_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.config.image_token_index)
|
self.config.image_token_index)
|
||||||
|
|
||||||
|
|||||||
@ -42,8 +42,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
|||||||
|
|
||||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||||
input_processor_for_clip)
|
input_processor_for_clip)
|
||||||
from .interfaces import SupportsVision
|
from .interfaces import SupportsMultiModal
|
||||||
from .utils import merge_vision_embeddings
|
from .utils import merge_multimodal_embeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -453,7 +453,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
|
||||||
class Phi3VForCausalLM(nn.Module, SupportsVision):
|
class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@ -568,8 +568,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
|||||||
if image_input is not None:
|
if image_input is not None:
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
vision_embeddings,
|
input_ids, inputs_embeds, vision_embeddings,
|
||||||
self.image_token_id)
|
self.image_token_id)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -54,41 +54,42 @@ def init_vllm_registered_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||||
inputs_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
vision_embeddings: BatchedTensors,
|
multimodal_embeddings: BatchedTensors,
|
||||||
image_token_id: int) -> torch.Tensor:
|
placeholder_token_id: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
|
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||||
positions in ``inputs_embeds`` corresponding to placeholder image tokens in
|
positions in ``inputs_embeds`` corresponding to placeholder tokens in
|
||||||
``input_ids``.
|
``input_ids``.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This updates ``inputs_embeds`` in place.
|
This updates ``inputs_embeds`` in place.
|
||||||
"""
|
"""
|
||||||
mask = (input_ids == image_token_id)
|
mask = (input_ids == placeholder_token_id)
|
||||||
num_expected_tokens = mask.sum()
|
num_expected_tokens = mask.sum()
|
||||||
|
|
||||||
if isinstance(vision_embeddings, torch.Tensor):
|
if isinstance(multimodal_embeddings, torch.Tensor):
|
||||||
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
|
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
|
||||||
total_tokens = batch_size * batch_tokens
|
total_tokens = batch_size * batch_tokens
|
||||||
if num_expected_tokens != total_tokens:
|
if num_expected_tokens != total_tokens:
|
||||||
expr = f"{batch_size} x {batch_tokens}"
|
expr = f"{batch_size} x {batch_tokens}"
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Attempted to assign {expr} = {total_tokens} "
|
f"Attempted to assign {expr} = {total_tokens} "
|
||||||
f"image tokens to {num_expected_tokens} placeholders")
|
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||||
|
|
||||||
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
|
inputs_embeds[mask] = multimodal_embeddings.view(
|
||||||
|
total_tokens, embed_dim)
|
||||||
else:
|
else:
|
||||||
size_per_batch = [t.shape[0] for t in vision_embeddings]
|
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
|
||||||
total_tokens = sum(size_per_batch)
|
total_tokens = sum(size_per_batch)
|
||||||
if num_expected_tokens != total_tokens:
|
if num_expected_tokens != total_tokens:
|
||||||
expr = ' + '.join(map(str, size_per_batch))
|
expr = ' + '.join(map(str, size_per_batch))
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Attempted to assign {expr} = {total_tokens} "
|
f"Attempted to assign {expr} = {total_tokens} "
|
||||||
f"image tokens to {num_expected_tokens} placeholders")
|
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||||
|
|
||||||
inputs_embeds[mask] = torch.cat(vision_embeddings)
|
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
|
||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
|||||||
17
vllm/multimodal/audio.py
Normal file
17
vllm/multimodal/audio.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from vllm.inputs.registry import InputContext
|
||||||
|
from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin
|
||||||
|
|
||||||
|
|
||||||
|
class AudioPlugin(MultiModalPlugin):
|
||||||
|
"""Plugin for audio data."""
|
||||||
|
|
||||||
|
def get_data_key(self) -> str:
|
||||||
|
return "audio"
|
||||||
|
|
||||||
|
def _default_input_mapper(self, ctx: InputContext,
|
||||||
|
data: object) -> MultiModalInputs:
|
||||||
|
raise NotImplementedError("There is no default audio input mapper")
|
||||||
|
|
||||||
|
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"There is no default maximum multimodal tokens")
|
||||||
@ -3,8 +3,9 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Type, TypedDict, TypeVar, Union, cast
|
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.types
|
import torch.types
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -121,6 +122,9 @@ class MultiModalDataBuiltins(TypedDict, total=False):
|
|||||||
image: Image.Image
|
image: Image.Image
|
||||||
"""The input image."""
|
"""The input image."""
|
||||||
|
|
||||||
|
audio: Tuple[np.ndarray, Union[int, float]]
|
||||||
|
"""The input audio and its sampling rate."""
|
||||||
|
|
||||||
|
|
||||||
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
|
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from .audio import AudioPlugin
|
||||||
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
|
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
|
||||||
MultiModalPlugin, MultiModalTokensCalc)
|
MultiModalPlugin, MultiModalTokensCalc)
|
||||||
from .image import ImagePlugin
|
from .image import ImagePlugin
|
||||||
@ -19,7 +20,7 @@ class MultiModalRegistry:
|
|||||||
:class:`~vllm.multimodal.MultiModalPlugin` for each modality.
|
:class:`~vllm.multimodal.MultiModalPlugin` for each modality.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_PLUGINS = (ImagePlugin(), )
|
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin())
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import soundfile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from vllm.connections import global_http_connection
|
from vllm.connections import global_http_connection
|
||||||
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
|
from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT
|
||||||
from vllm.multimodal.base import MultiModalDataDict
|
from vllm.multimodal.base import MultiModalDataDict
|
||||||
|
|
||||||
|
|
||||||
@ -63,11 +66,62 @@ async def async_fetch_image(image_url: str,
|
|||||||
return image.convert(image_mode)
|
return image.convert(image_mode)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
|
||||||
|
"""
|
||||||
|
Load audio from a URL.
|
||||||
|
"""
|
||||||
|
if audio_url.startswith("http"):
|
||||||
|
audio_bytes = global_http_connection.get_bytes(
|
||||||
|
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
|
||||||
|
elif audio_url.startswith("data:audio"):
|
||||||
|
_, audio_base64 = audio_url.split(",", 1)
|
||||||
|
audio_bytes = base64.b64decode(audio_base64)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
|
||||||
|
"with either 'data:audio' or 'http'.")
|
||||||
|
|
||||||
|
return librosa.load(BytesIO(audio_bytes), sr=None)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_fetch_audio(
|
||||||
|
audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
|
||||||
|
"""
|
||||||
|
Asynchronously fetch audio from a URL.
|
||||||
|
"""
|
||||||
|
if audio_url.startswith("http"):
|
||||||
|
audio_bytes = await global_http_connection.async_get_bytes(
|
||||||
|
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
|
||||||
|
elif audio_url.startswith("data:audio"):
|
||||||
|
_, audio_base64 = audio_url.split(",", 1)
|
||||||
|
audio_bytes = base64.b64decode(audio_base64)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
|
||||||
|
"with either 'data:audio' or 'http'.")
|
||||||
|
|
||||||
|
return librosa.load(BytesIO(audio_bytes), sr=None)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
||||||
|
audio, sr = await async_fetch_audio(audio_url)
|
||||||
|
return {"audio": (audio, sr)}
|
||||||
|
|
||||||
|
|
||||||
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
|
async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
|
||||||
image = await async_fetch_image(image_url)
|
image = await async_fetch_image(image_url)
|
||||||
return {"image": image}
|
return {"image": image}
|
||||||
|
|
||||||
|
|
||||||
|
def encode_audio_base64(
|
||||||
|
audio: np.ndarray,
|
||||||
|
sampling_rate: int,
|
||||||
|
) -> str:
|
||||||
|
"""Encode audio as base64."""
|
||||||
|
buffered = BytesIO()
|
||||||
|
soundfile.write(buffered, audio, sampling_rate, format="WAV")
|
||||||
|
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
def encode_image_base64(
|
def encode_image_base64(
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@ -40,7 +40,7 @@ from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
|||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.model_executor.models.interfaces import (supports_lora,
|
from vllm.model_executor.models.interfaces import (supports_lora,
|
||||||
supports_vision)
|
supports_multimodal)
|
||||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||||
MultiModalInputs)
|
MultiModalInputs)
|
||||||
@ -900,9 +900,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
assert supports_lora(self.model), "Model does not support LoRA"
|
assert supports_lora(self.model), "Model does not support LoRA"
|
||||||
assert not supports_vision(
|
assert not supports_multimodal(
|
||||||
self.model
|
self.model
|
||||||
), "To be tested: vision language model with LoRA settings."
|
), "To be tested: multimodal language model with LoRA settings."
|
||||||
|
|
||||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
self.scheduler_config.max_num_seqs,
|
self.scheduler_config.max_num_seqs,
|
||||||
@ -1054,7 +1054,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
# of images processed.
|
# of images processed.
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
|
|
||||||
if supports_vision(self.model):
|
if supports_multimodal(self.model):
|
||||||
max_mm_tokens = MULTIMODAL_REGISTRY \
|
max_mm_tokens = MULTIMODAL_REGISTRY \
|
||||||
.get_max_multimodal_tokens(model_config)
|
.get_max_multimodal_tokens(model_config)
|
||||||
max_num_seqs_orig = max_num_seqs
|
max_num_seqs_orig = max_num_seqs
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm.distributed import broadcast_tensor_dict
|
|||||||
from vllm.inputs import INPUT_REGISTRY
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models.interfaces import supports_vision
|
from vllm.model_executor.models.interfaces import supports_multimodal
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||||
MultiModalInputs)
|
MultiModalInputs)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
@ -165,7 +165,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
|||||||
# of images processed.
|
# of images processed.
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
|
|
||||||
if supports_vision(self.model):
|
if supports_multimodal(self.model):
|
||||||
max_mm_tokens = MULTIMODAL_REGISTRY \
|
max_mm_tokens = MULTIMODAL_REGISTRY \
|
||||||
.get_max_multimodal_tokens(model_config)
|
.get_max_multimodal_tokens(model_config)
|
||||||
max_num_seqs_orig = max_num_seqs
|
max_num_seqs_orig = max_num_seqs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user