mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
846 lines
30 KiB
Python
846 lines
30 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import math
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from functools import cached_property
|
|
from math import ceil
|
|
from typing import Literal, cast
|
|
|
|
import numpy as np
|
|
import regex as re
|
|
import torch
|
|
import torch.nn as nn
|
|
from mistral_common.audio import mel_filter_bank
|
|
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
|
|
from mistral_common.protocol.instruct.messages import UserMessage
|
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
|
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
|
|
from transformers import BatchFeature, TensorType, WhisperConfig
|
|
from transformers.tokenization_utils_base import TextInput
|
|
|
|
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
from vllm.inputs.data import PromptType
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models import SupportsPP
|
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|
from vllm.model_executor.models.whisper import WhisperEncoder
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
MultiModalUUIDDict,
|
|
NestedTensors,
|
|
)
|
|
from vllm.multimodal.parse import (
|
|
AudioProcessorItems,
|
|
MultiModalDataItems,
|
|
MultiModalDataParser,
|
|
)
|
|
from vllm.multimodal.processing import (
|
|
BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
MultiModalProcessingInfo,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.transformers_utils.tokenizer import (
|
|
MistralTokenizer,
|
|
cached_tokenizer_from_config,
|
|
)
|
|
|
|
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
|
|
from .utils import init_vllm_registered_model, maybe_prefix
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
ISO639_1_SUPPORTED_LANGS = {
|
|
"ar": "Arabic",
|
|
"nl": "Dutch",
|
|
"en": "English",
|
|
"fr": "French",
|
|
"de": "German",
|
|
"hi": "Hindi",
|
|
"it": "Italian",
|
|
"pt": "Portuguese",
|
|
"es": "Spanish",
|
|
}
|
|
|
|
|
|
class VoxtralProcessorAdapter:
|
|
"""
|
|
Provide a HF-compatible interface for
|
|
:class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
|
|
"""
|
|
|
|
def __init__(self, tokenizer: MistralTokenizer) -> None:
|
|
super().__init__()
|
|
self.tokenizer = tokenizer
|
|
|
|
@cached_property
|
|
def _audio_processor(self) -> AudioEncoder:
|
|
audio_encoder = self.tokenizer.instruct.audio_encoder
|
|
assert isinstance(audio_encoder, AudioEncoder)
|
|
return audio_encoder
|
|
|
|
@cached_property
|
|
def audio_token_id(self) -> int:
|
|
return self._audio_processor.special_ids.audio
|
|
|
|
@cached_property
|
|
def begin_audio_token_id(self) -> int:
|
|
return self._audio_processor.special_ids.begin_audio
|
|
|
|
# @cached_property
|
|
# def begin_transcript_token_id(self) -> int:
|
|
# return self._audio_processor.special_ids.begin_transcript
|
|
|
|
# @cached_property
|
|
# def end_transcript_token_id(self) -> int:
|
|
# return self._audio_processor.special_ids.end_transcript
|
|
|
|
@cached_property
|
|
def sampling_rate(self) -> int:
|
|
return self._audio_processor.audio_config.sampling_rate
|
|
|
|
@cached_property
|
|
def frame_rate(self) -> float:
|
|
return self._audio_processor.audio_config.frame_rate
|
|
|
|
def get_num_audio_tokens(
|
|
self,
|
|
audio_length: int,
|
|
) -> int:
|
|
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
|
|
audio_length, self.sampling_rate
|
|
)
|
|
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
|
|
|
|
def __call__(
|
|
self,
|
|
text: TextInput | list[TextInput] | None = None,
|
|
audios: np.ndarray | list[np.ndarray] | None = None,
|
|
return_tensors: str | TensorType | None = None,
|
|
**kwargs,
|
|
) -> Mapping[str, NestedTensors]:
|
|
if text is None:
|
|
text = []
|
|
if not isinstance(text, list):
|
|
text = [text]
|
|
if audios is None:
|
|
audios = []
|
|
if not isinstance(audios, list):
|
|
audios = [audios]
|
|
|
|
if not audios:
|
|
input_ids = self.tokenizer(text).input_ids
|
|
return {"input_ids": torch.tensor(input_ids)}
|
|
|
|
# Allow dummy text, which is used for profiling as well as token inputs
|
|
if any(len(t) > 0 for t in text):
|
|
raise ValueError(
|
|
"You've passed text inputs instead of token inputs. "
|
|
"Make sure to process your input via `mistral_common`'s "
|
|
"tokenizer or pass a chat completion request. "
|
|
"For more info, see: "
|
|
"https://github.com/vllm-project/vllm/issues/8411."
|
|
)
|
|
|
|
audios_tokens = list[torch.Tensor]()
|
|
audios_processed = list[torch.Tensor]()
|
|
for audio in audios:
|
|
assert isinstance(audio, np.ndarray)
|
|
assert audio.ndim == 1
|
|
|
|
# pad if necessary
|
|
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
|
|
|
audio_tokens = [self.begin_audio_token_id] + [
|
|
self.audio_token_id
|
|
] * self.get_num_audio_tokens(len(audio))
|
|
|
|
audios_tokens.append(torch.tensor(audio_tokens))
|
|
audios_processed.append(torch.tensor(audio))
|
|
|
|
return BatchFeature(
|
|
{
|
|
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
|
|
"audio_arrays": audios_processed,
|
|
}
|
|
)
|
|
|
|
|
|
class VoxtralProcessingInfo(BaseProcessingInfo):
|
|
def get_tokenizer(self) -> MistralTokenizer:
|
|
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
|
|
if not isinstance(tokenizer, MistralTokenizer):
|
|
raise ValueError("This model requires `--tokenizer-mode mistral`")
|
|
|
|
return tokenizer
|
|
|
|
def get_hf_processor(self) -> VoxtralProcessorAdapter:
|
|
return VoxtralProcessorAdapter(self.get_tokenizer())
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
return {"audio": 5} # Performance tends to degrade after 5
|
|
|
|
def get_mm_max_tokens_per_item(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> Mapping[str, int]:
|
|
return {"audio": self.get_max_audio_tokens()}
|
|
|
|
def get_max_audio_tokens(self) -> int:
|
|
return self.ctx.model_config.max_model_len
|
|
|
|
def get_max_audio_array_len(self) -> int:
|
|
processor = self.get_hf_processor()
|
|
return self.get_max_audio_tokens() * int(
|
|
processor.sampling_rate // processor.frame_rate
|
|
)
|
|
|
|
|
|
class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
return ""
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
) -> MultiModalDataDict:
|
|
num_audios = mm_counts.get("audio", 0)
|
|
|
|
target_length = self.info.get_max_audio_array_len()
|
|
|
|
audio_overrides = mm_options.get("audio") if mm_options else None
|
|
|
|
return {
|
|
"audio": self._get_dummy_audios(
|
|
length=target_length, num_audios=num_audios, overrides=audio_overrides
|
|
)
|
|
}
|
|
|
|
def get_dummy_processor_inputs(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
) -> ProcessorInputs:
|
|
tokenizer = self.info.get_tokenizer()
|
|
|
|
dummy_text = self.get_dummy_text(mm_counts)
|
|
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
|
dummy_audios = dummy_mm_data.get("audio", [])
|
|
|
|
audio_chunks: list[AudioChunk] = []
|
|
format = "wav"
|
|
for audio in dummy_audios:
|
|
audio_item = Audio(
|
|
audio_array=audio,
|
|
sampling_rate=self.info.get_hf_processor().sampling_rate,
|
|
format=format,
|
|
)
|
|
chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
|
|
audio_chunks.append(chunk)
|
|
|
|
request = ChatCompletionRequest(
|
|
messages=[
|
|
UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
|
|
]
|
|
)
|
|
res = tokenizer.mistral.encode_chat_completion(request)
|
|
dummy_tokens = res.tokens
|
|
# whixtral tokenizer adds padding to the audio
|
|
# so we need to update the audio arrays
|
|
dummy_mm_data["audio"] = [a.audio_array for a in res.audios]
|
|
|
|
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
|
|
|
|
|
|
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: Mapping[str, NestedTensors],
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(audio_arrays=MultiModalFieldConfig.batched("audio"))
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
|
|
audio_id = processor.audio_token_id
|
|
|
|
def get_replacement(item_idx: int):
|
|
audios = mm_items.get_items("audio", AudioProcessorItems)
|
|
audio_len = audios.get_audio_length(item_idx)
|
|
|
|
nb_audio_tokens = processor.get_num_audio_tokens(audio_len)
|
|
|
|
return [audio_id] * nb_audio_tokens
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="audio",
|
|
target="", # Never match the prompt (see below note)
|
|
replacement=get_replacement,
|
|
),
|
|
]
|
|
|
|
def _cached_apply_hf_processor(
|
|
self,
|
|
prompt: str | list[int],
|
|
mm_data_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
tokenization_kwargs: Mapping[str, object],
|
|
mm_uuids: MultiModalUUIDDict | None = None,
|
|
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
|
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
|
|
prompt=prompt,
|
|
mm_data_items=mm_data_items,
|
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
|
tokenization_kwargs=tokenization_kwargs,
|
|
mm_uuids=mm_uuids,
|
|
)
|
|
|
|
# NOTE: The tokens are already inserted by the chat template
|
|
return prompt_ids, mm_info, True
|
|
|
|
def _get_data_parser(self) -> MultiModalDataParser:
|
|
sampling_rate = self.info.get_hf_processor().sampling_rate
|
|
return MultiModalDataParser(target_sr=sampling_rate)
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
VoxtralMultiModalProcessor,
|
|
info=VoxtralProcessingInfo,
|
|
dummy_inputs=VoxtralDummyInputsBuilder,
|
|
)
|
|
class VoxtralForConditionalGeneration(
|
|
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
|
|
):
|
|
merge_by_field_config = True
|
|
|
|
supported_languages = ISO639_1_SUPPORTED_LANGS
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
|
|
|
# update quant config to so that ignored module and target module names
|
|
# match the vLLM model names
|
|
if hasattr(vllm_config, "quant_config"):
|
|
vllm_config.quant_config = self.maybe_update_quant_config(
|
|
vllm_config.quant_config
|
|
)
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
self.config = config
|
|
self.downsample_factor = self.config.audio_config.downsample_factor
|
|
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
)
|
|
self.whisper_encoder = VoxtralEncoderModel(
|
|
vllm_config.with_hf_config(config.audio_config),
|
|
prefix=maybe_prefix(prefix, "whisper_encoder"),
|
|
)
|
|
self.audio_language_adapter = AudioLanguageAdapter(
|
|
hidden_size=config.audio_config.d_model * self.downsample_factor,
|
|
dim=config.text_config.hidden_size,
|
|
)
|
|
|
|
def get_language_model(self) -> torch.nn.Module:
|
|
return self.language_model
|
|
|
|
def get_mm_mapping(self) -> MultiModelKeys:
|
|
"""Get module prefix for multimodal models to filter LoRA modules."""
|
|
return MultiModelKeys.from_string_field(
|
|
language_model="language_model",
|
|
connector="audio_language_adapter",
|
|
tower_model=["whisper_encoder"],
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
hidden_states = self.language_model.model(
|
|
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
def get_multimodal_embeddings(
|
|
self, **kwargs
|
|
) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
|
|
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
|
|
if audio_inputs is None:
|
|
return None
|
|
|
|
audio_embeddings = self.whisper_encoder(audio_inputs)
|
|
|
|
for i, audio_embedding in enumerate(audio_embeddings):
|
|
seq_len, dim = audio_embedding.shape
|
|
# Pad such that seq_len is divisible by downsample_factor
|
|
target_seq_len = self.downsample_factor * math.ceil(
|
|
seq_len / self.downsample_factor
|
|
)
|
|
audio_embedding = torch.nn.functional.pad(
|
|
audio_embedding,
|
|
(0, 0, 0, target_seq_len - seq_len),
|
|
)
|
|
audio_embeddings[i] = audio_embedding.reshape(
|
|
target_seq_len // self.downsample_factor, dim * self.downsample_factor
|
|
)
|
|
|
|
# Concat, project and resplit
|
|
audio_embeddings_packed = torch.cat(audio_embeddings, dim=0)
|
|
audio_embeddings_packed = self.audio_language_adapter(audio_embeddings_packed)
|
|
audio_embeddings = torch.split(
|
|
audio_embeddings_packed, [a.shape[0] for a in audio_embeddings], dim=0
|
|
)
|
|
|
|
return audio_embeddings
|
|
|
|
def _parse_and_validate_audio_arrays(
|
|
self, **kwargs: object
|
|
) -> list[torch.Tensor] | None:
|
|
audio_arrays = kwargs.pop("audio_arrays", None)
|
|
if audio_arrays is None:
|
|
return None
|
|
|
|
if not isinstance(audio_arrays, (torch.Tensor, list)):
|
|
raise ValueError(
|
|
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
|
|
)
|
|
|
|
if isinstance(audio_arrays, torch.Tensor):
|
|
audio_arrays = list(audio_arrays.unbind(0))
|
|
return audio_arrays
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
return self.language_model.compute_logits(hidden_states)
|
|
|
|
@classmethod
|
|
def get_speech_to_text_config(
|
|
cls, model_config: ModelConfig, task_type: str
|
|
) -> SpeechToTextConfig:
|
|
tokenizer = cached_tokenizer_from_config(model_config)
|
|
audio_config = tokenizer.instruct.audio_encoder.audio_config
|
|
max_audio_clip_s = audio_config.chunk_length_s
|
|
sample_rate = audio_config.sampling_rate
|
|
return SpeechToTextConfig(
|
|
max_audio_clip_s=max_audio_clip_s,
|
|
sample_rate=sample_rate,
|
|
# mistral_common and whisper encoder take care of chunking
|
|
min_energy_split_window_size=None,
|
|
)
|
|
|
|
@classmethod
|
|
# for speech-to-text transcription
|
|
def get_generation_prompt(
|
|
cls,
|
|
audio: np.ndarray,
|
|
model_config: ModelConfig,
|
|
stt_config: SpeechToTextConfig,
|
|
language: str | None,
|
|
task_type: Literal["transcribe", "translate"],
|
|
request_prompt: str,
|
|
to_language: str | None,
|
|
) -> PromptType:
|
|
tokenizer = cached_tokenizer_from_config(model_config)
|
|
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
|
|
req = TranscriptionRequest(
|
|
model=model_config.model,
|
|
audio=RawAudio.from_audio(audio),
|
|
language=language,
|
|
)
|
|
|
|
tokenized = tokenizer.instruct.encode_transcription(req)
|
|
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
|
|
prompts_dict = {"multi_modal_data": {"audio": audio}}
|
|
prompts_dict["prompt_token_ids"] = tokenized.tokens
|
|
return cast(PromptType, prompts_dict)
|
|
|
|
@classmethod
|
|
def get_num_audio_tokens(
|
|
cls,
|
|
audio_duration_s: float,
|
|
stt_config: SpeechToTextConfig,
|
|
model_config: ModelConfig,
|
|
) -> int | None:
|
|
"""
|
|
Map from audio duration to number of audio tokens produced by the ASR
|
|
model, without running a forward pass.
|
|
This is used for estimating the amount of processing for this audio.
|
|
"""
|
|
tokenizer = cached_tokenizer_from_config(model_config)
|
|
adapter = VoxtralProcessorAdapter(tokenizer)
|
|
return adapter.get_num_audio_tokens(
|
|
int(audio_duration_s * stt_config.sample_rate)
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
remapping_rules = [
|
|
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
|
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
|
(
|
|
r"audio_language_adapter\.0\.weight",
|
|
r"audio_language_adapter.w_in.weight",
|
|
),
|
|
(
|
|
r"audio_language_adapter\.2\.weight",
|
|
r"audio_language_adapter.w_out.weight",
|
|
),
|
|
]
|
|
|
|
audio_params = dict(
|
|
nn.ModuleDict(
|
|
{
|
|
"audio_language_adapter": self.audio_language_adapter,
|
|
}
|
|
).named_parameters()
|
|
)
|
|
|
|
loaded_weights = set()
|
|
|
|
def llm_weights_generator():
|
|
nonlocal loaded_weights
|
|
for name, w in weights:
|
|
is_encoder = (
|
|
name.startswith("mm_whisper_embeddings")
|
|
and not name.startswith("mm_whisper_embeddings.tok_embeddings")
|
|
and not name.startswith(
|
|
"mm_whisper_embeddings.audio_language_projection"
|
|
)
|
|
)
|
|
|
|
for pattern, repl in remapping_rules:
|
|
if re.fullmatch(pattern, name):
|
|
name = re.sub(pattern, repl, name)
|
|
|
|
if is_encoder:
|
|
name = self.whisper_encoder.load_weight((name, w))
|
|
loaded_weights.add(f"whisper_encoder.{name}")
|
|
continue
|
|
|
|
if name in audio_params:
|
|
param = audio_params[name]
|
|
with torch.no_grad():
|
|
default_weight_loader(param, w)
|
|
loaded_weights.add(name)
|
|
else:
|
|
yield (name, w)
|
|
|
|
for name in self.language_model.load_weights(llm_weights_generator()):
|
|
loaded_weights.add(f"language_model.{name}")
|
|
|
|
# potentially manually add position embeddings
|
|
sin_key = "whisper_encoder.whisper_encoder.embed_positions.weight"
|
|
if sin_key not in loaded_weights:
|
|
# make sure we don't hit an error here
|
|
loaded_weights.add(sin_key)
|
|
|
|
return loaded_weights
|
|
|
|
def maybe_update_quant_config(
|
|
self, quant_config: QuantizationConfig
|
|
) -> QuantizationConfig:
|
|
"""
|
|
Update quant config to so that ignored module and target module names
|
|
match the vLLM model names.
|
|
Right now this is specific for compressed-tensors format and
|
|
load_format mistral.
|
|
"""
|
|
remapping_rules = [
|
|
(r"output", r"language_model.lm_head"),
|
|
(
|
|
r"layers\.(\d+)\.attention\.wo",
|
|
r"language_model.model.layers.\1.self_attn.out_proj",
|
|
),
|
|
(
|
|
r"layers\.(\d+)\.attention\.w(.*)",
|
|
r"language_model.model.layers.\1.self_attn.\2_proj",
|
|
),
|
|
(
|
|
r"layers\.(\d+)\.feed_forward\.w1",
|
|
r"language_model.model.layers.\1.mlp.gate_proj",
|
|
),
|
|
(
|
|
r"layers\.(\d+)\.feed_forward\.w2",
|
|
r"language_model.model.layers.\1.mlp.down_proj",
|
|
),
|
|
(
|
|
r"layers\.(\d+)\.feed_forward\.w3",
|
|
r"language_model.model.layers.\1.mlp.up_proj",
|
|
),
|
|
(
|
|
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
|
|
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj",
|
|
),
|
|
(
|
|
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
|
|
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj",
|
|
),
|
|
(
|
|
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
|
|
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2",
|
|
),
|
|
(
|
|
r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
|
|
r"whisper_encoder.whisper_encoder.conv1",
|
|
),
|
|
(
|
|
r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
|
|
r"whisper_encoder.whisper_encoder.conv2",
|
|
),
|
|
(
|
|
r"mm_whisper_embeddings\.audio_language_projection\.0",
|
|
r"audio_language_adapter.w_in",
|
|
),
|
|
(
|
|
r"mm_whisper_embeddings\.audio_language_projection\.2",
|
|
r"audio_language_adapter.w_out",
|
|
),
|
|
]
|
|
|
|
# Update ignore list
|
|
if hasattr(quant_config, "ignore"):
|
|
mistral_ignore = []
|
|
for name in quant_config.ignore:
|
|
mistral_name = name
|
|
for pattern, repl in remapping_rules:
|
|
if re.fullmatch(pattern, name):
|
|
mistral_name = re.sub(pattern, repl, name)
|
|
mistral_ignore.append(mistral_name)
|
|
quant_config.ignore = mistral_ignore
|
|
|
|
# Update target list
|
|
if hasattr(quant_config, "config_groups"):
|
|
config_groups = quant_config.config_groups
|
|
for group_name in config_groups:
|
|
if "targets" in config_groups[group_name]:
|
|
targets = []
|
|
for name in config_groups[group_name]["targets"]:
|
|
mistral_name = name
|
|
for pattern, repl in remapping_rules:
|
|
if re.fullmatch(pattern, name):
|
|
mistral_name = re.sub(pattern, repl, name)
|
|
targets.append(mistral_name)
|
|
config_groups[group_name]["targets"] = targets
|
|
quant_config.config_groups = config_groups
|
|
|
|
return quant_config
|
|
|
|
|
|
class AudioLanguageAdapter(nn.Module):
|
|
def __init__(self, hidden_size: int, dim: int) -> None:
|
|
super().__init__()
|
|
self.w_in = nn.Linear(hidden_size, dim, bias=False)
|
|
self.gelu = nn.GELU()
|
|
self.w_out = nn.Linear(dim, dim, bias=False)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.w_out(self.gelu(self.w_in(x)))
|
|
|
|
|
|
class VoxtralEncoderModel(nn.Module):
|
|
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
|
|
|
mistral_remapping = [
|
|
(
|
|
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
|
|
r"whisper_encoder.conv1.\1",
|
|
),
|
|
(
|
|
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
|
|
r"whisper_encoder.conv2.\1",
|
|
),
|
|
(
|
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
|
|
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
|
|
),
|
|
(
|
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", # noqa: E501
|
|
r"whisper_encoder.layers.\1.self_attn.out_proj.\2",
|
|
),
|
|
(
|
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", # noqa: E501
|
|
r"whisper_encoder.layers.\1.self_attn_layer_norm.\2",
|
|
),
|
|
(
|
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", # noqa: E501
|
|
r"whisper_encoder.layers.\1.mlp.fc1.\2",
|
|
),
|
|
(
|
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501
|
|
r"whisper_encoder.layers.\1.mlp.fc2.\2",
|
|
),
|
|
(
|
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
|
|
r"whisper_encoder.layers.\1.final_layer_norm.\2",
|
|
),
|
|
(
|
|
r"whisper_encoder\.transformer\.norm\.(weight|bias)",
|
|
r"whisper_encoder.layer_norm.\1",
|
|
),
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
*,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
|
|
self.dtype: torch.dtype = vllm_config.model_config.dtype
|
|
self.whisper_encoder = WhisperEncoder(
|
|
vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "whisper_encoder"),
|
|
init_in_fp32=True,
|
|
)
|
|
mel_filters = mel_filter_bank(
|
|
num_frequency_bins=1 + self.config.window_size // 2,
|
|
num_mel_bins=self.config.num_mel_bins,
|
|
min_frequency=0.0,
|
|
max_frequency=8000.0,
|
|
sampling_rate=self.config.sampling_rate,
|
|
)
|
|
self.mel_filters = torch.tensor(mel_filters, dtype=torch.float32)
|
|
|
|
def compute_whisper_melspec(
|
|
self,
|
|
audio_waveforms: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
input_dtype = audio_waveforms.dtype
|
|
window = torch.hann_window(self.config.window_size).to(audio_waveforms.device)
|
|
stft = torch.stft(
|
|
audio_waveforms,
|
|
self.config.window_size,
|
|
self.config.hop_length,
|
|
window=window,
|
|
return_complex=True,
|
|
)
|
|
magnitudes = stft[..., :-1].abs() ** 2
|
|
mel_spec = self.mel_filters.T @ magnitudes
|
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
|
log_spec = (log_spec + 4.0) / 4.0
|
|
return log_spec.to(input_dtype)
|
|
|
|
@property
|
|
def downsample_factor(self) -> int:
|
|
return (
|
|
self.whisper_encoder.conv1.stride[0] * self.whisper_encoder.conv2.stride[0]
|
|
)
|
|
|
|
@property
|
|
def chunk_size(self) -> int:
|
|
return self.config.max_source_positions * self.downsample_factor
|
|
|
|
def prepare_inputs_for_conv(
|
|
self,
|
|
audio_waveforms: list[torch.Tensor],
|
|
) -> tuple[torch.Tensor, list[int]]:
|
|
assert isinstance(audio_waveforms, list)
|
|
# list[num_mel_bins, seq_len]
|
|
input_features = [
|
|
self.compute_whisper_melspec(audio).to(self.dtype)
|
|
for audio in audio_waveforms
|
|
]
|
|
|
|
chunked_features: list[torch.Tensor] = []
|
|
chunks_per_example: list[int] = []
|
|
for feature in input_features:
|
|
chunks = feature.split(self.chunk_size, dim=-1)
|
|
chunked_features += chunks
|
|
chunks_per_example.append(len(chunks))
|
|
|
|
# [total_num_chunks, num_mel_bins, chunk_size]
|
|
return torch.stack(chunked_features), chunks_per_example
|
|
|
|
def forward(
|
|
self, input_features: torch.Tensor | list[torch.Tensor]
|
|
) -> list[torch.Tensor]:
|
|
if not isinstance(input_features, list):
|
|
input_features = [input_features]
|
|
|
|
# Split long inputs into chunks
|
|
input_embeds, chunks_per_example = self.prepare_inputs_for_conv(input_features)
|
|
|
|
# [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size]
|
|
out = self.whisper_encoder([input_embeds])
|
|
|
|
# Re-concatenate the chunks
|
|
chunk_idx = 0
|
|
results = []
|
|
for n_chunks in chunks_per_example:
|
|
result = out[chunk_idx : chunk_idx + n_chunks].flatten(0, 1)
|
|
results.append(result)
|
|
chunk_idx += n_chunks
|
|
|
|
return results
|
|
|
|
def load_weight(self, weight: tuple[str, torch.Tensor]) -> str:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
|
|
name, loaded_weight = weight
|
|
for pattern, repl in self.mistral_remapping:
|
|
if re.fullmatch(pattern, name):
|
|
name = re.sub(pattern, repl, name)
|
|
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
return name
|