mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 03:11:23 +08:00
adapt voxtral (#31095)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
b10d47e0e0
commit
3faa8bee57
@ -111,4 +111,5 @@ async def test_online_serving(client, audio_assets: AudioTestAssets):
|
|||||||
|
|
||||||
assert len(chat_completion.choices) == 1
|
assert len(chat_completion.choices) == 1
|
||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
|
assert choice.message.content == "In the first audio clip, you hear a brief"
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
|
|||||||
@ -860,6 +860,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
# disable this temporarily until we support HF format
|
# disable this temporarily until we support HF format
|
||||||
is_available_online=False,
|
is_available_online=False,
|
||||||
),
|
),
|
||||||
|
"VoxtralStreamingGeneration": _HfExamplesInfo(
|
||||||
|
"<place-holder>",
|
||||||
|
# disable this temporarily until we support HF format
|
||||||
|
is_available_online=False,
|
||||||
|
),
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
"WhisperForConditionalGeneration": _HfExamplesInfo(
|
"WhisperForConditionalGeneration": _HfExamplesInfo(
|
||||||
"openai/whisper-large-v3-turbo",
|
"openai/whisper-large-v3-turbo",
|
||||||
|
|||||||
@ -1542,6 +1542,10 @@ class ModelConfig:
|
|||||||
def is_multimodal_raw_input_only_model(self) -> bool:
|
def is_multimodal_raw_input_only_model(self) -> bool:
|
||||||
return self._model_info.supports_multimodal_raw_input_only
|
return self._model_info.supports_multimodal_raw_input_only
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_raw_input_tokens(self) -> bool:
|
||||||
|
return self._model_info.requires_raw_input_tokens
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_cross_encoder(self) -> bool:
|
def is_cross_encoder(self) -> bool:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -94,6 +94,12 @@ class SupportsMultiModal(Protocol):
|
|||||||
`multimodal_config.mm_encoder_tp_mode="data"`.
|
`multimodal_config.mm_encoder_tp_mode="data"`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
requires_raw_input_tokens: ClassVar[bool] = False
|
||||||
|
"""
|
||||||
|
A flag that indicates this model processes input id tokens
|
||||||
|
in their raw form and not input embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
merge_by_field_config: ClassVar[bool | None] = None
|
merge_by_field_config: ClassVar[bool | None] = None
|
||||||
"""
|
"""
|
||||||
[DEPRECATED] A flag that indicates which implementation of
|
[DEPRECATED] A flag that indicates which implementation of
|
||||||
@ -306,6 +312,10 @@ def supports_multimodal_raw_input_only(model: type[object] | object) -> bool:
|
|||||||
return getattr(model, "supports_multimodal_raw_input_only", False)
|
return getattr(model, "supports_multimodal_raw_input_only", False)
|
||||||
|
|
||||||
|
|
||||||
|
def requires_raw_input_tokens(model: type[object] | object) -> bool:
|
||||||
|
return getattr(model, "requires_raw_input_tokens", False)
|
||||||
|
|
||||||
|
|
||||||
def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
|
def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
|
||||||
return getattr(model, "supports_encoder_tp_data", False)
|
return getattr(model, "supports_encoder_tp_data", False)
|
||||||
|
|
||||||
|
|||||||
@ -46,6 +46,7 @@ from .interfaces import (
|
|||||||
has_noops,
|
has_noops,
|
||||||
is_attention_free,
|
is_attention_free,
|
||||||
is_hybrid,
|
is_hybrid,
|
||||||
|
requires_raw_input_tokens,
|
||||||
supports_cross_encoding,
|
supports_cross_encoding,
|
||||||
supports_mamba_prefix_caching,
|
supports_mamba_prefix_caching,
|
||||||
supports_multimodal,
|
supports_multimodal,
|
||||||
@ -422,6 +423,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
),
|
),
|
||||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||||
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
|
||||||
|
"VoxtralStreamingGeneration": ("voxtral_streaming", "VoxtralStreamingGeneration"), # noqa: E501
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
|
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
|
||||||
}
|
}
|
||||||
@ -539,6 +541,7 @@ class _ModelInfo:
|
|||||||
supports_cross_encoding: bool
|
supports_cross_encoding: bool
|
||||||
supports_multimodal: bool
|
supports_multimodal: bool
|
||||||
supports_multimodal_raw_input_only: bool
|
supports_multimodal_raw_input_only: bool
|
||||||
|
requires_raw_input_tokens: bool
|
||||||
supports_multimodal_encoder_tp_data: bool
|
supports_multimodal_encoder_tp_data: bool
|
||||||
supports_pp: bool
|
supports_pp: bool
|
||||||
has_inner_state: bool
|
has_inner_state: bool
|
||||||
@ -562,6 +565,7 @@ class _ModelInfo:
|
|||||||
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
|
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
|
||||||
model
|
model
|
||||||
),
|
),
|
||||||
|
requires_raw_input_tokens=requires_raw_input_tokens(model),
|
||||||
supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
|
supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data(
|
||||||
model
|
model
|
||||||
),
|
),
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
@ -116,10 +117,7 @@ class VoxtralProcessorAdapter:
|
|||||||
self,
|
self,
|
||||||
audio_length: int,
|
audio_length: int,
|
||||||
) -> int:
|
) -> int:
|
||||||
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
|
return ceil(audio_length / (self.sampling_rate // self.frame_rate))
|
||||||
audio_length, self.sampling_rate
|
|
||||||
)
|
|
||||||
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -158,7 +156,14 @@ class VoxtralProcessorAdapter:
|
|||||||
assert audio.ndim == 1
|
assert audio.ndim == 1
|
||||||
|
|
||||||
# pad if necessary
|
# pad if necessary
|
||||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
# TODO(Patrick) - remove once mistral-common is bumped
|
||||||
|
sig = inspect.signature(self._audio_processor.pad)
|
||||||
|
if "is_online_streaming" in sig.parameters:
|
||||||
|
audio = self._audio_processor.pad(
|
||||||
|
audio, self.sampling_rate, is_online_streaming=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||||
|
|
||||||
audio_tokens = [self.begin_audio_token_id] + [
|
audio_tokens = [self.begin_audio_token_id] + [
|
||||||
self.audio_token_id
|
self.audio_token_id
|
||||||
@ -510,6 +515,7 @@ class VoxtralForConditionalGeneration(
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
remapping_rules = [
|
remapping_rules = [
|
||||||
|
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
|
||||||
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
||||||
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
||||||
(
|
(
|
||||||
@ -535,13 +541,16 @@ class VoxtralForConditionalGeneration(
|
|||||||
def llm_weights_generator():
|
def llm_weights_generator():
|
||||||
nonlocal loaded_weights
|
nonlocal loaded_weights
|
||||||
for name, w in weights:
|
for name, w in weights:
|
||||||
is_encoder = (
|
is_encoder = False
|
||||||
name.startswith("mm_whisper_embeddings")
|
for k in [
|
||||||
and not name.startswith("mm_whisper_embeddings.tok_embeddings")
|
"mm_whisper_embeddings",
|
||||||
and not name.startswith(
|
"mm_streams_embeddings.embedding_module",
|
||||||
"mm_whisper_embeddings.audio_language_projection"
|
]:
|
||||||
|
is_encoder |= (
|
||||||
|
name.startswith(k)
|
||||||
|
and not name.startswith(f"{k}.tok_embeddings")
|
||||||
|
and not name.startswith(f"{k}.audio_language_projection")
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
for pattern, repl in remapping_rules:
|
for pattern, repl in remapping_rules:
|
||||||
if re.fullmatch(pattern, name):
|
if re.fullmatch(pattern, name):
|
||||||
@ -676,6 +685,7 @@ class VoxtralEncoderModel(nn.Module):
|
|||||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||||
|
|
||||||
mistral_remapping = [
|
mistral_remapping = [
|
||||||
|
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
|
||||||
(
|
(
|
||||||
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
|
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
|
||||||
r"whisper_encoder.conv1.\1",
|
r"whisper_encoder.conv1.\1",
|
||||||
@ -684,6 +694,14 @@ class VoxtralEncoderModel(nn.Module):
|
|||||||
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
|
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
|
||||||
r"whisper_encoder.conv2.\1",
|
r"whisper_encoder.conv2.\1",
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
r"whisper_encoder\.conv_layers\.0\.conv\.(weight|bias)",
|
||||||
|
r"whisper_encoder.conv1.\1",
|
||||||
|
), # noqa: E501
|
||||||
|
(
|
||||||
|
r"whisper_encoder\.conv_layers\.1\.conv\.(weight|bias)",
|
||||||
|
r"whisper_encoder.conv2.\1",
|
||||||
|
), # noqa: E501
|
||||||
(
|
(
|
||||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
|
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
|
||||||
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
|
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
|
||||||
|
|||||||
243
vllm/model_executor/models/voxtral_streaming.py
Normal file
243
vllm/model_executor/models/voxtral_streaming.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config.vllm import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||||
|
from vllm.model_executor.models.voxtral import (
|
||||||
|
VoxtralDummyInputsBuilder,
|
||||||
|
VoxtralForConditionalGeneration,
|
||||||
|
VoxtralMultiModalProcessor,
|
||||||
|
VoxtralProcessingInfo,
|
||||||
|
)
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.cache import _I, BaseMultiModalProcessorCache
|
||||||
|
from vllm.multimodal.inputs import (
|
||||||
|
MultiModalKwargsOptionalItems,
|
||||||
|
)
|
||||||
|
from vllm.multimodal.parse import MultiModalDataItems
|
||||||
|
from vllm.multimodal.processing import (
|
||||||
|
MultiModalPromptUpdates,
|
||||||
|
PlaceholderFeaturesInfo,
|
||||||
|
)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
_flatten_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VoxtralStreamingMultiModalProcessor(VoxtralMultiModalProcessor):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
info: _I,
|
||||||
|
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||||
|
*,
|
||||||
|
cache: BaseMultiModalProcessorCache | None = None,
|
||||||
|
) -> None:
|
||||||
|
# streaming can't make use of a cache yet
|
||||||
|
super().__init__(info, dummy_inputs, cache=None)
|
||||||
|
|
||||||
|
def _maybe_apply_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
prompt_ids: list[int],
|
||||||
|
mm_kwargs: MultiModalKwargsOptionalItems,
|
||||||
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
|
is_update_applied: bool,
|
||||||
|
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||||
|
# there are no placeholder audio tokens for streaming
|
||||||
|
# so we need to build the place placeholder positions manually
|
||||||
|
|
||||||
|
# in streaming there is always only one audio input
|
||||||
|
audios = mm_kwargs.get("audio", [])
|
||||||
|
assert len(audios) == 1, (
|
||||||
|
f"Expected only one audio input for streaming, got {mm_kwargs=}"
|
||||||
|
)
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
audio_config = tokenizer.instruct.audio_encoder.audio_config
|
||||||
|
|
||||||
|
num_audio_samples = audios[0]["audio_arrays"].data.shape[0]
|
||||||
|
length = audio_config.num_audio_tokens(num_audio_samples)
|
||||||
|
|
||||||
|
features_info = PlaceholderFeaturesInfo(
|
||||||
|
modality="audio",
|
||||||
|
item_idx=0,
|
||||||
|
start_idx=0,
|
||||||
|
tokens=length
|
||||||
|
* [0], # only used for length computation, so we can take dummy inputs
|
||||||
|
is_embed=None,
|
||||||
|
)
|
||||||
|
return prompt_ids, {"audio": [features_info]}
|
||||||
|
|
||||||
|
|
||||||
|
class TimeEmbedding(torch.nn.Module):
|
||||||
|
"""Sinusoidal Embedding for encoding time"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
inv_freq = torch.exp(
|
||||||
|
-math.log(self.theta)
|
||||||
|
* torch.arange(self.dim // 2).float()
|
||||||
|
/ (self.dim // 2)
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
t = t[..., None] # (B,) -> (B, 1) or (B, T) -> (B, T, 1)
|
||||||
|
inv_freq = self.inv_freq.to(device=t.device, dtype=t.dtype)
|
||||||
|
emb = (
|
||||||
|
t * inv_freq
|
||||||
|
) # (B, 1) x (D/2,) -> (B, D/2) or (B, T, 1) x (D/2,) -> (B, T, D/2)
|
||||||
|
return torch.cat((emb.cos(), emb.sin()), dim=-1) # (B, D) or (B, T, D)
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
VoxtralStreamingMultiModalProcessor,
|
||||||
|
info=VoxtralProcessingInfo,
|
||||||
|
dummy_inputs=VoxtralDummyInputsBuilder,
|
||||||
|
)
|
||||||
|
class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
||||||
|
requires_raw_input_tokens = True
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
self.time_embedding: TimeEmbedding = TimeEmbedding(
|
||||||
|
dim=self.config.text_config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_config = self.tokenizer.instruct.audio_encoder.audio_config
|
||||||
|
_n_delay_tokens = (
|
||||||
|
audio_config.frame_rate * audio_config.transcription_delay_ms / 1000
|
||||||
|
)
|
||||||
|
assert _n_delay_tokens.is_integer(), (
|
||||||
|
f"n_delay_tokens must be integer, got {_n_delay_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.n_delay_tokens = int(_n_delay_tokens)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_config(self):
|
||||||
|
return self.tokenizer.instruct.audio_encoder.audio_config
|
||||||
|
|
||||||
|
def embed_input_ids(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||||
|
*,
|
||||||
|
is_multimodal: torch.Tensor | None = None,
|
||||||
|
# Multi-modal token ID may exceed vocab size
|
||||||
|
handle_oov_mm_token: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Pass post-conv embeddings directly as input"""
|
||||||
|
# for streaming we simply flatten the multimodal embeddings
|
||||||
|
# to be in tensor format, we treat the input ids later
|
||||||
|
assert multimodal_embeddings is not None
|
||||||
|
assert len(multimodal_embeddings) > 0, (
|
||||||
|
"For streaming you must provide a multimodal_embedding at every step."
|
||||||
|
)
|
||||||
|
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
|
||||||
|
return mm_embeds_flat
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> torch.Tensor | IntermediateTensors:
|
||||||
|
assert inputs_embeds is not None
|
||||||
|
assert input_ids is not None
|
||||||
|
|
||||||
|
pool_size = self.config.audio_config.block_pool_size
|
||||||
|
inputs_embeds = inputs_embeds.view(
|
||||||
|
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_hidden_states = self.whisper_encoder.whisper_encoder.forward_layers(
|
||||||
|
inputs_embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
num_tokens, audio_hidden_size = audio_hidden_states.shape
|
||||||
|
assert num_tokens % self.downsample_factor == 0
|
||||||
|
audio_hidden_states = audio_hidden_states.reshape(
|
||||||
|
num_tokens // self.downsample_factor,
|
||||||
|
audio_hidden_size * self.downsample_factor,
|
||||||
|
)
|
||||||
|
audio_text_embeds = self.audio_language_adapter(audio_hidden_states)
|
||||||
|
|
||||||
|
text_embeds = self.language_model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
# sum pool text and audio embeddings
|
||||||
|
inputs_embeds = audio_text_embeds + text_embeds
|
||||||
|
|
||||||
|
time_tensor = torch.tensor(
|
||||||
|
[self.n_delay_tokens],
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
dtype=inputs_embeds.dtype,
|
||||||
|
)
|
||||||
|
inputs_embeds = inputs_embeds + self.time_embedding(time_tensor)
|
||||||
|
|
||||||
|
hidden_states = self.language_model.model(
|
||||||
|
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def embed_multimodal(
|
||||||
|
self, **kwargs
|
||||||
|
) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
|
||||||
|
"""Transform audio waveforms -> initial whisper post-conv embeddings"""
|
||||||
|
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
|
||||||
|
|
||||||
|
assert audio_inputs is not None, (
|
||||||
|
"For streaming you must provide an audio input at every step."
|
||||||
|
)
|
||||||
|
|
||||||
|
multiple_of = self.audio_config.raw_audio_length_per_tok
|
||||||
|
assert all(
|
||||||
|
(this_audio := audio.shape[0]) % multiple_of == 0 for audio in audio_inputs
|
||||||
|
), (
|
||||||
|
f"Every input audio waveform has to be a multiple of {multiple_of}, but"
|
||||||
|
f" one is {this_audio} with {(this_audio / multiple_of)=}."
|
||||||
|
)
|
||||||
|
|
||||||
|
mel_features = [
|
||||||
|
self.whisper_encoder.compute_whisper_melspec(audio).to(
|
||||||
|
self.whisper_encoder.dtype
|
||||||
|
)
|
||||||
|
for audio in audio_inputs
|
||||||
|
]
|
||||||
|
seq_lens = [mel.shape[1] for mel in mel_features]
|
||||||
|
# [total_num_20ms_frames, hidden_size]
|
||||||
|
audio_embeddings = self.whisper_encoder.whisper_encoder.forward_conv(
|
||||||
|
mel_features
|
||||||
|
)[0]
|
||||||
|
conv_stride = self.whisper_encoder.whisper_encoder.total_stride
|
||||||
|
audio_embeddings_per_sample = audio_embeddings.split(
|
||||||
|
[s // conv_stride for s in seq_lens], dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# audio_embeddings per sample need to be divisible by 4
|
||||||
|
pool_size = self.config.audio_config.block_pool_size
|
||||||
|
assert all(
|
||||||
|
(this_shape := sample.shape[0]) % pool_size == 0
|
||||||
|
for sample in audio_embeddings_per_sample
|
||||||
|
), f"Every audio embedding has to be a multiple of 4, but one is {this_shape}."
|
||||||
|
|
||||||
|
audio_embeddings_per_sample = [
|
||||||
|
e.view(e.shape[0] // pool_size, e.shape[1] * pool_size)
|
||||||
|
for e in audio_embeddings_per_sample
|
||||||
|
]
|
||||||
|
return audio_embeddings_per_sample
|
||||||
@ -1,9 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import enum
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from functools import partial
|
||||||
from typing import Annotated, Literal, cast
|
from typing import Annotated, Literal, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -16,7 +18,10 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||||
|
|
||||||
from vllm.attention.layer import Attention, AttentionType
|
from vllm.attention.backends.abstract import (
|
||||||
|
AttentionType,
|
||||||
|
)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.layers.cross_attention import CrossAttention
|
from vllm.attention.layers.cross_attention import CrossAttention
|
||||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||||
@ -34,6 +39,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.whisper_utils import (
|
||||||
|
ISO639_1_SUPPORTED_LANGS,
|
||||||
|
WhisperAttentionWithBlockPooling,
|
||||||
|
WhisperCausalConv1d,
|
||||||
|
)
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import (
|
||||||
MultiModalDataDict,
|
MultiModalDataDict,
|
||||||
@ -64,67 +74,11 @@ from .utils import (
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
|
||||||
|
|
||||||
ISO639_1_SUPPORTED_LANGS = {
|
class WhisperPosEmbedType(enum.Enum):
|
||||||
"af": "Afrikaans",
|
SINUSOIDAL = "sinusoidal"
|
||||||
"ar": "Arabic",
|
NOPE = "nope"
|
||||||
"hy": "Armenian",
|
LEARNED = "learned"
|
||||||
"az": "Azerbaijani",
|
|
||||||
"be": "Belarusian",
|
|
||||||
"bs": "Bosnian",
|
|
||||||
"bg": "Bulgarian",
|
|
||||||
"ca": "Catalan",
|
|
||||||
"zh": "Chinese",
|
|
||||||
"hr": "Croatian",
|
|
||||||
"cs": "Czech",
|
|
||||||
"da": "Danish",
|
|
||||||
"nl": "Dutch",
|
|
||||||
"en": "English",
|
|
||||||
"et": "Estonian",
|
|
||||||
"fi": "Finnish",
|
|
||||||
"fr": "French",
|
|
||||||
"gl": "Galician",
|
|
||||||
"de": "German",
|
|
||||||
"el": "Greek",
|
|
||||||
"he": "Hebrew",
|
|
||||||
"hi": "Hindi",
|
|
||||||
"hu": "Hungarian",
|
|
||||||
"is": "Icelandic",
|
|
||||||
"id": "Indonesian",
|
|
||||||
"it": "Italian",
|
|
||||||
"ja": "Japanese",
|
|
||||||
"kn": "Kannada",
|
|
||||||
"kk": "Kazakh",
|
|
||||||
"ko": "Korean",
|
|
||||||
"lv": "Latvian",
|
|
||||||
"lt": "Lithuanian",
|
|
||||||
"mk": "Macedonian",
|
|
||||||
"ms": "Malay",
|
|
||||||
"mr": "Marathi",
|
|
||||||
"mi": "Maori",
|
|
||||||
"ne": "Nepali",
|
|
||||||
"no": "Norwegian",
|
|
||||||
"fa": "Persian",
|
|
||||||
"pl": "Polish",
|
|
||||||
"pt": "Portuguese",
|
|
||||||
"ro": "Romanian",
|
|
||||||
"ru": "Russian",
|
|
||||||
"sr": "Serbian",
|
|
||||||
"sk": "Slovak",
|
|
||||||
"sl": "Slovenian",
|
|
||||||
"es": "Spanish",
|
|
||||||
"sw": "Swahili",
|
|
||||||
"sv": "Swedish",
|
|
||||||
"tl": "Tagalog",
|
|
||||||
"ta": "Tamil",
|
|
||||||
"th": "Thai",
|
|
||||||
"tr": "Turkish",
|
|
||||||
"uk": "Ukrainian",
|
|
||||||
"ur": "Urdu",
|
|
||||||
"vi": "Vietnamese",
|
|
||||||
"cy": "Welsh",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WhisperAudioInputs(TensorSchema):
|
class WhisperAudioInputs(TensorSchema):
|
||||||
@ -184,6 +138,8 @@ class WhisperAttention(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
|
per_layer_sliding_window: int | None = None,
|
||||||
|
block_pool_size: int = 1,
|
||||||
cache_config: CacheConfig | None = None,
|
cache_config: CacheConfig | None = None,
|
||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
@ -242,7 +198,14 @@ class WhisperAttention(nn.Module):
|
|||||||
attn_type=self.attn_type,
|
attn_type=self.attn_type,
|
||||||
)
|
)
|
||||||
else: # AttentionType.DECODER (regular decoder self-attention)
|
else: # AttentionType.DECODER (regular decoder self-attention)
|
||||||
self.attn = Attention(
|
if block_pool_size > 1:
|
||||||
|
attn_cls = partial(
|
||||||
|
WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_cls = Attention
|
||||||
|
|
||||||
|
self.attn = attn_cls(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
@ -251,6 +214,7 @@ class WhisperAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
attn_type=self.attn_type,
|
attn_type=self.attn_type,
|
||||||
|
per_layer_sliding_window=per_layer_sliding_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_qkv(
|
def _init_qkv(
|
||||||
@ -386,6 +350,9 @@ class WhisperEncoderLayer(nn.Module):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
is_causal = getattr(config, "is_causal", False)
|
||||||
|
sliding_window = getattr(config, "sliding_window", None)
|
||||||
|
block_pool_size = getattr(config, "block_pool_size", 1)
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
@ -393,7 +360,9 @@ class WhisperEncoderLayer(nn.Module):
|
|||||||
self.self_attn = WhisperAttention(
|
self.self_attn = WhisperAttention(
|
||||||
embed_dim=self.embed_dim,
|
embed_dim=self.embed_dim,
|
||||||
num_heads=config.encoder_attention_heads,
|
num_heads=config.encoder_attention_heads,
|
||||||
attn_type=AttentionType.ENCODER,
|
attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER,
|
||||||
|
block_pool_size=block_pool_size,
|
||||||
|
per_layer_sliding_window=sliding_window,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
@ -492,12 +461,21 @@ class WhisperEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
embed_dim = config.d_model
|
embed_dim = config.d_model
|
||||||
|
|
||||||
|
self.pos_embed_type = WhisperPosEmbedType(
|
||||||
|
getattr(config, "pos_embed", "sinusoidal")
|
||||||
|
)
|
||||||
self.num_mel_bins = config.num_mel_bins
|
self.num_mel_bins = config.num_mel_bins
|
||||||
self.max_source_positions = config.max_source_positions
|
self.max_source_positions = config.max_source_positions
|
||||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||||
|
|
||||||
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
|
is_causal = getattr(config, "is_causal", False)
|
||||||
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1)
|
||||||
|
|
||||||
|
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
|
||||||
|
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
|
||||||
|
|
||||||
|
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
config.encoder_layers,
|
config.encoder_layers,
|
||||||
lambda prefix: WhisperEncoderLayer(
|
lambda prefix: WhisperEncoderLayer(
|
||||||
@ -507,29 +485,54 @@ class WhisperEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||||
|
|
||||||
maybe_fp32_init_ctx = (
|
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
||||||
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
|
raise ValueError(
|
||||||
)
|
"Only NOPE position embeddings are supported "
|
||||||
|
f"for causal models, but got {self.pos_embed_type}"
|
||||||
with (
|
)
|
||||||
torch.no_grad(),
|
elif self.pos_embed_type in (
|
||||||
maybe_fp32_init_ctx,
|
WhisperPosEmbedType.SINUSOIDAL,
|
||||||
|
WhisperPosEmbedType.LEARNED,
|
||||||
):
|
):
|
||||||
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
maybe_fp32_init_ctx = (
|
||||||
self.embed_positions.weight.copy_(
|
set_default_torch_dtype(torch.float32)
|
||||||
sinusoids(*self.embed_positions.weight.shape)
|
if init_in_fp32
|
||||||
|
else nullcontext()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
with (
|
||||||
|
torch.no_grad(),
|
||||||
|
maybe_fp32_init_ctx,
|
||||||
|
):
|
||||||
|
self.embed_positions = nn.Embedding(
|
||||||
|
self.max_source_positions, embed_dim
|
||||||
|
)
|
||||||
|
self.embed_positions.weight.copy_(
|
||||||
|
sinusoids(*self.embed_positions.weight.shape)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_conv(
|
||||||
|
self, input_features: torch.Tensor | list[torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
hidden_states = []
|
hidden_states = []
|
||||||
input_is_batched = False
|
input_is_batched = False
|
||||||
for features in input_features:
|
for features in input_features:
|
||||||
embeds = nn.functional.gelu(self.conv1(features))
|
embeds = nn.functional.gelu(self.conv1(features))
|
||||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||||
embeds = embeds.transpose(-1, -2)
|
|
||||||
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
|
if self.pos_embed_type in (
|
||||||
embeds.dtype
|
WhisperPosEmbedType.SINUSOIDAL,
|
||||||
)
|
WhisperPosEmbedType.LEARNED,
|
||||||
|
):
|
||||||
|
embeds = embeds.transpose(-1, -2)
|
||||||
|
embeds = (
|
||||||
|
embeds + self.embed_positions.weight[: embeds.size(-2), :]
|
||||||
|
).to(embeds.dtype)
|
||||||
|
elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
|
||||||
|
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")
|
||||||
|
|
||||||
hidden_states.append(embeds)
|
hidden_states.append(embeds)
|
||||||
input_is_batched = embeds.ndim > 2
|
input_is_batched = embeds.ndim > 2
|
||||||
# Input to MHA must be B x T x D
|
# Input to MHA must be B x T x D
|
||||||
@ -539,12 +542,19 @@ class WhisperEncoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states = torch.stack(hidden_states, dim=0)
|
hidden_states = torch.stack(hidden_states, dim=0)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
hidden_states = encoder_layer(hidden_states)
|
hidden_states = encoder_layer(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||||
|
hidden_states = self.forward_conv(input_features)
|
||||||
|
return self.forward_layers(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
class WhisperDecoder(nn.Module):
|
class WhisperDecoder(nn.Module):
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
|||||||
299
vllm/model_executor/models/whisper_utils.py
Normal file
299
vllm/model_executor/models/whisper_utils.py
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import copy
|
||||||
|
import functools
|
||||||
|
import math
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (
|
||||||
|
AttentionBackend,
|
||||||
|
AttentionMetadata,
|
||||||
|
AttentionType,
|
||||||
|
)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
subclass_attention_backend_with_overrides,
|
||||||
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
||||||
|
ISO639_1_SUPPORTED_LANGS = {
|
||||||
|
"af": "Afrikaans",
|
||||||
|
"ar": "Arabic",
|
||||||
|
"hy": "Armenian",
|
||||||
|
"az": "Azerbaijani",
|
||||||
|
"be": "Belarusian",
|
||||||
|
"bs": "Bosnian",
|
||||||
|
"bg": "Bulgarian",
|
||||||
|
"ca": "Catalan",
|
||||||
|
"zh": "Chinese",
|
||||||
|
"hr": "Croatian",
|
||||||
|
"cs": "Czech",
|
||||||
|
"da": "Danish",
|
||||||
|
"nl": "Dutch",
|
||||||
|
"en": "English",
|
||||||
|
"et": "Estonian",
|
||||||
|
"fi": "Finnish",
|
||||||
|
"fr": "French",
|
||||||
|
"gl": "Galician",
|
||||||
|
"de": "German",
|
||||||
|
"el": "Greek",
|
||||||
|
"he": "Hebrew",
|
||||||
|
"hi": "Hindi",
|
||||||
|
"hu": "Hungarian",
|
||||||
|
"is": "Icelandic",
|
||||||
|
"id": "Indonesian",
|
||||||
|
"it": "Italian",
|
||||||
|
"ja": "Japanese",
|
||||||
|
"kn": "Kannada",
|
||||||
|
"kk": "Kazakh",
|
||||||
|
"ko": "Korean",
|
||||||
|
"lv": "Latvian",
|
||||||
|
"lt": "Lithuanian",
|
||||||
|
"mk": "Macedonian",
|
||||||
|
"ms": "Malay",
|
||||||
|
"mr": "Marathi",
|
||||||
|
"mi": "Maori",
|
||||||
|
"ne": "Nepali",
|
||||||
|
"no": "Norwegian",
|
||||||
|
"fa": "Persian",
|
||||||
|
"pl": "Polish",
|
||||||
|
"pt": "Portuguese",
|
||||||
|
"ro": "Romanian",
|
||||||
|
"ru": "Russian",
|
||||||
|
"sr": "Serbian",
|
||||||
|
"sk": "Slovak",
|
||||||
|
"sl": "Slovenian",
|
||||||
|
"es": "Spanish",
|
||||||
|
"sw": "Swahili",
|
||||||
|
"sv": "Swedish",
|
||||||
|
"tl": "Tagalog",
|
||||||
|
"ta": "Tamil",
|
||||||
|
"th": "Thai",
|
||||||
|
"tr": "Turkish",
|
||||||
|
"uk": "Ukrainian",
|
||||||
|
"ur": "Urdu",
|
||||||
|
"vi": "Vietnamese",
|
||||||
|
"cy": "Welsh",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _pad1d(
|
||||||
|
x: torch.Tensor,
|
||||||
|
paddings: tuple[int, int],
|
||||||
|
mode: str = "constant",
|
||||||
|
value: float = 0.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Tiny wrapper around F.pad, just to allow for
|
||||||
|
reflect padding on small input.
|
||||||
|
If this is the case, we insert extra 0 padding
|
||||||
|
to the right before the reflection happen.
|
||||||
|
"""
|
||||||
|
length = x.shape[-1]
|
||||||
|
padding_left, padding_right = paddings
|
||||||
|
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
||||||
|
if mode == "reflect":
|
||||||
|
max_pad = max(padding_left, padding_right)
|
||||||
|
extra_pad = 0
|
||||||
|
if length <= max_pad:
|
||||||
|
extra_pad = max_pad - length + 1
|
||||||
|
x = F.pad(x, (0, extra_pad))
|
||||||
|
padded = F.pad(x, paddings, mode, value)
|
||||||
|
end = padded.shape[-1] - extra_pad
|
||||||
|
return padded[..., :end]
|
||||||
|
else:
|
||||||
|
return F.pad(x, paddings, mode, value)
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperCausalConv1d(nn.Conv1d):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
stride: int = 1,
|
||||||
|
padding: int = 0,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
self._stride = self.stride[0]
|
||||||
|
self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1
|
||||||
|
self._padding_total = self._effective_kernel_size - self._stride
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
n_frames = (
|
||||||
|
x.shape[-1] - self._effective_kernel_size + self._padding_total
|
||||||
|
) / self._stride + 1
|
||||||
|
target_length = (math.ceil(n_frames) - 1) * self._stride + (
|
||||||
|
self._effective_kernel_size - self._padding_total
|
||||||
|
)
|
||||||
|
extra_padding = target_length - x.shape[-1]
|
||||||
|
x = _pad1d(x, (self._padding_total, extra_padding), mode="constant")
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def create_whisper_attention_backend_with_block_pooling(
|
||||||
|
underlying_attn_backend: AttentionBackend, block_pool_size: int
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
prefix = "WhisperAttentionWithBlockPooling_"
|
||||||
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
|
class WhisperAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
layer_names: list[str],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
assert kv_cache_spec.num_kv_heads % block_pool_size == 0
|
||||||
|
kv_cache_spec = replace(
|
||||||
|
kv_cache_spec,
|
||||||
|
block_size=kv_cache_spec.block_size * block_pool_size,
|
||||||
|
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
|
||||||
|
)
|
||||||
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> AttentionMetadata:
|
||||||
|
new_common_attn_metadata = copy.deepcopy(common_attn_metadata)
|
||||||
|
new_common_attn_metadata.query_start_loc *= block_pool_size
|
||||||
|
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
|
||||||
|
new_common_attn_metadata.seq_lens *= block_pool_size
|
||||||
|
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
|
||||||
|
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
|
||||||
|
new_common_attn_metadata.num_actual_tokens *= block_pool_size
|
||||||
|
new_common_attn_metadata.max_query_len *= block_pool_size
|
||||||
|
new_common_attn_metadata.max_seq_len *= block_pool_size
|
||||||
|
original_slot_mapping = common_attn_metadata.slot_mapping
|
||||||
|
common_prefix_len *= block_pool_size
|
||||||
|
new_common_attn_metadata.slot_mapping = (
|
||||||
|
(
|
||||||
|
original_slot_mapping.unsqueeze(1) * block_pool_size
|
||||||
|
+ torch.arange(block_pool_size, device=original_slot_mapping.device)
|
||||||
|
)
|
||||||
|
.flatten()
|
||||||
|
.clamp(min=-1)
|
||||||
|
)
|
||||||
|
return super().build(
|
||||||
|
common_prefix_len, new_common_attn_metadata, fast_build
|
||||||
|
)
|
||||||
|
|
||||||
|
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{underlying_attn_backend} is not yet supported."
|
||||||
|
"Contributions to support more backends are much "
|
||||||
|
"appreciated."
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_backend = subclass_attention_backend_with_overrides(
|
||||||
|
name_prefix=prefix,
|
||||||
|
attention_backend_cls=underlying_attn_backend,
|
||||||
|
overrides={
|
||||||
|
"get_builder_cls": lambda: WhisperAttentionWithBlockPoolingBuilder,
|
||||||
|
"get_kv_cache_shape": lambda num_blocks,
|
||||||
|
block_size,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
cache_dtype_str: (
|
||||||
|
2,
|
||||||
|
num_blocks,
|
||||||
|
# we stretch each block by `block_pool_size`
|
||||||
|
block_size * block_pool_size,
|
||||||
|
num_kv_heads // block_pool_size,
|
||||||
|
head_size,
|
||||||
|
), # TODO: generalize to other backends
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperAttentionWithBlockPooling(Attention):
|
||||||
|
"""Attention layer with block pooling."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int | None = None,
|
||||||
|
alibi_slopes: list[float] | None = None,
|
||||||
|
cache_config: CacheConfig | None = None,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
logits_soft_cap: float | None = None,
|
||||||
|
per_layer_sliding_window: int | None = None,
|
||||||
|
prefix: str = "",
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: str | None = None,
|
||||||
|
block_pool_size: int = 1,
|
||||||
|
attn_backend: type[AttentionBackend] | None = None,
|
||||||
|
**extra_impl_args,
|
||||||
|
) -> None:
|
||||||
|
self.block_pool_size = block_pool_size
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
|
||||||
|
underlying_attn_backend = get_attn_backend(
|
||||||
|
head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
attn_type=attn_type,
|
||||||
|
)
|
||||||
|
attn_backend = create_whisper_attention_backend_with_block_pooling(
|
||||||
|
underlying_attn_backend, block_pool_size
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
scale=scale,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
alibi_slopes=alibi_slopes,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
per_layer_sliding_window=per_layer_sliding_window,
|
||||||
|
prefix=prefix,
|
||||||
|
attn_type=attn_type,
|
||||||
|
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||||
|
attn_backend=attn_backend,
|
||||||
|
**extra_impl_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self, vllm_config: VllmConfig):
|
||||||
|
kv_cache_spec = super().get_kv_cache_spec(vllm_config)
|
||||||
|
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||||
|
kv_cache_spec = replace(
|
||||||
|
kv_cache_spec,
|
||||||
|
num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads,
|
||||||
|
)
|
||||||
|
return kv_cache_spec
|
||||||
@ -184,18 +184,42 @@ def _remap_mistral_audio_args(config: dict) -> dict:
|
|||||||
whisper_args = config["multimodal"].pop("whisper_model_args")
|
whisper_args = config["multimodal"].pop("whisper_model_args")
|
||||||
encoder_args = whisper_args["encoder_args"]
|
encoder_args = whisper_args["encoder_args"]
|
||||||
downsample_args = whisper_args["downsample_args"]
|
downsample_args = whisper_args["downsample_args"]
|
||||||
|
downsample_factor = downsample_args["downsample_factor"]
|
||||||
|
|
||||||
|
# make sure that k/v blocks can be allocated with
|
||||||
|
# unified k/v cache class and pool whisper k/v cache blocks
|
||||||
|
# with downsample_factor:1 ratio
|
||||||
|
if encoder_args.get("causal"):
|
||||||
|
block_pool_size = downsample_factor
|
||||||
|
config["projection_size"] = downsample_factor * encoder_args["dim"]
|
||||||
|
else:
|
||||||
|
block_pool_size = 1
|
||||||
|
|
||||||
|
_maybe_sliding_window = encoder_args.get("ragged_attention", None)
|
||||||
|
if _maybe_sliding_window is None:
|
||||||
|
sliding_window = None
|
||||||
|
elif _maybe_sliding_window.isdigit():
|
||||||
|
sliding_window = int(_maybe_sliding_window)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}")
|
||||||
|
|
||||||
|
architecture = (
|
||||||
|
"VoxtralStreamingGeneration"
|
||||||
|
if encoder_args.get("causal")
|
||||||
|
else "VoxtralForConditionalGeneration"
|
||||||
|
)
|
||||||
|
|
||||||
quant_config = config.get("quantization_config")
|
quant_config = config.get("quantization_config")
|
||||||
config = {
|
config = {
|
||||||
"model_type": "whixtral",
|
"model_type": "voxtral",
|
||||||
"architectures": ["VoxtralForConditionalGeneration"],
|
"architectures": [architecture],
|
||||||
"text_config": PretrainedConfig.from_dict(config),
|
"text_config": PretrainedConfig.from_dict(config),
|
||||||
"audio_config": WhisperConfig(
|
"audio_config": WhisperConfig(
|
||||||
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
|
num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
|
||||||
window_size=encoder_args["audio_encoding_args"]["window_size"],
|
window_size=encoder_args["audio_encoding_args"]["window_size"],
|
||||||
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
|
sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
|
||||||
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
|
hop_length=encoder_args["audio_encoding_args"]["hop_length"],
|
||||||
downsample_factor=downsample_args["downsample_factor"],
|
downsample_factor=downsample_factor,
|
||||||
d_model=encoder_args["dim"],
|
d_model=encoder_args["dim"],
|
||||||
encoder_layers=encoder_args["n_layers"],
|
encoder_layers=encoder_args["n_layers"],
|
||||||
encoder_ffn_dim=encoder_args["hidden_dim"],
|
encoder_ffn_dim=encoder_args["hidden_dim"],
|
||||||
@ -203,6 +227,10 @@ def _remap_mistral_audio_args(config: dict) -> dict:
|
|||||||
vocab_size=encoder_args["vocab_size"],
|
vocab_size=encoder_args["vocab_size"],
|
||||||
max_source_positions=encoder_args["max_source_positions"],
|
max_source_positions=encoder_args["max_source_positions"],
|
||||||
is_encoder_decoder=False, # Override WhisperConfig default
|
is_encoder_decoder=False, # Override WhisperConfig default
|
||||||
|
is_causal=encoder_args.get("causal", False),
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
block_pool_size=block_pool_size,
|
||||||
|
pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
if quant_config:
|
if quant_config:
|
||||||
|
|||||||
@ -835,6 +835,15 @@ def subclass_attention_backend(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def subclass_attention_backend_with_overrides(
|
||||||
|
name_prefix: str,
|
||||||
|
attention_backend_cls: type[AttentionBackend],
|
||||||
|
overrides: dict[str, Any],
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||||
|
return type(name, (attention_backend_cls,), overrides)
|
||||||
|
|
||||||
|
|
||||||
def split_decodes_prefills_and_extends(
|
def split_decodes_prefills_and_extends(
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
decode_threshold: int = 1,
|
decode_threshold: int = 1,
|
||||||
|
|||||||
@ -2457,6 +2457,17 @@ class GPUModelRunner(
|
|||||||
return round_up(num_scheduled_tokens, tp_size)
|
return round_up(num_scheduled_tokens, tp_size)
|
||||||
return num_scheduled_tokens
|
return num_scheduled_tokens
|
||||||
|
|
||||||
|
def _prepare_mm_inputs(
|
||||||
|
self, num_tokens: int
|
||||||
|
) -> tuple[torch.Tensor | None, torch.Tensor]:
|
||||||
|
if self.model.requires_raw_input_tokens:
|
||||||
|
input_ids = self.input_ids.gpu[:num_tokens]
|
||||||
|
else:
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
|
||||||
|
return input_ids, inputs_embeds
|
||||||
|
|
||||||
def _preprocess(
|
def _preprocess(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@ -2499,8 +2510,7 @@ class GPUModelRunner(
|
|||||||
# TODO(woosuk): Avoid the copy. Optimize.
|
# TODO(woosuk): Avoid the copy. Optimize.
|
||||||
self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled)
|
self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled)
|
||||||
|
|
||||||
input_ids = None
|
input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens)
|
||||||
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
**self._init_model_kwargs(num_scheduled_tokens),
|
**self._init_model_kwargs(num_scheduled_tokens),
|
||||||
**self._extract_mm_kwargs(scheduler_output),
|
**self._extract_mm_kwargs(scheduler_output),
|
||||||
@ -4220,8 +4230,8 @@ class GPUModelRunner(
|
|||||||
assert num_tokens_padded <= self.max_num_tokens
|
assert num_tokens_padded <= self.max_num_tokens
|
||||||
model_kwargs = self._init_model_kwargs(num_tokens_padded)
|
model_kwargs = self._init_model_kwargs(num_tokens_padded)
|
||||||
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
|
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
|
||||||
input_ids = None
|
input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded)
|
||||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
|
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
**self._dummy_mm_kwargs(num_reqs),
|
**self._dummy_mm_kwargs(num_reqs),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user