mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:27:27 +08:00
Voxtral (#20970)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
4ffd963fa0
commit
e7e3e6d263
@ -10,7 +10,7 @@ on HuggingFace model repository.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import NamedTuple, Optional
|
from typing import Any, NamedTuple, Optional
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
@ -30,7 +30,9 @@ question_per_audio_count = {
|
|||||||
|
|
||||||
class ModelRequestData(NamedTuple):
|
class ModelRequestData(NamedTuple):
|
||||||
engine_args: EngineArgs
|
engine_args: EngineArgs
|
||||||
prompt: str
|
prompt: Optional[str] = None
|
||||||
|
prompt_token_ids: Optional[dict[str, list[int]]] = None
|
||||||
|
multi_modal_data: Optional[dict[str, Any]] = None
|
||||||
stop_token_ids: Optional[list[int]] = None
|
stop_token_ids: Optional[list[int]] = None
|
||||||
lora_requests: Optional[list[LoRARequest]] = None
|
lora_requests: Optional[list[LoRARequest]] = None
|
||||||
|
|
||||||
@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple):
|
|||||||
# Unless specified, these settings have been tested to work on a single L4.
|
# Unless specified, these settings have been tested to work on a single L4.
|
||||||
|
|
||||||
|
|
||||||
|
# Voxtral
|
||||||
|
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||||
|
from mistral_common.audio import Audio
|
||||||
|
from mistral_common.protocol.instruct.messages import (
|
||||||
|
AudioChunk,
|
||||||
|
RawAudio,
|
||||||
|
TextChunk,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
|
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||||
|
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=2,
|
||||||
|
limit_mm_per_prompt={"audio": audio_count},
|
||||||
|
config_format="mistral",
|
||||||
|
load_format="mistral",
|
||||||
|
tokenizer_mode="mistral",
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_chunked_prefill=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_chunk = TextChunk(text=question)
|
||||||
|
audios = [
|
||||||
|
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||||
|
for i in range(audio_count)
|
||||||
|
]
|
||||||
|
audio_chunks = [
|
||||||
|
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||||
|
|
||||||
|
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||||
|
|
||||||
|
tokens = tokenizer.encode_chat_completion(req)
|
||||||
|
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||||
|
|
||||||
|
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||||
|
|
||||||
|
multi_modal_data = {"audio": audios_and_sr}
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompt_token_ids=prompt_ids,
|
||||||
|
multi_modal_data=multi_modal_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Granite Speech
|
# Granite Speech
|
||||||
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||||
# NOTE - the setting in this example are somehat different than what is
|
# NOTE - the setting in this example are somehat different than what is
|
||||||
@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
|||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
|
"voxtral": run_voxtral,
|
||||||
"granite_speech": run_granite_speech,
|
"granite_speech": run_granite_speech,
|
||||||
"minicpmo": run_minicpmo,
|
"minicpmo": run_minicpmo,
|
||||||
"phi4_mm": run_phi4mm,
|
"phi4_mm": run_phi4mm,
|
||||||
@ -311,6 +368,8 @@ def main(args):
|
|||||||
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mm_data = req_data.multi_modal_data
|
||||||
|
if not mm_data:
|
||||||
mm_data = {}
|
mm_data = {}
|
||||||
if audio_count > 0:
|
if audio_count > 0:
|
||||||
mm_data = {
|
mm_data = {
|
||||||
@ -320,7 +379,13 @@ def main(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert args.num_prompts > 0
|
assert args.num_prompts > 0
|
||||||
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
|
inputs = {"multi_modal_data": mm_data}
|
||||||
|
|
||||||
|
if req_data.prompt:
|
||||||
|
inputs["prompt"] = req_data.prompt
|
||||||
|
else:
|
||||||
|
inputs["prompt_token_ids"] = req_data.prompt_token_ids
|
||||||
|
|
||||||
if args.num_prompts > 1:
|
if args.num_prompts > 1:
|
||||||
# Batch inference
|
# Batch inference
|
||||||
inputs = [inputs] * args.num_prompts
|
inputs = [inputs] * args.num_prompts
|
||||||
|
|||||||
@ -33,7 +33,7 @@ pyzmq >= 25.0.0
|
|||||||
msgspec
|
msgspec
|
||||||
gguf >= 0.13.0
|
gguf >= 0.13.0
|
||||||
importlib_metadata; python_version < '3.10'
|
importlib_metadata; python_version < '3.10'
|
||||||
mistral_common[opencv] >= 1.6.2
|
mistral_common[opencv] >= 1.8.0
|
||||||
opencv-python-headless >= 4.11.0 # required for video IO
|
opencv-python-headless >= 4.11.0 # required for video IO
|
||||||
pyyaml
|
pyyaml
|
||||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||||
|
|||||||
@ -23,7 +23,7 @@ jiwer # required for audio tests
|
|||||||
timm # required for internvl test
|
timm # required for internvl test
|
||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
mistral_common[opencv] >= 1.6.2 # required for pixtral test
|
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||||
num2words # required for smolvlm test
|
num2words # required for smolvlm test
|
||||||
opencv-python-headless >= 4.11.0 # required for video test
|
opencv-python-headless >= 4.11.0 # required for video test
|
||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
|
|||||||
@ -28,7 +28,7 @@ torchvision==0.22.0
|
|||||||
transformers_stream_generator # required for qwen-vl test
|
transformers_stream_generator # required for qwen-vl test
|
||||||
mamba_ssm # required for plamo2 test
|
mamba_ssm # required for plamo2 test
|
||||||
matplotlib # required for qwen-vl test
|
matplotlib # required for qwen-vl test
|
||||||
mistral_common[opencv] >= 1.7.0 # required for pixtral test
|
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||||
num2words # required for smolvlm test
|
num2words # required for smolvlm test
|
||||||
opencv-python-headless >= 4.11.0 # required for video test
|
opencv-python-headless >= 4.11.0 # required for video test
|
||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
|
|||||||
@ -305,7 +305,7 @@ mbstrdecoder==1.1.3
|
|||||||
# typepy
|
# typepy
|
||||||
mdurl==0.1.2
|
mdurl==0.1.2
|
||||||
# via markdown-it-py
|
# via markdown-it-py
|
||||||
mistral-common==1.7.0
|
mistral-common==1.8.0
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
more-itertools==10.5.0
|
more-itertools==10.5.0
|
||||||
# via lm-eval
|
# via lm-eval
|
||||||
@ -518,6 +518,8 @@ pyasn1-modules==0.4.2
|
|||||||
# via google-auth
|
# via google-auth
|
||||||
pybind11==2.13.6
|
pybind11==2.13.6
|
||||||
# via lm-eval
|
# via lm-eval
|
||||||
|
pycountry==24.6.1
|
||||||
|
# via pydantic-extra-types
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
# via cffi
|
# via cffi
|
||||||
pycryptodomex==3.22.0
|
pycryptodomex==3.22.0
|
||||||
@ -528,9 +530,12 @@ pydantic==2.11.5
|
|||||||
# datamodel-code-generator
|
# datamodel-code-generator
|
||||||
# mistral-common
|
# mistral-common
|
||||||
# mteb
|
# mteb
|
||||||
|
# pydantic-extra-types
|
||||||
# ray
|
# ray
|
||||||
pydantic-core==2.33.2
|
pydantic-core==2.33.2
|
||||||
# via pydantic
|
# via pydantic
|
||||||
|
pydantic-extra-types==2.10.5
|
||||||
|
# via mistral-common
|
||||||
pygments==2.18.0
|
pygments==2.18.0
|
||||||
# via rich
|
# via rich
|
||||||
pyparsing==3.2.0
|
pyparsing==3.2.0
|
||||||
@ -835,6 +840,7 @@ typing-extensions==4.12.2
|
|||||||
# pqdm
|
# pqdm
|
||||||
# pydantic
|
# pydantic
|
||||||
# pydantic-core
|
# pydantic-core
|
||||||
|
# pydantic-extra-types
|
||||||
# torch
|
# torch
|
||||||
# typer
|
# typer
|
||||||
# typing-inspection
|
# typing-inspection
|
||||||
|
|||||||
3
setup.py
3
setup.py
@ -692,7 +692,8 @@ setup(
|
|||||||
"tensorizer": ["tensorizer==2.10.1"],
|
"tensorizer": ["tensorizer==2.10.1"],
|
||||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||||
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
||||||
"audio": ["librosa", "soundfile"], # Required for audio processing
|
"audio": ["librosa", "soundfile",
|
||||||
|
"mistral_common[audio]"], # Required for audio processing
|
||||||
"video": [] # Kept for backwards compatibility
|
"video": [] # Kept for backwards compatibility
|
||||||
},
|
},
|
||||||
cmdclass=cmdclass,
|
cmdclass=cmdclass,
|
||||||
|
|||||||
@ -17,6 +17,11 @@ from vllm.assets.audio import AudioAsset
|
|||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MISTRAL_FORMAT_ARGS = [
|
||||||
|
"--tokenizer_mode", "mistral", "--config_format", "mistral",
|
||||||
|
"--load_format", "mistral"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mary_had_lamb():
|
def mary_had_lamb():
|
||||||
@ -33,9 +38,18 @@ def winning_call():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_audio(mary_had_lamb):
|
@pytest.mark.parametrize(
|
||||||
model_name = "openai/whisper-large-v3-turbo"
|
"model_name",
|
||||||
|
["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
|
||||||
|
async def test_basic_audio(mary_had_lamb, model_name):
|
||||||
server_args = ["--enforce-eager"]
|
server_args = ["--enforce-eager"]
|
||||||
|
|
||||||
|
if model_name.startswith("mistralai"):
|
||||||
|
server_args += MISTRAL_FORMAT_ARGS
|
||||||
|
|
||||||
|
# TODO(PATRICK) - REMOVE AFTER RELEASE
|
||||||
|
return # skip for now
|
||||||
|
|
||||||
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
client = remote_server.get_async_client()
|
client = remote_server.get_async_client()
|
||||||
@ -65,10 +79,13 @@ async def test_bad_requests(mary_had_lamb):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_long_audio_request(mary_had_lamb):
|
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
|
||||||
model_name = "openai/whisper-large-v3-turbo"
|
async def test_long_audio_request(mary_had_lamb, model_name):
|
||||||
server_args = ["--enforce-eager"]
|
server_args = ["--enforce-eager"]
|
||||||
|
|
||||||
|
if model_name.startswith("openai"):
|
||||||
|
return
|
||||||
|
|
||||||
mary_had_lamb.seek(0)
|
mary_had_lamb.seek(0)
|
||||||
audio, sr = librosa.load(mary_had_lamb)
|
audio, sr = librosa.load(mary_had_lamb)
|
||||||
# Add small silence after each audio for repeatability in the split process
|
# Add small silence after each audio for repeatability in the split process
|
||||||
@ -87,7 +104,8 @@ async def test_long_audio_request(mary_had_lamb):
|
|||||||
response_format="text",
|
response_format="text",
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
out = json.loads(transcription)['text']
|
out = json.loads(transcription)['text']
|
||||||
assert out.count("Mary had a little lamb") == 10
|
counts = out.count("Mary had a little lamb")
|
||||||
|
assert counts == 10, counts
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -440,6 +440,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
|
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
|
||||||
trust_remote_code=True), # noqa: E501
|
trust_remote_code=True), # noqa: E501
|
||||||
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
||||||
|
"VoxtralForConditionalGeneration": _HfExamplesInfo("mistralai/Voxtral-Mini-3B-2507", is_available_online=False, tokenizer_mode="mistral"), # noqa: E501
|
||||||
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
|
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
|
||||||
|
|
||||||
# [Cross-encoder]
|
# [Cross-encoder]
|
||||||
|
|||||||
@ -112,6 +112,7 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
prompt = self.model_cls.get_generation_prompt(
|
prompt = self.model_cls.get_generation_prompt(
|
||||||
audio=chunk,
|
audio=chunk,
|
||||||
stt_config=self.asr_config,
|
stt_config=self.asr_config,
|
||||||
|
model_config=self.model_config,
|
||||||
language=lang,
|
language=lang,
|
||||||
task_type=self.task_type,
|
task_type=self.task_type,
|
||||||
request_prompt=request.prompt)
|
request_prompt=request.prompt)
|
||||||
|
|||||||
@ -722,7 +722,8 @@ class SupportsTranscription(Protocol):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_generation_prompt(cls, audio: np.ndarray,
|
def get_generation_prompt(cls, audio: np.ndarray,
|
||||||
stt_config: SpeechToTextConfig, language: str,
|
stt_config: SpeechToTextConfig,
|
||||||
|
model_config: ModelConfig, language: str,
|
||||||
task_type: str,
|
task_type: str,
|
||||||
request_prompt: str) -> PromptType:
|
request_prompt: str) -> PromptType:
|
||||||
"""Get the prompt for the ASR model.
|
"""Get the prompt for the ASR model.
|
||||||
|
|||||||
@ -231,6 +231,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
||||||
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
|
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
|
||||||
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
|
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
|
||||||
|
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||||
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
||||||
|
|||||||
691
vllm/model_executor/models/voxtral.py
Normal file
691
vllm/model_executor/models/voxtral.py
Normal file
@ -0,0 +1,691 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
|
from functools import cached_property
|
||||||
|
from math import ceil
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import regex as re
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mistral_common.audio import mel_filter_bank
|
||||||
|
from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio,
|
||||||
|
TextChunk, UserMessage)
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
|
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
||||||
|
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
|
||||||
|
from transformers import TensorType, WhisperConfig
|
||||||
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||||
|
from vllm.inputs.data import PromptType
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models import SupportsPP
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.model_executor.models.whisper import (
|
||||||
|
WhisperEncoder, WhisperForConditionalGeneration)
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
|
MultiModalKwargs, NestedTensors)
|
||||||
|
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||||
|
MultiModalDataParser)
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
BaseProcessingInfo, MultiModalHashes,
|
||||||
|
PromptReplacement, PromptUpdate)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||||
|
cached_tokenizer_from_config)
|
||||||
|
|
||||||
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||||
|
SupportsTranscription)
|
||||||
|
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||||
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralProcessorAdapter:
|
||||||
|
"""
|
||||||
|
Provide a HF-compatible interface for
|
||||||
|
:class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: MistralTokenizer) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _audio_processor(self) -> AudioEncoder:
|
||||||
|
audio_encoder = self.tokenizer.instruct.audio_encoder
|
||||||
|
assert isinstance(audio_encoder, AudioEncoder)
|
||||||
|
return audio_encoder
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def audio_token_id(self) -> int:
|
||||||
|
return self._audio_processor.special_ids.audio
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def begin_audio_token_id(self) -> int:
|
||||||
|
return self._audio_processor.special_ids.begin_audio
|
||||||
|
|
||||||
|
# @cached_property
|
||||||
|
# def begin_transcript_token_id(self) -> int:
|
||||||
|
# return self._audio_processor.special_ids.begin_transcript
|
||||||
|
|
||||||
|
# @cached_property
|
||||||
|
# def end_transcript_token_id(self) -> int:
|
||||||
|
# return self._audio_processor.special_ids.end_transcript
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sampling_rate(self) -> int:
|
||||||
|
return self._audio_processor.audio_config.sampling_rate
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def frame_rate(self) -> float:
|
||||||
|
return self._audio_processor.audio_config.frame_rate
|
||||||
|
|
||||||
|
def get_num_audio_tokens(
|
||||||
|
self,
|
||||||
|
audio_length: int,
|
||||||
|
) -> int:
|
||||||
|
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
|
||||||
|
audio_length, self.sampling_rate)
|
||||||
|
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||||
|
audios: Optional[Union[np.ndarray, list[np.ndarray]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Mapping[str, NestedTensors]:
|
||||||
|
if text is None:
|
||||||
|
text = []
|
||||||
|
if not isinstance(text, list):
|
||||||
|
text = [text]
|
||||||
|
if audios is None:
|
||||||
|
audios = []
|
||||||
|
if not isinstance(audios, list):
|
||||||
|
audios = [audios]
|
||||||
|
|
||||||
|
if not audios:
|
||||||
|
input_ids = self.tokenizer(text).input_ids
|
||||||
|
return {"input_ids": torch.tensor(input_ids)}
|
||||||
|
|
||||||
|
# Allow dummy text, which is used for profiling as well as token inputs
|
||||||
|
if any(len(t) > 0 for t in text):
|
||||||
|
raise ValueError(
|
||||||
|
"You've passed text inputs instead of token inputs. "
|
||||||
|
"Make sure to process your input via `mistral_common`'s "
|
||||||
|
"tokenizer or pass a chat completion request. "
|
||||||
|
"For more info, see: "
|
||||||
|
"https://github.com/vllm-project/vllm/issues/8411.")
|
||||||
|
|
||||||
|
audios_tokens = list[torch.Tensor]()
|
||||||
|
audios_processed = list[torch.Tensor]()
|
||||||
|
for audio in audios:
|
||||||
|
assert isinstance(audio, np.ndarray)
|
||||||
|
assert audio.ndim == 1
|
||||||
|
|
||||||
|
# pad if necessary
|
||||||
|
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||||
|
|
||||||
|
audio_tokens = [
|
||||||
|
self.begin_audio_token_id
|
||||||
|
] + [self.audio_token_id] * self.get_num_audio_tokens(len(audio))
|
||||||
|
|
||||||
|
audios_tokens.append(torch.tensor(audio_tokens))
|
||||||
|
audios_processed.append(torch.tensor(audio))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
|
||||||
|
"audio_arrays": audios_processed,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
|
def get_tokenizer(self) -> MistralTokenizer:
|
||||||
|
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
|
||||||
|
if not isinstance(tokenizer, MistralTokenizer):
|
||||||
|
raise ValueError("This model requires `--tokenizer-mode mistral`")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def get_hf_processor(self) -> VoxtralProcessorAdapter:
|
||||||
|
return VoxtralProcessorAdapter(self.get_tokenizer())
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"audio": 5} # Performance tends to degrade after 5
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
return {"audio": self.get_max_audio_tokens()}
|
||||||
|
|
||||||
|
def get_max_audio_tokens(self) -> int:
|
||||||
|
return self.ctx.model_config.max_model_len
|
||||||
|
|
||||||
|
def get_max_audio_array_len(self) -> int:
|
||||||
|
processor = self.get_hf_processor()
|
||||||
|
return self.get_max_audio_tokens() * int(
|
||||||
|
processor.sampling_rate // processor.frame_rate)
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_dummy_mm_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> MultiModalDataDict:
|
||||||
|
num_audios = mm_counts.get("audio", 0)
|
||||||
|
|
||||||
|
target_length = self.info.get_max_audio_array_len()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"audio":
|
||||||
|
self._get_dummy_audios(length=target_length, num_audios=num_audios)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
|
||||||
|
dummy_text = self.get_dummy_text(mm_counts)
|
||||||
|
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
||||||
|
dummy_audios = dummy_mm_data.get("audio", [])
|
||||||
|
|
||||||
|
audio_chunks: list[AudioChunk] = []
|
||||||
|
format = "wav"
|
||||||
|
for audio in dummy_audios:
|
||||||
|
audio_item = Audio(
|
||||||
|
audio_array=audio,
|
||||||
|
sampling_rate=self.info.get_hf_processor().sampling_rate,
|
||||||
|
format=format,
|
||||||
|
)
|
||||||
|
chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
|
||||||
|
audio_chunks.append(chunk)
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(messages=[
|
||||||
|
UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
|
||||||
|
])
|
||||||
|
res = tokenizer.mistral.encode_chat_completion(request)
|
||||||
|
dummy_tokens = res.tokens
|
||||||
|
# whixtral tokenizer adds padding to the audio
|
||||||
|
# so we need to update the audio arrays
|
||||||
|
dummy_mm_data["audio"] = [a.audio_array for a in res.audios]
|
||||||
|
|
||||||
|
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
||||||
|
):
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: Mapping[str, NestedTensors],
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(audio_arrays=MultiModalFieldConfig.batched("audio"))
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> Sequence[PromptUpdate]:
|
||||||
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
|
||||||
|
audio_id = processor.audio_token_id
|
||||||
|
|
||||||
|
def get_replacement(item_idx: int):
|
||||||
|
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||||
|
audio_len = audios.get_audio_length(item_idx)
|
||||||
|
|
||||||
|
nb_audio_tokens = processor.get_num_audio_tokens(audio_len)
|
||||||
|
|
||||||
|
return [audio_id] * nb_audio_tokens
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="audio",
|
||||||
|
target="", # Never match the prompt (see below note)
|
||||||
|
replacement=get_replacement,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _cached_apply_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[int]],
|
||||||
|
mm_data_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
tokenization_kwargs: Mapping[str, object],
|
||||||
|
*,
|
||||||
|
return_mm_hashes: bool,
|
||||||
|
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
||||||
|
prompt_ids, mm_kwargs, mm_hashes, _ = super(
|
||||||
|
)._cached_apply_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data_items=mm_data_items,
|
||||||
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
|
return_mm_hashes=return_mm_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: The tokens are already inserted by the chat template
|
||||||
|
return prompt_ids, mm_kwargs, mm_hashes, True
|
||||||
|
|
||||||
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
|
sampling_rate = self.info.get_hf_processor().sampling_rate
|
||||||
|
return MultiModalDataParser(target_sr=sampling_rate)
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(VoxtralMultiModalProcessor,
|
||||||
|
info=VoxtralProcessingInfo,
|
||||||
|
dummy_inputs=VoxtralDummyInputsBuilder)
|
||||||
|
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
|
SupportsPP, SupportsTranscription):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
self.config = config
|
||||||
|
self.downsample_factor = self.config.audio_config.downsample_factor
|
||||||
|
|
||||||
|
self.language_model = init_vllm_registered_model(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
hf_config=config.text_config,
|
||||||
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
|
)
|
||||||
|
self.whisper_encoder = VoxtralEncoderModel(
|
||||||
|
vllm_config.with_hf_config(config.audio_config),
|
||||||
|
prefix=maybe_prefix(prefix, "whisper_encoder"),
|
||||||
|
)
|
||||||
|
self.audio_language_adapter = AudioLanguageAdapter(
|
||||||
|
hidden_size=config.audio_config.d_model * self.downsample_factor,
|
||||||
|
dim=config.text_config.hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
|
return self.language_model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
|
# condition is for v0 compatibility.
|
||||||
|
elif inputs_embeds is None:
|
||||||
|
audio_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
|
audio_embeddings)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model.model(input_ids,
|
||||||
|
positions,
|
||||||
|
intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs
|
||||||
|
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...],
|
||||||
|
None]:
|
||||||
|
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
|
||||||
|
if audio_inputs is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
audio_embeddings = self.whisper_encoder(audio_inputs)
|
||||||
|
|
||||||
|
for i, audio_embedding in enumerate(audio_embeddings):
|
||||||
|
seq_len, dim = audio_embedding.shape
|
||||||
|
# Pad such that seq_len is divisible by downsample_factor
|
||||||
|
target_seq_len = self.downsample_factor * math.ceil(
|
||||||
|
seq_len / self.downsample_factor)
|
||||||
|
audio_embedding = torch.nn.functional.pad(
|
||||||
|
audio_embedding,
|
||||||
|
(0, 0, 0, target_seq_len - seq_len),
|
||||||
|
)
|
||||||
|
audio_embeddings[i] = audio_embedding.reshape(
|
||||||
|
target_seq_len // self.downsample_factor,
|
||||||
|
dim * self.downsample_factor)
|
||||||
|
|
||||||
|
# Concat, project and resplit
|
||||||
|
audio_embeddings_packed = torch.cat(audio_embeddings, dim=0)
|
||||||
|
audio_embeddings_packed = self.audio_language_adapter(
|
||||||
|
audio_embeddings_packed)
|
||||||
|
audio_embeddings = torch.split(audio_embeddings_packed,
|
||||||
|
[a.shape[0] for a in audio_embeddings],
|
||||||
|
dim=0)
|
||||||
|
|
||||||
|
return audio_embeddings
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
audio_encoder = self.tokenizer.instruct.audio_encoder
|
||||||
|
audio_tok_id = audio_encoder.audio_token
|
||||||
|
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
if multimodal_embeddings is not None:
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def _parse_and_validate_audio_arrays(
|
||||||
|
self, **kwargs: object) -> Union[list[torch.Tensor], None]:
|
||||||
|
audio_arrays = kwargs.pop("audio_arrays", None)
|
||||||
|
if audio_arrays is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(audio_arrays, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of audio_arrays. "
|
||||||
|
f"Got type: {type(audio_arrays)}")
|
||||||
|
|
||||||
|
audio_arrays = flatten_bn(audio_arrays)
|
||||||
|
if isinstance(audio_arrays, torch.Tensor):
|
||||||
|
audio_arrays = list(audio_arrays.unbind(0))
|
||||||
|
return audio_arrays
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_speech_to_text_config(cls, model_config: ModelConfig,
|
||||||
|
task_type: str) -> SpeechToTextConfig:
|
||||||
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
|
audio_config = tokenizer.instruct.audio_encoder.audio_config
|
||||||
|
max_audio_clip_s = audio_config.chunk_length_s
|
||||||
|
sample_rate = audio_config.sampling_rate
|
||||||
|
return SpeechToTextConfig(
|
||||||
|
max_audio_clip_s=max_audio_clip_s,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
# mistral_common and whisper encoder take care of chunking
|
||||||
|
min_energy_split_window_size=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
# for speech-to-text transcription
|
||||||
|
def get_generation_prompt(cls, audio: np.ndarray,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
stt_config: SpeechToTextConfig, language: str,
|
||||||
|
task_type: str,
|
||||||
|
request_prompt: str) -> PromptType:
|
||||||
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
|
audio = Audio(audio, int(stt_config.sample_rate),
|
||||||
|
format="wav") # lossless
|
||||||
|
req = TranscriptionRequest(model=model_config.model,
|
||||||
|
audio=RawAudio.from_audio(audio),
|
||||||
|
language=language)
|
||||||
|
|
||||||
|
tokenized = tokenizer.instruct.encode_transcription(req)
|
||||||
|
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
|
||||||
|
prompts_dict = {"multi_modal_data": {"audio": audio}}
|
||||||
|
prompts_dict["prompt_token_ids"] = tokenized.tokens
|
||||||
|
return cast(PromptType, prompts_dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_language(cls, language: str) -> bool:
|
||||||
|
# same as whisper
|
||||||
|
return WhisperForConditionalGeneration.validate_language(language)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||||
|
stt_config: SpeechToTextConfig,
|
||||||
|
model_config: ModelConfig) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Map from audio duration to number of audio tokens produced by the ASR
|
||||||
|
model, without running a forward pass.
|
||||||
|
This is used for estimating the amount of processing for this audio.
|
||||||
|
"""
|
||||||
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
|
adapter = VoxtralProcessorAdapter(tokenizer)
|
||||||
|
return adapter.get_num_audio_tokens(
|
||||||
|
int(audio_duration_s * stt_config.sample_rate))
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
# fmt: off
|
||||||
|
remapping_rules = [
|
||||||
|
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
||||||
|
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
||||||
|
(r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"), # noqa: E501
|
||||||
|
(r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"), # noqa: E501
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
audio_params = dict(
|
||||||
|
nn.ModuleDict({
|
||||||
|
"audio_language_adapter":
|
||||||
|
self.audio_language_adapter,
|
||||||
|
}).named_parameters())
|
||||||
|
|
||||||
|
loaded_weights = set()
|
||||||
|
|
||||||
|
def llm_weights_generator():
|
||||||
|
nonlocal loaded_weights
|
||||||
|
for name, w in weights:
|
||||||
|
is_encoder = (
|
||||||
|
name.startswith("mm_whisper_embeddings") and
|
||||||
|
not name.startswith("mm_whisper_embeddings.tok_embeddings")
|
||||||
|
and not name.startswith(
|
||||||
|
"mm_whisper_embeddings.audio_language_projection"))
|
||||||
|
|
||||||
|
for pattern, repl in remapping_rules:
|
||||||
|
if re.fullmatch(pattern, name):
|
||||||
|
name = re.sub(pattern, repl, name)
|
||||||
|
|
||||||
|
if is_encoder:
|
||||||
|
name = self.whisper_encoder.load_weight((name, w))
|
||||||
|
loaded_weights.add(f"whisper_encoder.{name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name in audio_params:
|
||||||
|
param = audio_params[name]
|
||||||
|
with torch.no_grad():
|
||||||
|
default_weight_loader(param, w)
|
||||||
|
loaded_weights.add(name)
|
||||||
|
else:
|
||||||
|
yield (name, w)
|
||||||
|
|
||||||
|
for name in self.language_model.load_weights(llm_weights_generator()):
|
||||||
|
loaded_weights.add(f"language_model.{name}")
|
||||||
|
|
||||||
|
# potentially manually add position embeddings
|
||||||
|
sin_key = "whisper_encoder.whisper_encoder.embed_positions.weight"
|
||||||
|
if sin_key not in loaded_weights:
|
||||||
|
# make sure we don't hit an error here
|
||||||
|
loaded_weights.add(sin_key)
|
||||||
|
|
||||||
|
return loaded_weights
|
||||||
|
|
||||||
|
|
||||||
|
class AudioLanguageAdapter(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, dim: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.w_in = nn.Linear(hidden_size, dim, bias=False)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
self.w_out = nn.Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.w_out(self.gelu(self.w_in(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralEncoderModel(nn.Module):
|
||||||
|
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
mistral_remapping = [
|
||||||
|
(r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501
|
||||||
|
(r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501
|
||||||
|
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501
|
||||||
|
(r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
*,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
|
||||||
|
self.dtype: torch.dtype = vllm_config.model_config.dtype
|
||||||
|
self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "whisper_encoder"),
|
||||||
|
is_standalone_encoder=True,
|
||||||
|
init_in_fp32=True)
|
||||||
|
mel_filters = mel_filter_bank(
|
||||||
|
num_frequency_bins=1 + self.config.window_size // 2,
|
||||||
|
num_mel_bins=self.config.num_mel_bins,
|
||||||
|
min_frequency=0.0,
|
||||||
|
max_frequency=8000.0,
|
||||||
|
sampling_rate=self.config.sampling_rate,
|
||||||
|
)
|
||||||
|
self.mel_filters = torch.tensor(mel_filters, dtype=torch.float32)
|
||||||
|
|
||||||
|
def compute_whisper_melspec(
|
||||||
|
self,
|
||||||
|
audio_waveforms: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
input_dtype = audio_waveforms.dtype
|
||||||
|
window = torch.hann_window(self.config.window_size).to(
|
||||||
|
audio_waveforms.device)
|
||||||
|
stft = torch.stft(
|
||||||
|
audio_waveforms,
|
||||||
|
self.config.window_size,
|
||||||
|
self.config.hop_length,
|
||||||
|
window=window,
|
||||||
|
return_complex=True,
|
||||||
|
)
|
||||||
|
magnitudes = stft[..., :-1].abs()**2
|
||||||
|
mel_spec = self.mel_filters.T @ magnitudes
|
||||||
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||||
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
|
return log_spec.to(input_dtype)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def downsample_factor(self) -> int:
|
||||||
|
return self.whisper_encoder.conv1.stride[
|
||||||
|
0] * self.whisper_encoder.conv2.stride[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_size(self) -> int:
|
||||||
|
return self.config.max_source_positions * self.downsample_factor
|
||||||
|
|
||||||
|
def prepare_inputs_for_conv(
|
||||||
|
self,
|
||||||
|
audio_waveforms: list[torch.Tensor],
|
||||||
|
) -> tuple[torch.Tensor, list[int]]:
|
||||||
|
assert isinstance(audio_waveforms, list)
|
||||||
|
# list[num_mel_bins, seq_len]
|
||||||
|
input_features = [
|
||||||
|
self.compute_whisper_melspec(audio).to(self.dtype)
|
||||||
|
for audio in audio_waveforms
|
||||||
|
]
|
||||||
|
|
||||||
|
chunked_features: list[torch.Tensor] = []
|
||||||
|
chunks_per_example: list[int] = []
|
||||||
|
for feature in input_features:
|
||||||
|
chunks = feature.split(self.chunk_size, dim=-1)
|
||||||
|
chunked_features += chunks
|
||||||
|
chunks_per_example.append(len(chunks))
|
||||||
|
|
||||||
|
# [total_num_chunks, num_mel_bins, chunk_size]
|
||||||
|
return torch.stack(chunked_features), chunks_per_example
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
if not isinstance(input_features, list):
|
||||||
|
input_features = [input_features]
|
||||||
|
|
||||||
|
# Split long inputs into chunks
|
||||||
|
input_embeds, chunks_per_example = (
|
||||||
|
self.prepare_inputs_for_conv(input_features))
|
||||||
|
|
||||||
|
# [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size]
|
||||||
|
out = self.whisper_encoder([input_embeds])
|
||||||
|
|
||||||
|
# Re-concatenate the chunks
|
||||||
|
chunk_idx = 0
|
||||||
|
results = []
|
||||||
|
for n_chunks in chunks_per_example:
|
||||||
|
result = out[chunk_idx:chunk_idx + n_chunks].flatten(0, 1)
|
||||||
|
results.append(result)
|
||||||
|
chunk_idx += n_chunks
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def load_weight(self, weight: tuple[str, torch.Tensor]) -> str:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
name, loaded_weight = weight
|
||||||
|
for pattern, repl in self.mistral_remapping:
|
||||||
|
if re.fullmatch(pattern, name):
|
||||||
|
name = re.sub(pattern, repl, name)
|
||||||
|
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
return name
|
||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import Optional, TypedDict, Union, cast
|
from typing import Optional, TypedDict, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -13,6 +14,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
|||||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionType
|
from vllm.attention import Attention, AttentionType
|
||||||
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
||||||
VllmConfig)
|
VllmConfig)
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -26,6 +28,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||||
@ -178,6 +181,7 @@ class WhisperAttention(nn.Module):
|
|||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
standalone_encoder: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -213,6 +217,14 @@ class WhisperAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.out_proj",
|
prefix=f"{prefix}.out_proj",
|
||||||
)
|
)
|
||||||
|
if standalone_encoder:
|
||||||
|
self.attn = MultiHeadAttention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -357,7 +369,11 @@ class WhisperMLP(nn.Module):
|
|||||||
|
|
||||||
class WhisperEncoderLayer(nn.Module):
|
class WhisperEncoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
is_standalone_encoder: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
@ -371,6 +387,7 @@ class WhisperEncoderLayer(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
standalone_encoder=is_standalone_encoder,
|
||||||
)
|
)
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
self.mlp = WhisperMLP(
|
self.mlp = WhisperMLP(
|
||||||
@ -462,10 +479,16 @@ class WhisperDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
class WhisperEncoder(nn.Module):
|
class WhisperEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
is_standalone_encoder: bool = False,
|
||||||
|
init_in_fp32: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
embed_dim = config.d_model
|
embed_dim = config.d_model
|
||||||
|
self.is_standalone_encoder = is_standalone_encoder
|
||||||
self.num_mel_bins = config.num_mel_bins
|
self.num_mel_bins = config.num_mel_bins
|
||||||
self.max_source_positions = config.max_source_positions
|
self.max_source_positions = config.max_source_positions
|
||||||
self.embed_scale = (math.sqrt(embed_dim)
|
self.embed_scale = (math.sqrt(embed_dim)
|
||||||
@ -480,17 +503,25 @@ class WhisperEncoder(nn.Module):
|
|||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
padding=1)
|
padding=1)
|
||||||
self.embed_positions = nn.Embedding(self.max_source_positions,
|
|
||||||
embed_dim)
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.encoder_layers,
|
config.encoder_layers,
|
||||||
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
|
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
|
||||||
prefix=f"{prefix}.layers"),
|
prefix=f"{prefix}.layers",
|
||||||
|
is_standalone_encoder=
|
||||||
|
is_standalone_encoder),
|
||||||
prefix=f"{prefix}.layers",
|
prefix=f"{prefix}.layers",
|
||||||
)
|
)
|
||||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||||
|
|
||||||
with torch.no_grad():
|
maybe_fp32_init_ctx = set_default_torch_dtype(
|
||||||
|
torch.float32) if init_in_fp32 else nullcontext()
|
||||||
|
|
||||||
|
with (
|
||||||
|
torch.no_grad(),
|
||||||
|
maybe_fp32_init_ctx,
|
||||||
|
):
|
||||||
|
self.embed_positions = nn.Embedding(self.max_source_positions,
|
||||||
|
embed_dim)
|
||||||
self.embed_positions.weight.copy_(
|
self.embed_positions.weight.copy_(
|
||||||
sinusoids(*self.embed_positions.weight.shape))
|
sinusoids(*self.embed_positions.weight.shape))
|
||||||
|
|
||||||
@ -499,8 +530,10 @@ class WhisperEncoder(nn.Module):
|
|||||||
for features in input_features:
|
for features in input_features:
|
||||||
embeds = nn.functional.gelu(self.conv1(features))
|
embeds = nn.functional.gelu(self.conv1(features))
|
||||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||||
embeds = embeds.permute(1, 0)
|
embeds = embeds.transpose(-1, -2)
|
||||||
embeds = embeds + self.embed_positions.weight[:embeds.size(0), :]
|
embeds = (embeds +
|
||||||
|
self.embed_positions.weight[:embeds.size(-2), :]).to(
|
||||||
|
embeds.dtype)
|
||||||
hidden_states.append(embeds)
|
hidden_states.append(embeds)
|
||||||
hidden_states = torch.cat(hidden_states)
|
hidden_states = torch.cat(hidden_states)
|
||||||
|
|
||||||
@ -792,8 +825,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_generation_prompt(cls, audio: np.ndarray,
|
def get_generation_prompt(
|
||||||
stt_config: SpeechToTextConfig, language: str,
|
cls,
|
||||||
|
audio: np.ndarray,
|
||||||
|
model_config: ModelConfig, # not needed here
|
||||||
|
stt_config: SpeechToTextConfig,
|
||||||
|
language: str,
|
||||||
task_type: str,
|
task_type: str,
|
||||||
request_prompt: str) -> PromptType:
|
request_prompt: str) -> PromptType:
|
||||||
prompt = {
|
prompt = {
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig, WhisperConfig
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@ -24,9 +24,21 @@ def adapt_config_dict(config_dict: dict[str, Any],
|
|||||||
|
|
||||||
if bool(config_dict.get("yarn")):
|
if bool(config_dict.get("yarn")):
|
||||||
config_dict = _remap_mistral_yarn_args(config_dict)
|
config_dict = _remap_mistral_yarn_args(config_dict)
|
||||||
if bool((config_dict.get("multimodal") or {}).get("vision_encoder_args")
|
|
||||||
or config_dict.get("vision_encoder")):
|
is_vision = ((config_dict.get("multimodal")
|
||||||
|
or {}).get("vision_encoder_args")
|
||||||
|
or config_dict.get("vision_encoder"))
|
||||||
|
is_audio = bool(
|
||||||
|
((config_dict.get("multimodal") or {}).get("whisper_model_args")
|
||||||
|
or {}).get("encoder_args"))
|
||||||
|
|
||||||
|
assert not (is_vision and is_audio), \
|
||||||
|
"Vision and audio are mutually exclusive"
|
||||||
|
|
||||||
|
if is_vision:
|
||||||
config_dict = _remap_mistral_vision_args(config_dict)
|
config_dict = _remap_mistral_vision_args(config_dict)
|
||||||
|
if is_audio:
|
||||||
|
config_dict = _remap_mistral_audio_args(config_dict)
|
||||||
|
|
||||||
config = PretrainedConfig.from_dict(config_dict)
|
config = PretrainedConfig.from_dict(config_dict)
|
||||||
|
|
||||||
@ -118,3 +130,35 @@ def _remap_mistral_quantization_args(config: dict) -> dict:
|
|||||||
config["quantization_config"] = quantization_config
|
config["quantization_config"] = quantization_config
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _remap_mistral_audio_args(config: dict) -> dict:
|
||||||
|
whisper_args = config["multimodal"].pop("whisper_model_args")
|
||||||
|
encoder_args = whisper_args["encoder_args"]
|
||||||
|
downsample_args = whisper_args["downsample_args"]
|
||||||
|
|
||||||
|
quant_config = config.get("quantization_config")
|
||||||
|
config = {
|
||||||
|
"model_type":
|
||||||
|
"whixtral",
|
||||||
|
"architectures": ["VoxtralForConditionalGeneration"],
|
||||||
|
"text_config":
|
||||||
|
PretrainedConfig.from_dict(config),
|
||||||
|
"audio_config":
|
||||||
|
WhisperConfig(
|
||||||
|
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
|
||||||
|
window_size=encoder_args["audio_encoding_args"]["window_size"],
|
||||||
|
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
|
||||||
|
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
|
||||||
|
downsample_factor=downsample_args["downsample_factor"],
|
||||||
|
d_model=encoder_args["dim"],
|
||||||
|
encoder_layers=encoder_args["n_layers"],
|
||||||
|
encoder_ffn_dim=encoder_args["hidden_dim"],
|
||||||
|
encoder_attention_heads=encoder_args["n_heads"],
|
||||||
|
vocab_size=encoder_args["vocab_size"],
|
||||||
|
max_source_positions=encoder_args["max_source_positions"],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if quant_config:
|
||||||
|
config["quantization_config"] = quant_config
|
||||||
|
return config
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user