mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:25:21 +08:00
Voxtral (#20970)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
4ffd963fa0
commit
e7e3e6d263
@ -10,7 +10,7 @@ on HuggingFace model repository.
|
||||
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import Any, NamedTuple, Optional
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
@ -30,7 +30,9 @@ question_per_audio_count = {
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompt: str
|
||||
prompt: Optional[str] = None
|
||||
prompt_token_ids: Optional[dict[str, list[int]]] = None
|
||||
multi_modal_data: Optional[dict[str, Any]] = None
|
||||
stop_token_ids: Optional[list[int]] = None
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
|
||||
@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple):
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Voxtral
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(audio_count)
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||
|
||||
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||
|
||||
tokens = tokenizer.encode_chat_completion(req)
|
||||
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||
|
||||
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||
|
||||
multi_modal_data = {"audio": audios_and_sr}
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt_token_ids=prompt_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
# Granite Speech
|
||||
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||
# NOTE - the setting in this example are somehat different than what is
|
||||
@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"voxtral": run_voxtral,
|
||||
"granite_speech": run_granite_speech,
|
||||
"minicpmo": run_minicpmo,
|
||||
"phi4_mm": run_phi4mm,
|
||||
@ -311,16 +368,24 @@ def main(args):
|
||||
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
|
||||
mm_data = {}
|
||||
if audio_count > 0:
|
||||
mm_data = {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
|
||||
]
|
||||
}
|
||||
mm_data = req_data.multi_modal_data
|
||||
if not mm_data:
|
||||
mm_data = {}
|
||||
if audio_count > 0:
|
||||
mm_data = {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
|
||||
]
|
||||
}
|
||||
|
||||
assert args.num_prompts > 0
|
||||
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
|
||||
inputs = {"multi_modal_data": mm_data}
|
||||
|
||||
if req_data.prompt:
|
||||
inputs["prompt"] = req_data.prompt
|
||||
else:
|
||||
inputs["prompt_token_ids"] = req_data.prompt_token_ids
|
||||
|
||||
if args.num_prompts > 1:
|
||||
# Batch inference
|
||||
inputs = [inputs] * args.num_prompts
|
||||
|
||||
@ -33,7 +33,7 @@ pyzmq >= 25.0.0
|
||||
msgspec
|
||||
gguf >= 0.13.0
|
||||
importlib_metadata; python_version < '3.10'
|
||||
mistral_common[opencv] >= 1.6.2
|
||||
mistral_common[opencv] >= 1.8.0
|
||||
opencv-python-headless >= 4.11.0 # required for video IO
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
|
||||
@ -23,7 +23,7 @@ jiwer # required for audio tests
|
||||
timm # required for internvl test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.6.2 # required for pixtral test
|
||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
|
||||
@ -28,7 +28,7 @@ torchvision==0.22.0
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
mamba_ssm # required for plamo2 test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.7.0 # required for pixtral test
|
||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
|
||||
@ -305,7 +305,7 @@ mbstrdecoder==1.1.3
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.7.0
|
||||
mistral-common==1.8.0
|
||||
# via -r requirements/test.in
|
||||
more-itertools==10.5.0
|
||||
# via lm-eval
|
||||
@ -518,6 +518,8 @@ pyasn1-modules==0.4.2
|
||||
# via google-auth
|
||||
pybind11==2.13.6
|
||||
# via lm-eval
|
||||
pycountry==24.6.1
|
||||
# via pydantic-extra-types
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pycryptodomex==3.22.0
|
||||
@ -528,9 +530,12 @@ pydantic==2.11.5
|
||||
# datamodel-code-generator
|
||||
# mistral-common
|
||||
# mteb
|
||||
# pydantic-extra-types
|
||||
# ray
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-extra-types==2.10.5
|
||||
# via mistral-common
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyparsing==3.2.0
|
||||
@ -835,6 +840,7 @@ typing-extensions==4.12.2
|
||||
# pqdm
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pydantic-extra-types
|
||||
# torch
|
||||
# typer
|
||||
# typing-inspection
|
||||
|
||||
3
setup.py
3
setup.py
@ -692,7 +692,8 @@ setup(
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
||||
"audio": ["librosa", "soundfile"], # Required for audio processing
|
||||
"audio": ["librosa", "soundfile",
|
||||
"mistral_common[audio]"], # Required for audio processing
|
||||
"video": [] # Kept for backwards compatibility
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
|
||||
@ -17,6 +17,11 @@ from vllm.assets.audio import AudioAsset
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MISTRAL_FORMAT_ARGS = [
|
||||
"--tokenizer_mode", "mistral", "--config_format", "mistral",
|
||||
"--load_format", "mistral"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mary_had_lamb():
|
||||
@ -33,9 +38,18 @@ def winning_call():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio(mary_had_lamb):
|
||||
model_name = "openai/whisper-large-v3-turbo"
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
|
||||
async def test_basic_audio(mary_had_lamb, model_name):
|
||||
server_args = ["--enforce-eager"]
|
||||
|
||||
if model_name.startswith("mistralai"):
|
||||
server_args += MISTRAL_FORMAT_ARGS
|
||||
|
||||
# TODO(PATRICK) - REMOVE AFTER RELEASE
|
||||
return # skip for now
|
||||
|
||||
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
@ -65,10 +79,13 @@ async def test_bad_requests(mary_had_lamb):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_audio_request(mary_had_lamb):
|
||||
model_name = "openai/whisper-large-v3-turbo"
|
||||
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
|
||||
async def test_long_audio_request(mary_had_lamb, model_name):
|
||||
server_args = ["--enforce-eager"]
|
||||
|
||||
if model_name.startswith("openai"):
|
||||
return
|
||||
|
||||
mary_had_lamb.seek(0)
|
||||
audio, sr = librosa.load(mary_had_lamb)
|
||||
# Add small silence after each audio for repeatability in the split process
|
||||
@ -87,7 +104,8 @@ async def test_long_audio_request(mary_had_lamb):
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(transcription)['text']
|
||||
assert out.count("Mary had a little lamb") == 10
|
||||
counts = out.count("Mary had a little lamb")
|
||||
assert counts == 10, counts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -440,6 +440,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
|
||||
trust_remote_code=True), # noqa: E501
|
||||
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
||||
"VoxtralForConditionalGeneration": _HfExamplesInfo("mistralai/Voxtral-Mini-3B-2507", is_available_online=False, tokenizer_mode="mistral"), # noqa: E501
|
||||
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
|
||||
|
||||
# [Cross-encoder]
|
||||
@ -513,4 +514,4 @@ class HfExampleModels:
|
||||
raise ValueError(f"No example model defined for {model_id}")
|
||||
|
||||
|
||||
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
|
||||
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
|
||||
|
||||
@ -112,6 +112,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
prompt = self.model_cls.get_generation_prompt(
|
||||
audio=chunk,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language=lang,
|
||||
task_type=self.task_type,
|
||||
request_prompt=request.prompt)
|
||||
|
||||
@ -722,7 +722,8 @@ class SupportsTranscription(Protocol):
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig, language: str,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig, language: str,
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
"""Get the prompt for the ASR model.
|
||||
|
||||
@ -231,6 +231,7 @@ _MULTIMODAL_MODELS = {
|
||||
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
||||
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
|
||||
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
|
||||
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
||||
# [Encoder-decoder]
|
||||
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
||||
|
||||
691
vllm/model_executor/models/voxtral.py
Normal file
691
vllm/model_executor/models/voxtral.py
Normal file
@ -0,0 +1,691 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from math import ceil
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mistral_common.audio import mel_filter_bank
|
||||
from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio,
|
||||
TextChunk, UserMessage)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
||||
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
|
||||
from transformers import TensorType, WhisperConfig
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsPP
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.whisper import (
|
||||
WhisperEncoder, WhisperForConditionalGeneration)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, MultiModalHashes,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription)
|
||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class VoxtralProcessorAdapter:
|
||||
"""
|
||||
Provide a HF-compatible interface for
|
||||
:class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: MistralTokenizer) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def _audio_processor(self) -> AudioEncoder:
|
||||
audio_encoder = self.tokenizer.instruct.audio_encoder
|
||||
assert isinstance(audio_encoder, AudioEncoder)
|
||||
return audio_encoder
|
||||
|
||||
@cached_property
|
||||
def audio_token_id(self) -> int:
|
||||
return self._audio_processor.special_ids.audio
|
||||
|
||||
@cached_property
|
||||
def begin_audio_token_id(self) -> int:
|
||||
return self._audio_processor.special_ids.begin_audio
|
||||
|
||||
# @cached_property
|
||||
# def begin_transcript_token_id(self) -> int:
|
||||
# return self._audio_processor.special_ids.begin_transcript
|
||||
|
||||
# @cached_property
|
||||
# def end_transcript_token_id(self) -> int:
|
||||
# return self._audio_processor.special_ids.end_transcript
|
||||
|
||||
@cached_property
|
||||
def sampling_rate(self) -> int:
|
||||
return self._audio_processor.audio_config.sampling_rate
|
||||
|
||||
@cached_property
|
||||
def frame_rate(self) -> float:
|
||||
return self._audio_processor.audio_config.frame_rate
|
||||
|
||||
def get_num_audio_tokens(
|
||||
self,
|
||||
audio_length: int,
|
||||
) -> int:
|
||||
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
|
||||
audio_length, self.sampling_rate)
|
||||
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||
audios: Optional[Union[np.ndarray, list[np.ndarray]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if audios is None:
|
||||
audios = []
|
||||
if not isinstance(audios, list):
|
||||
audios = [audios]
|
||||
|
||||
if not audios:
|
||||
input_ids = self.tokenizer(text).input_ids
|
||||
return {"input_ids": torch.tensor(input_ids)}
|
||||
|
||||
# Allow dummy text, which is used for profiling as well as token inputs
|
||||
if any(len(t) > 0 for t in text):
|
||||
raise ValueError(
|
||||
"You've passed text inputs instead of token inputs. "
|
||||
"Make sure to process your input via `mistral_common`'s "
|
||||
"tokenizer or pass a chat completion request. "
|
||||
"For more info, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/8411.")
|
||||
|
||||
audios_tokens = list[torch.Tensor]()
|
||||
audios_processed = list[torch.Tensor]()
|
||||
for audio in audios:
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert audio.ndim == 1
|
||||
|
||||
# pad if necessary
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
|
||||
audio_tokens = [
|
||||
self.begin_audio_token_id
|
||||
] + [self.audio_token_id] * self.get_num_audio_tokens(len(audio))
|
||||
|
||||
audios_tokens.append(torch.tensor(audio_tokens))
|
||||
audios_processed.append(torch.tensor(audio))
|
||||
|
||||
return {
|
||||
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
|
||||
"audio_arrays": audios_processed,
|
||||
}
|
||||
|
||||
|
||||
class VoxtralProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_tokenizer(self) -> MistralTokenizer:
|
||||
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
|
||||
if not isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError("This model requires `--tokenizer-mode mistral`")
|
||||
|
||||
return tokenizer
|
||||
|
||||
def get_hf_processor(self) -> VoxtralProcessorAdapter:
|
||||
return VoxtralProcessorAdapter(self.get_tokenizer())
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": 5} # Performance tends to degrade after 5
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"audio": self.get_max_audio_tokens()}
|
||||
|
||||
def get_max_audio_tokens(self) -> int:
|
||||
return self.ctx.model_config.max_model_len
|
||||
|
||||
def get_max_audio_array_len(self) -> int:
|
||||
processor = self.get_hf_processor()
|
||||
return self.get_max_audio_tokens() * int(
|
||||
processor.sampling_rate // processor.frame_rate)
|
||||
|
||||
|
||||
class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
target_length = self.info.get_max_audio_array_len()
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=target_length, num_audios=num_audios)
|
||||
}
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
|
||||
dummy_audios = dummy_mm_data.get("audio", [])
|
||||
|
||||
audio_chunks: list[AudioChunk] = []
|
||||
format = "wav"
|
||||
for audio in dummy_audios:
|
||||
audio_item = Audio(
|
||||
audio_array=audio,
|
||||
sampling_rate=self.info.get_hf_processor().sampling_rate,
|
||||
format=format,
|
||||
)
|
||||
chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
|
||||
audio_chunks.append(chunk)
|
||||
|
||||
request = ChatCompletionRequest(messages=[
|
||||
UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
|
||||
])
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
dummy_tokens = res.tokens
|
||||
# whixtral tokenizer adds padding to the audio
|
||||
# so we need to update the audio arrays
|
||||
dummy_mm_data["audio"] = [a.audio_array for a in res.audios]
|
||||
|
||||
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
|
||||
|
||||
|
||||
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
||||
):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(audio_arrays=MultiModalFieldConfig.batched("audio"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
audio_id = processor.audio_token_id
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
audio_len = audios.get_audio_length(item_idx)
|
||||
|
||||
nb_audio_tokens = processor.get_num_audio_tokens(audio_len)
|
||||
|
||||
return [audio_id] * nb_audio_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target="", # Never match the prompt (see below note)
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
||||
prompt_ids, mm_kwargs, mm_hashes, _ = super(
|
||||
)._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
# NOTE: The tokens are already inserted by the chat template
|
||||
return prompt_ids, mm_kwargs, mm_hashes, True
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
sampling_rate = self.info.get_hf_processor().sampling_rate
|
||||
return MultiModalDataParser(target_sr=sampling_rate)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(VoxtralMultiModalProcessor,
|
||||
info=VoxtralProcessingInfo,
|
||||
dummy_inputs=VoxtralDummyInputsBuilder)
|
||||
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP, SupportsTranscription):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.config = config
|
||||
self.downsample_factor = self.config.audio_config.downsample_factor
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.whisper_encoder = VoxtralEncoderModel(
|
||||
vllm_config.with_hf_config(config.audio_config),
|
||||
prefix=maybe_prefix(prefix, "whisper_encoder"),
|
||||
)
|
||||
self.audio_language_adapter = AudioLanguageAdapter(
|
||||
hidden_size=config.audio_config.d_model * self.downsample_factor,
|
||||
dim=config.text_config.hidden_size,
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
audio_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
audio_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs
|
||||
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...],
|
||||
None]:
|
||||
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
|
||||
if audio_inputs is None:
|
||||
return None
|
||||
|
||||
audio_embeddings = self.whisper_encoder(audio_inputs)
|
||||
|
||||
for i, audio_embedding in enumerate(audio_embeddings):
|
||||
seq_len, dim = audio_embedding.shape
|
||||
# Pad such that seq_len is divisible by downsample_factor
|
||||
target_seq_len = self.downsample_factor * math.ceil(
|
||||
seq_len / self.downsample_factor)
|
||||
audio_embedding = torch.nn.functional.pad(
|
||||
audio_embedding,
|
||||
(0, 0, 0, target_seq_len - seq_len),
|
||||
)
|
||||
audio_embeddings[i] = audio_embedding.reshape(
|
||||
target_seq_len // self.downsample_factor,
|
||||
dim * self.downsample_factor)
|
||||
|
||||
# Concat, project and resplit
|
||||
audio_embeddings_packed = torch.cat(audio_embeddings, dim=0)
|
||||
audio_embeddings_packed = self.audio_language_adapter(
|
||||
audio_embeddings_packed)
|
||||
audio_embeddings = torch.split(audio_embeddings_packed,
|
||||
[a.shape[0] for a in audio_embeddings],
|
||||
dim=0)
|
||||
|
||||
return audio_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
audio_encoder = self.tokenizer.instruct.audio_encoder
|
||||
audio_tok_id = audio_encoder.audio_token
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id)
|
||||
return inputs_embeds
|
||||
|
||||
def _parse_and_validate_audio_arrays(
|
||||
self, **kwargs: object) -> Union[list[torch.Tensor], None]:
|
||||
audio_arrays = kwargs.pop("audio_arrays", None)
|
||||
if audio_arrays is None:
|
||||
return None
|
||||
|
||||
if not isinstance(audio_arrays, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_arrays. "
|
||||
f"Got type: {type(audio_arrays)}")
|
||||
|
||||
audio_arrays = flatten_bn(audio_arrays)
|
||||
if isinstance(audio_arrays, torch.Tensor):
|
||||
audio_arrays = list(audio_arrays.unbind(0))
|
||||
return audio_arrays
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(cls, model_config: ModelConfig,
|
||||
task_type: str) -> SpeechToTextConfig:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio_config = tokenizer.instruct.audio_encoder.audio_config
|
||||
max_audio_clip_s = audio_config.chunk_length_s
|
||||
sample_rate = audio_config.sampling_rate
|
||||
return SpeechToTextConfig(
|
||||
max_audio_clip_s=max_audio_clip_s,
|
||||
sample_rate=sample_rate,
|
||||
# mistral_common and whisper encoder take care of chunking
|
||||
min_energy_split_window_size=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# for speech-to-text transcription
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig, language: str,
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio = Audio(audio, int(stt_config.sample_rate),
|
||||
format="wav") # lossless
|
||||
req = TranscriptionRequest(model=model_config.model,
|
||||
audio=RawAudio.from_audio(audio),
|
||||
language=language)
|
||||
|
||||
tokenized = tokenizer.instruct.encode_transcription(req)
|
||||
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
|
||||
prompts_dict = {"multi_modal_data": {"audio": audio}}
|
||||
prompts_dict["prompt_token_ids"] = tokenized.tokens
|
||||
return cast(PromptType, prompts_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str) -> bool:
|
||||
# same as whisper
|
||||
return WhisperForConditionalGeneration.validate_language(language)
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig) -> Optional[int]:
|
||||
"""
|
||||
Map from audio duration to number of audio tokens produced by the ASR
|
||||
model, without running a forward pass.
|
||||
This is used for estimating the amount of processing for this audio.
|
||||
"""
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
adapter = VoxtralProcessorAdapter(tokenizer)
|
||||
return adapter.get_num_audio_tokens(
|
||||
int(audio_duration_s * stt_config.sample_rate))
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
# fmt: off
|
||||
remapping_rules = [
|
||||
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
||||
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
||||
(r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"), # noqa: E501
|
||||
(r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"), # noqa: E501
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
audio_params = dict(
|
||||
nn.ModuleDict({
|
||||
"audio_language_adapter":
|
||||
self.audio_language_adapter,
|
||||
}).named_parameters())
|
||||
|
||||
loaded_weights = set()
|
||||
|
||||
def llm_weights_generator():
|
||||
nonlocal loaded_weights
|
||||
for name, w in weights:
|
||||
is_encoder = (
|
||||
name.startswith("mm_whisper_embeddings") and
|
||||
not name.startswith("mm_whisper_embeddings.tok_embeddings")
|
||||
and not name.startswith(
|
||||
"mm_whisper_embeddings.audio_language_projection"))
|
||||
|
||||
for pattern, repl in remapping_rules:
|
||||
if re.fullmatch(pattern, name):
|
||||
name = re.sub(pattern, repl, name)
|
||||
|
||||
if is_encoder:
|
||||
name = self.whisper_encoder.load_weight((name, w))
|
||||
loaded_weights.add(f"whisper_encoder.{name}")
|
||||
continue
|
||||
|
||||
if name in audio_params:
|
||||
param = audio_params[name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
loaded_weights.add(name)
|
||||
else:
|
||||
yield (name, w)
|
||||
|
||||
for name in self.language_model.load_weights(llm_weights_generator()):
|
||||
loaded_weights.add(f"language_model.{name}")
|
||||
|
||||
# potentially manually add position embeddings
|
||||
sin_key = "whisper_encoder.whisper_encoder.embed_positions.weight"
|
||||
if sin_key not in loaded_weights:
|
||||
# make sure we don't hit an error here
|
||||
loaded_weights.add(sin_key)
|
||||
|
||||
return loaded_weights
|
||||
|
||||
|
||||
class AudioLanguageAdapter(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.w_in = nn.Linear(hidden_size, dim, bias=False)
|
||||
self.gelu = nn.GELU()
|
||||
self.w_out = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w_out(self.gelu(self.w_in(x)))
|
||||
|
||||
|
||||
class VoxtralEncoderModel(nn.Module):
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
# fmt: off
|
||||
mistral_remapping = [
|
||||
(r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501
|
||||
(r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501
|
||||
(r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
|
||||
self.dtype: torch.dtype = vllm_config.model_config.dtype
|
||||
self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "whisper_encoder"),
|
||||
is_standalone_encoder=True,
|
||||
init_in_fp32=True)
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=1 + self.config.window_size // 2,
|
||||
num_mel_bins=self.config.num_mel_bins,
|
||||
min_frequency=0.0,
|
||||
max_frequency=8000.0,
|
||||
sampling_rate=self.config.sampling_rate,
|
||||
)
|
||||
self.mel_filters = torch.tensor(mel_filters, dtype=torch.float32)
|
||||
|
||||
def compute_whisper_melspec(
|
||||
self,
|
||||
audio_waveforms: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
input_dtype = audio_waveforms.dtype
|
||||
window = torch.hann_window(self.config.window_size).to(
|
||||
audio_waveforms.device)
|
||||
stft = torch.stft(
|
||||
audio_waveforms,
|
||||
self.config.window_size,
|
||||
self.config.hop_length,
|
||||
window=window,
|
||||
return_complex=True,
|
||||
)
|
||||
magnitudes = stft[..., :-1].abs()**2
|
||||
mel_spec = self.mel_filters.T @ magnitudes
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec.to(input_dtype)
|
||||
|
||||
@property
|
||||
def downsample_factor(self) -> int:
|
||||
return self.whisper_encoder.conv1.stride[
|
||||
0] * self.whisper_encoder.conv2.stride[0]
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
return self.config.max_source_positions * self.downsample_factor
|
||||
|
||||
def prepare_inputs_for_conv(
|
||||
self,
|
||||
audio_waveforms: list[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, list[int]]:
|
||||
assert isinstance(audio_waveforms, list)
|
||||
# list[num_mel_bins, seq_len]
|
||||
input_features = [
|
||||
self.compute_whisper_melspec(audio).to(self.dtype)
|
||||
for audio in audio_waveforms
|
||||
]
|
||||
|
||||
chunked_features: list[torch.Tensor] = []
|
||||
chunks_per_example: list[int] = []
|
||||
for feature in input_features:
|
||||
chunks = feature.split(self.chunk_size, dim=-1)
|
||||
chunked_features += chunks
|
||||
chunks_per_example.append(len(chunks))
|
||||
|
||||
# [total_num_chunks, num_mel_bins, chunk_size]
|
||||
return torch.stack(chunked_features), chunks_per_example
|
||||
|
||||
def forward(
|
||||
self, input_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||
) -> list[torch.Tensor]:
|
||||
if not isinstance(input_features, list):
|
||||
input_features = [input_features]
|
||||
|
||||
# Split long inputs into chunks
|
||||
input_embeds, chunks_per_example = (
|
||||
self.prepare_inputs_for_conv(input_features))
|
||||
|
||||
# [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size]
|
||||
out = self.whisper_encoder([input_embeds])
|
||||
|
||||
# Re-concatenate the chunks
|
||||
chunk_idx = 0
|
||||
results = []
|
||||
for n_chunks in chunks_per_example:
|
||||
result = out[chunk_idx:chunk_idx + n_chunks].flatten(0, 1)
|
||||
results.append(result)
|
||||
chunk_idx += n_chunks
|
||||
|
||||
return results
|
||||
|
||||
def load_weight(self, weight: tuple[str, torch.Tensor]) -> str:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
name, loaded_weight = weight
|
||||
for pattern, repl in self.mistral_remapping:
|
||||
if re.fullmatch(pattern, name):
|
||||
name = re.sub(pattern, repl, name)
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
return name
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
@ -13,6 +14,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
||||
VllmConfig)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -26,6 +28,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
@ -178,6 +181,7 @@ class WhisperAttention(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
standalone_encoder: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -213,16 +217,24 @@ class WhisperAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=self.attn_type,
|
||||
)
|
||||
if standalone_encoder:
|
||||
self.attn = MultiHeadAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
else:
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=self.attn_type,
|
||||
)
|
||||
|
||||
def _init_qkv(
|
||||
self,
|
||||
@ -357,7 +369,11 @@ class WhisperMLP(nn.Module):
|
||||
|
||||
class WhisperEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
is_standalone_encoder: bool = False):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -371,6 +387,7 @@ class WhisperEncoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
standalone_encoder=is_standalone_encoder,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.mlp = WhisperMLP(
|
||||
@ -462,10 +479,16 @@ class WhisperDecoderLayer(nn.Module):
|
||||
|
||||
class WhisperEncoder(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
is_standalone_encoder: bool = False,
|
||||
init_in_fp32: bool = False):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
embed_dim = config.d_model
|
||||
self.is_standalone_encoder = is_standalone_encoder
|
||||
self.num_mel_bins = config.num_mel_bins
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = (math.sqrt(embed_dim)
|
||||
@ -480,17 +503,25 @@ class WhisperEncoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions,
|
||||
embed_dim)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.encoder_layers,
|
||||
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.layers"),
|
||||
prefix=f"{prefix}.layers",
|
||||
is_standalone_encoder=
|
||||
is_standalone_encoder),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
with torch.no_grad():
|
||||
maybe_fp32_init_ctx = set_default_torch_dtype(
|
||||
torch.float32) if init_in_fp32 else nullcontext()
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions,
|
||||
embed_dim)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape))
|
||||
|
||||
@ -499,8 +530,10 @@ class WhisperEncoder(nn.Module):
|
||||
for features in input_features:
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
embeds = embeds.permute(1, 0)
|
||||
embeds = embeds + self.embed_positions.weight[:embeds.size(0), :]
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (embeds +
|
||||
self.embed_positions.weight[:embeds.size(-2), :]).to(
|
||||
embeds.dtype)
|
||||
hidden_states.append(embeds)
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
|
||||
@ -792,10 +825,14 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig, language: str,
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig, # not needed here
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: str,
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
# Whisper does not support encoder prompt.
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import PretrainedConfig, WhisperConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -24,9 +24,21 @@ def adapt_config_dict(config_dict: dict[str, Any],
|
||||
|
||||
if bool(config_dict.get("yarn")):
|
||||
config_dict = _remap_mistral_yarn_args(config_dict)
|
||||
if bool((config_dict.get("multimodal") or {}).get("vision_encoder_args")
|
||||
or config_dict.get("vision_encoder")):
|
||||
|
||||
is_vision = ((config_dict.get("multimodal")
|
||||
or {}).get("vision_encoder_args")
|
||||
or config_dict.get("vision_encoder"))
|
||||
is_audio = bool(
|
||||
((config_dict.get("multimodal") or {}).get("whisper_model_args")
|
||||
or {}).get("encoder_args"))
|
||||
|
||||
assert not (is_vision and is_audio), \
|
||||
"Vision and audio are mutually exclusive"
|
||||
|
||||
if is_vision:
|
||||
config_dict = _remap_mistral_vision_args(config_dict)
|
||||
if is_audio:
|
||||
config_dict = _remap_mistral_audio_args(config_dict)
|
||||
|
||||
config = PretrainedConfig.from_dict(config_dict)
|
||||
|
||||
@ -118,3 +130,35 @@ def _remap_mistral_quantization_args(config: dict) -> dict:
|
||||
config["quantization_config"] = quantization_config
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _remap_mistral_audio_args(config: dict) -> dict:
|
||||
whisper_args = config["multimodal"].pop("whisper_model_args")
|
||||
encoder_args = whisper_args["encoder_args"]
|
||||
downsample_args = whisper_args["downsample_args"]
|
||||
|
||||
quant_config = config.get("quantization_config")
|
||||
config = {
|
||||
"model_type":
|
||||
"whixtral",
|
||||
"architectures": ["VoxtralForConditionalGeneration"],
|
||||
"text_config":
|
||||
PretrainedConfig.from_dict(config),
|
||||
"audio_config":
|
||||
WhisperConfig(
|
||||
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
|
||||
window_size=encoder_args["audio_encoding_args"]["window_size"],
|
||||
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
|
||||
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
|
||||
downsample_factor=downsample_args["downsample_factor"],
|
||||
d_model=encoder_args["dim"],
|
||||
encoder_layers=encoder_args["n_layers"],
|
||||
encoder_ffn_dim=encoder_args["hidden_dim"],
|
||||
encoder_attention_heads=encoder_args["n_heads"],
|
||||
vocab_size=encoder_args["vocab_size"],
|
||||
max_source_positions=encoder_args["max_source_positions"],
|
||||
)
|
||||
}
|
||||
if quant_config:
|
||||
config["quantization_config"] = quant_config
|
||||
return config
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user