Harry Mellor 8fcaaf6a16
Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-12 09:51:31 -07:00

846 lines
30 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from math import ceil
from typing import Literal, cast
import numpy as np
import regex as re
import torch
import torch.nn as nn
from mistral_common.audio import mel_filter_bank
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.whisper import WhisperEncoder
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
MultiModalUUIDDict,
NestedTensors,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (
MistralTokenizer,
cached_tokenizer_from_config,
)
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix
logger = init_logger(__name__)
ISO639_1_SUPPORTED_LANGS = {
"ar": "Arabic",
"nl": "Dutch",
"en": "English",
"fr": "French",
"de": "German",
"hi": "Hindi",
"it": "Italian",
"pt": "Portuguese",
"es": "Spanish",
}
class VoxtralProcessorAdapter:
"""
Provide a HF-compatible interface for
:class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
@cached_property
def _audio_processor(self) -> AudioEncoder:
audio_encoder = self.tokenizer.instruct.audio_encoder
assert isinstance(audio_encoder, AudioEncoder)
return audio_encoder
@cached_property
def audio_token_id(self) -> int:
return self._audio_processor.special_ids.audio
@cached_property
def begin_audio_token_id(self) -> int:
return self._audio_processor.special_ids.begin_audio
# @cached_property
# def begin_transcript_token_id(self) -> int:
# return self._audio_processor.special_ids.begin_transcript
# @cached_property
# def end_transcript_token_id(self) -> int:
# return self._audio_processor.special_ids.end_transcript
@cached_property
def sampling_rate(self) -> int:
return self._audio_processor.audio_config.sampling_rate
@cached_property
def frame_rate(self) -> float:
return self._audio_processor.audio_config.frame_rate
def get_num_audio_tokens(
self,
audio_length: int,
) -> int:
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
audio_length, self.sampling_rate
)
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
audios: np.ndarray | list[np.ndarray] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if audios is None:
audios = []
if not isinstance(audios, list):
audios = [audios]
if not audios:
input_ids = self.tokenizer(text).input_ids
return {"input_ids": torch.tensor(input_ids)}
# Allow dummy text, which is used for profiling as well as token inputs
if any(len(t) > 0 for t in text):
raise ValueError(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
audios_tokens = list[torch.Tensor]()
audios_processed = list[torch.Tensor]()
for audio in audios:
assert isinstance(audio, np.ndarray)
assert audio.ndim == 1
# pad if necessary
audio = self._audio_processor.pad(audio, self.sampling_rate)
audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id
] * self.get_num_audio_tokens(len(audio))
audios_tokens.append(torch.tensor(audio_tokens))
audios_processed.append(torch.tensor(audio))
return BatchFeature(
{
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
"audio_arrays": audios_processed,
}
)
class VoxtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
if not isinstance(tokenizer, MistralTokenizer):
raise ValueError("This model requires `--tokenizer-mode mistral`")
return tokenizer
def get_hf_processor(self) -> VoxtralProcessorAdapter:
return VoxtralProcessorAdapter(self.get_tokenizer())
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 5} # Performance tends to degrade after 5
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"audio": self.get_max_audio_tokens()}
def get_max_audio_tokens(self) -> int:
return self.ctx.model_config.max_model_len
def get_max_audio_array_len(self) -> int:
processor = self.get_hf_processor()
return self.get_max_audio_tokens() * int(
processor.sampling_rate // processor.frame_rate
)
class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0)
target_length = self.info.get_max_audio_array_len()
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio": self._get_dummy_audios(
length=target_length, num_audios=num_audios, overrides=audio_overrides
)
}
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_audios = dummy_mm_data.get("audio", [])
audio_chunks: list[AudioChunk] = []
format = "wav"
for audio in dummy_audios:
audio_item = Audio(
audio_array=audio,
sampling_rate=self.info.get_hf_processor().sampling_rate,
format=format,
)
chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
audio_chunks.append(chunk)
request = ChatCompletionRequest(
messages=[
UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
]
)
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
# whixtral tokenizer adds padding to the audio
# so we need to update the audio arrays
dummy_mm_data["audio"] = [a.audio_array for a in res.audios]
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(audio_arrays=MultiModalFieldConfig.batched("audio"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
audio_id = processor.audio_token_id
def get_replacement(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
audio_len = audios.get_audio_length(item_idx)
nb_audio_tokens = processor.get_num_audio_tokens(audio_len)
return [audio_id] * nb_audio_tokens
return [
PromptReplacement(
modality="audio",
target="", # Never match the prompt (see below note)
replacement=get_replacement,
),
]
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
mm_uuids: MultiModalUUIDDict | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_info, True
def _get_data_parser(self) -> MultiModalDataParser:
sampling_rate = self.info.get_hf_processor().sampling_rate
return MultiModalDataParser(target_sr=sampling_rate)
@MULTIMODAL_REGISTRY.register_processor(
VoxtralMultiModalProcessor,
info=VoxtralProcessingInfo,
dummy_inputs=VoxtralDummyInputsBuilder,
)
class VoxtralForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
):
merge_by_field_config = True
supported_languages = ISO639_1_SUPPORTED_LANGS
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
# update quant config to so that ignored module and target module names
# match the vLLM model names
if hasattr(vllm_config, "quant_config"):
vllm_config.quant_config = self.maybe_update_quant_config(
vllm_config.quant_config
)
config = vllm_config.model_config.hf_config
self.config = config
self.downsample_factor = self.config.audio_config.downsample_factor
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.whisper_encoder = VoxtralEncoderModel(
vllm_config.with_hf_config(config.audio_config),
prefix=maybe_prefix(prefix, "whisper_encoder"),
)
self.audio_language_adapter = AudioLanguageAdapter(
hidden_size=config.audio_config.d_model * self.downsample_factor,
dim=config.text_config.hidden_size,
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_mm_mapping(self) -> MultiModelKeys:
"""Get module prefix for multimodal models to filter LoRA modules."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="audio_language_adapter",
tower_model=["whisper_encoder"],
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def get_multimodal_embeddings(
self, **kwargs
) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
if audio_inputs is None:
return None
audio_embeddings = self.whisper_encoder(audio_inputs)
for i, audio_embedding in enumerate(audio_embeddings):
seq_len, dim = audio_embedding.shape
# Pad such that seq_len is divisible by downsample_factor
target_seq_len = self.downsample_factor * math.ceil(
seq_len / self.downsample_factor
)
audio_embedding = torch.nn.functional.pad(
audio_embedding,
(0, 0, 0, target_seq_len - seq_len),
)
audio_embeddings[i] = audio_embedding.reshape(
target_seq_len // self.downsample_factor, dim * self.downsample_factor
)
# Concat, project and resplit
audio_embeddings_packed = torch.cat(audio_embeddings, dim=0)
audio_embeddings_packed = self.audio_language_adapter(audio_embeddings_packed)
audio_embeddings = torch.split(
audio_embeddings_packed, [a.shape[0] for a in audio_embeddings], dim=0
)
return audio_embeddings
def _parse_and_validate_audio_arrays(
self, **kwargs: object
) -> list[torch.Tensor] | None:
audio_arrays = kwargs.pop("audio_arrays", None)
if audio_arrays is None:
return None
if not isinstance(audio_arrays, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
)
if isinstance(audio_arrays, torch.Tensor):
audio_arrays = list(audio_arrays.unbind(0))
return audio_arrays
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
tokenizer = cached_tokenizer_from_config(model_config)
audio_config = tokenizer.instruct.audio_encoder.audio_config
max_audio_clip_s = audio_config.chunk_length_s
sample_rate = audio_config.sampling_rate
return SpeechToTextConfig(
max_audio_clip_s=max_audio_clip_s,
sample_rate=sample_rate,
# mistral_common and whisper encoder take care of chunking
min_energy_split_window_size=None,
)
@classmethod
# for speech-to-text transcription
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
req = TranscriptionRequest(
model=model_config.model,
audio=RawAudio.from_audio(audio),
language=language,
)
tokenized = tokenizer.instruct.encode_transcription(req)
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}}
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)
@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
"""
Map from audio duration to number of audio tokens produced by the ASR
model, without running a forward pass.
This is used for estimating the amount of processing for this audio.
"""
tokenizer = cached_tokenizer_from_config(model_config)
adapter = VoxtralProcessorAdapter(tokenizer)
return adapter.get_num_audio_tokens(
int(audio_duration_s * stt_config.sample_rate)
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
remapping_rules = [
(r"mm_whisper_embeddings\.(.*)", r"\1"),
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
(
r"audio_language_adapter\.0\.weight",
r"audio_language_adapter.w_in.weight",
),
(
r"audio_language_adapter\.2\.weight",
r"audio_language_adapter.w_out.weight",
),
]
audio_params = dict(
nn.ModuleDict(
{
"audio_language_adapter": self.audio_language_adapter,
}
).named_parameters()
)
loaded_weights = set()
def llm_weights_generator():
nonlocal loaded_weights
for name, w in weights:
is_encoder = (
name.startswith("mm_whisper_embeddings")
and not name.startswith("mm_whisper_embeddings.tok_embeddings")
and not name.startswith(
"mm_whisper_embeddings.audio_language_projection"
)
)
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
name = re.sub(pattern, repl, name)
if is_encoder:
name = self.whisper_encoder.load_weight((name, w))
loaded_weights.add(f"whisper_encoder.{name}")
continue
if name in audio_params:
param = audio_params[name]
with torch.no_grad():
default_weight_loader(param, w)
loaded_weights.add(name)
else:
yield (name, w)
for name in self.language_model.load_weights(llm_weights_generator()):
loaded_weights.add(f"language_model.{name}")
# potentially manually add position embeddings
sin_key = "whisper_encoder.whisper_encoder.embed_positions.weight"
if sin_key not in loaded_weights:
# make sure we don't hit an error here
loaded_weights.add(sin_key)
return loaded_weights
def maybe_update_quant_config(
self, quant_config: QuantizationConfig
) -> QuantizationConfig:
"""
Update quant config to so that ignored module and target module names
match the vLLM model names.
Right now this is specific for compressed-tensors format and
load_format mistral.
"""
remapping_rules = [
(r"output", r"language_model.lm_head"),
(
r"layers\.(\d+)\.attention\.wo",
r"language_model.model.layers.\1.self_attn.out_proj",
),
(
r"layers\.(\d+)\.attention\.w(.*)",
r"language_model.model.layers.\1.self_attn.\2_proj",
),
(
r"layers\.(\d+)\.feed_forward\.w1",
r"language_model.model.layers.\1.mlp.gate_proj",
),
(
r"layers\.(\d+)\.feed_forward\.w2",
r"language_model.model.layers.\1.mlp.down_proj",
),
(
r"layers\.(\d+)\.feed_forward\.w3",
r"language_model.model.layers.\1.mlp.up_proj",
),
(
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj",
),
(
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj",
),
(
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2",
),
(
r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
r"whisper_encoder.whisper_encoder.conv1",
),
(
r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
r"whisper_encoder.whisper_encoder.conv2",
),
(
r"mm_whisper_embeddings\.audio_language_projection\.0",
r"audio_language_adapter.w_in",
),
(
r"mm_whisper_embeddings\.audio_language_projection\.2",
r"audio_language_adapter.w_out",
),
]
# Update ignore list
if hasattr(quant_config, "ignore"):
mistral_ignore = []
for name in quant_config.ignore:
mistral_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
mistral_name = re.sub(pattern, repl, name)
mistral_ignore.append(mistral_name)
quant_config.ignore = mistral_ignore
# Update target list
if hasattr(quant_config, "config_groups"):
config_groups = quant_config.config_groups
for group_name in config_groups:
if "targets" in config_groups[group_name]:
targets = []
for name in config_groups[group_name]["targets"]:
mistral_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
mistral_name = re.sub(pattern, repl, name)
targets.append(mistral_name)
config_groups[group_name]["targets"] = targets
quant_config.config_groups = config_groups
return quant_config
class AudioLanguageAdapter(nn.Module):
def __init__(self, hidden_size: int, dim: int) -> None:
super().__init__()
self.w_in = nn.Linear(hidden_size, dim, bias=False)
self.gelu = nn.GELU()
self.w_out = nn.Linear(dim, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x)))
class VoxtralEncoderModel(nn.Module):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
mistral_remapping = [
(
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
r"whisper_encoder.conv1.\1",
),
(
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
r"whisper_encoder.conv2.\1",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.self_attn.out_proj.\2",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.self_attn_layer_norm.\2",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.mlp.fc1.\2",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.mlp.fc2.\2",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
r"whisper_encoder.layers.\1.final_layer_norm.\2",
),
(
r"whisper_encoder\.transformer\.norm\.(weight|bias)",
r"whisper_encoder.layer_norm.\1",
),
]
def __init__(
self,
vllm_config: VllmConfig,
*,
prefix: str = "",
) -> None:
super().__init__()
self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
self.dtype: torch.dtype = vllm_config.model_config.dtype
self.whisper_encoder = WhisperEncoder(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "whisper_encoder"),
init_in_fp32=True,
)
mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.window_size // 2,
num_mel_bins=self.config.num_mel_bins,
min_frequency=0.0,
max_frequency=8000.0,
sampling_rate=self.config.sampling_rate,
)
self.mel_filters = torch.tensor(mel_filters, dtype=torch.float32)
def compute_whisper_melspec(
self,
audio_waveforms: torch.Tensor,
) -> torch.Tensor:
input_dtype = audio_waveforms.dtype
window = torch.hann_window(self.config.window_size).to(audio_waveforms.device)
stft = torch.stft(
audio_waveforms,
self.config.window_size,
self.config.hop_length,
window=window,
return_complex=True,
)
magnitudes = stft[..., :-1].abs() ** 2
mel_spec = self.mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec.to(input_dtype)
@property
def downsample_factor(self) -> int:
return (
self.whisper_encoder.conv1.stride[0] * self.whisper_encoder.conv2.stride[0]
)
@property
def chunk_size(self) -> int:
return self.config.max_source_positions * self.downsample_factor
def prepare_inputs_for_conv(
self,
audio_waveforms: list[torch.Tensor],
) -> tuple[torch.Tensor, list[int]]:
assert isinstance(audio_waveforms, list)
# list[num_mel_bins, seq_len]
input_features = [
self.compute_whisper_melspec(audio).to(self.dtype)
for audio in audio_waveforms
]
chunked_features: list[torch.Tensor] = []
chunks_per_example: list[int] = []
for feature in input_features:
chunks = feature.split(self.chunk_size, dim=-1)
chunked_features += chunks
chunks_per_example.append(len(chunks))
# [total_num_chunks, num_mel_bins, chunk_size]
return torch.stack(chunked_features), chunks_per_example
def forward(
self, input_features: torch.Tensor | list[torch.Tensor]
) -> list[torch.Tensor]:
if not isinstance(input_features, list):
input_features = [input_features]
# Split long inputs into chunks
input_embeds, chunks_per_example = self.prepare_inputs_for_conv(input_features)
# [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size]
out = self.whisper_encoder([input_embeds])
# Re-concatenate the chunks
chunk_idx = 0
results = []
for n_chunks in chunks_per_example:
result = out[chunk_idx : chunk_idx + n_chunks].flatten(0, 1)
results.append(result)
chunk_idx += n_chunks
return results
def load_weight(self, weight: tuple[str, torch.Tensor]) -> str:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
name, loaded_weight = weight
for pattern, repl in self.mistral_remapping:
if re.fullmatch(pattern, name):
name = re.sub(pattern, repl, name)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
return name