mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
971 lines
33 KiB
Python
971 lines
33 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import math
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from contextlib import nullcontext
|
|
from typing import Annotated, Literal, cast
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from transformers import (
|
|
BatchFeature,
|
|
WhisperConfig,
|
|
WhisperFeatureExtractor,
|
|
WhisperProcessor,
|
|
)
|
|
from transformers.models.whisper.modeling_whisper import sinusoids
|
|
|
|
from vllm.attention import Attention, AttentionType
|
|
from vllm.attention.layer import MultiHeadAttention
|
|
from vllm.attention.layers.cross_attention import CrossAttention
|
|
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.inputs.data import PromptType
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.activation import get_act_fn
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
)
|
|
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
|
|
from vllm.multimodal.processing import (
|
|
BaseProcessingInfo,
|
|
EncDecMultiModalProcessor,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
)
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
|
from vllm.transformers_utils.processor import cached_get_processor
|
|
from vllm.utils.jsontree import json_map_leaves
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
from vllm.utils.torch_utils import set_default_torch_dtype
|
|
|
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
|
|
from .utils import (
|
|
AutoWeightsLoader,
|
|
WeightsMapper,
|
|
cast_overflow_tensors,
|
|
make_layers,
|
|
maybe_prefix,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
|
|
|
ISO639_1_SUPPORTED_LANGS = {
|
|
"af": "Afrikaans",
|
|
"ar": "Arabic",
|
|
"hy": "Armenian",
|
|
"az": "Azerbaijani",
|
|
"be": "Belarusian",
|
|
"bs": "Bosnian",
|
|
"bg": "Bulgarian",
|
|
"ca": "Catalan",
|
|
"zh": "Chinese",
|
|
"hr": "Croatian",
|
|
"cs": "Czech",
|
|
"da": "Danish",
|
|
"nl": "Dutch",
|
|
"en": "English",
|
|
"et": "Estonian",
|
|
"fi": "Finnish",
|
|
"fr": "French",
|
|
"gl": "Galician",
|
|
"de": "German",
|
|
"el": "Greek",
|
|
"he": "Hebrew",
|
|
"hi": "Hindi",
|
|
"hu": "Hungarian",
|
|
"is": "Icelandic",
|
|
"id": "Indonesian",
|
|
"it": "Italian",
|
|
"ja": "Japanese",
|
|
"kn": "Kannada",
|
|
"kk": "Kazakh",
|
|
"ko": "Korean",
|
|
"lv": "Latvian",
|
|
"lt": "Lithuanian",
|
|
"mk": "Macedonian",
|
|
"ms": "Malay",
|
|
"mr": "Marathi",
|
|
"mi": "Maori",
|
|
"ne": "Nepali",
|
|
"no": "Norwegian",
|
|
"fa": "Persian",
|
|
"pl": "Polish",
|
|
"pt": "Portuguese",
|
|
"ro": "Romanian",
|
|
"ru": "Russian",
|
|
"sr": "Serbian",
|
|
"sk": "Slovak",
|
|
"sl": "Slovenian",
|
|
"es": "Spanish",
|
|
"sw": "Swahili",
|
|
"sv": "Swedish",
|
|
"tl": "Tagalog",
|
|
"ta": "Tamil",
|
|
"th": "Thai",
|
|
"tr": "Turkish",
|
|
"uk": "Ukrainian",
|
|
"ur": "Urdu",
|
|
"vi": "Vietnamese",
|
|
"cy": "Welsh",
|
|
}
|
|
|
|
|
|
class WhisperAudioInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- b: Batch size
|
|
- nmb: Number of mel bins
|
|
- t: Time frames (M)
|
|
"""
|
|
|
|
input_features: Annotated[
|
|
list[torch.Tensor] | None,
|
|
TensorShape("b", "nmb", "t"),
|
|
]
|
|
|
|
|
|
class WhisperEncoderAttention(MultiHeadAttention):
|
|
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Input shape: batch_size x seq_len x hidden_size
|
|
or seq_len x hidden_size
|
|
"""
|
|
is_2d = query.dim() == 2
|
|
if is_2d:
|
|
query = query.unsqueeze(0)
|
|
key = key.unsqueeze(0)
|
|
value = value.unsqueeze(0)
|
|
|
|
# Call the parent forward method
|
|
out = super().forward(query, key, value)
|
|
|
|
if is_2d:
|
|
out = out.squeeze(0)
|
|
|
|
return out
|
|
|
|
|
|
class WhisperPositionalEmbedding(nn.Embedding):
|
|
def __init__(self, num_positions: int, embedding_dim: int):
|
|
super().__init__(num_positions, embedding_dim)
|
|
|
|
def forward(self, position_ids):
|
|
return self.weight[position_ids]
|
|
|
|
|
|
class WhisperAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
bias: bool = True,
|
|
attn_type: AttentionType = AttentionType.DECODER,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = num_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
if self.total_num_heads >= tp_size:
|
|
# Number of heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_heads % tp_size == 0
|
|
else:
|
|
# Number of heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert tp_size % self.total_num_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_heads // tp_size)
|
|
self.head_dim = self.embed_dim // self.total_num_heads
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.attn_type = attn_type
|
|
|
|
if (self.head_dim * num_heads) != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: "
|
|
f"{self.embed_dim} and `num_heads`: {num_heads})."
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
|
|
self.out_proj = RowParallelLinear(
|
|
input_size=embed_dim,
|
|
output_size=embed_dim,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.out_proj",
|
|
)
|
|
if attn_type == AttentionType.ENCODER:
|
|
self.attn = WhisperEncoderAttention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
)
|
|
elif self.attn_type == AttentionType.ENCODER_DECODER:
|
|
self.attn = CrossAttention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
attn_type=self.attn_type,
|
|
)
|
|
else: # AttentionType.DECODER (regular decoder self-attention)
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
attn_type=self.attn_type,
|
|
)
|
|
|
|
def _init_qkv(
|
|
self,
|
|
embed_dim: int,
|
|
bias: bool = True,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size=embed_dim,
|
|
head_size=self.head_dim,
|
|
total_num_heads=self.total_num_heads,
|
|
total_num_kv_heads=self.total_num_heads,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
):
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
|
|
attn_output = self.attn(q, k, v)
|
|
|
|
output, _ = self.out_proj(attn_output)
|
|
|
|
return output
|
|
|
|
|
|
class WhisperCrossAttention(WhisperAttention):
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
bias: bool = True,
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__(
|
|
embed_dim=embed_dim,
|
|
num_heads=num_heads,
|
|
bias=bias,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
attn_type=AttentionType.ENCODER_DECODER,
|
|
)
|
|
|
|
def _init_qkv(
|
|
self,
|
|
embed_dim: int,
|
|
bias: bool = True,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
self.q_proj = ColumnParallelLinear(
|
|
input_size=embed_dim,
|
|
output_size=embed_dim,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_proj",
|
|
)
|
|
self.kv_proj = QKVParallelLinear(
|
|
hidden_size=embed_dim,
|
|
head_size=self.head_dim,
|
|
total_num_heads=0,
|
|
total_num_kv_heads=self.total_num_heads,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.kv_proj",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor | None,
|
|
):
|
|
q, _ = self.q_proj(hidden_states)
|
|
|
|
# Encoder hidden states are only computed once during prefill phase.
|
|
# Afterwards, the keys and values should be available in the kv-cache.
|
|
if encoder_hidden_states is not None:
|
|
kv, _ = self.kv_proj(encoder_hidden_states)
|
|
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
|
|
else:
|
|
k = v = None
|
|
|
|
attn_output = self.attn(q, k, v)
|
|
|
|
output, _ = self.out_proj(attn_output)
|
|
|
|
return output
|
|
|
|
|
|
class WhisperMLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
ffn_dim: int,
|
|
act_fn: str,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
|
|
self.activation_fn = get_act_fn(act_fn)
|
|
self.fc1 = ColumnParallelLinear(
|
|
input_size=embed_dim,
|
|
output_size=ffn_dim,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fc1",
|
|
)
|
|
self.fc2 = RowParallelLinear(
|
|
input_size=ffn_dim,
|
|
output_size=embed_dim,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fc2",
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
hidden_states, _ = self.fc1(hidden_states)
|
|
hidden_states = self.activation_fn(hidden_states)
|
|
hidden_states, _ = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class WhisperEncoderLayer(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.embed_dim = config.d_model
|
|
self.self_attn = WhisperAttention(
|
|
embed_dim=self.embed_dim,
|
|
num_heads=config.encoder_attention_heads,
|
|
attn_type=AttentionType.ENCODER,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
self.mlp = WhisperMLP(
|
|
embed_dim=config.d_model,
|
|
ffn_dim=config.encoder_ffn_dim,
|
|
act_fn=config.activation_function,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
):
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
residual = hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
hidden_states = cast_overflow_tensors(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class WhisperDecoderLayer(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.self_attn = WhisperAttention(
|
|
embed_dim=config.d_model,
|
|
num_heads=config.decoder_attention_heads,
|
|
attn_type=AttentionType.DECODER,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
|
|
self.encoder_attn = WhisperCrossAttention(
|
|
embed_dim=config.d_model,
|
|
num_heads=config.decoder_attention_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.encoder_attn",
|
|
)
|
|
self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model)
|
|
self.mlp = WhisperMLP(
|
|
embed_dim=config.d_model,
|
|
ffn_dim=config.decoder_ffn_dim,
|
|
act_fn=config.activation_function,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
self.final_layer_norm = nn.LayerNorm(config.d_model)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor | None,
|
|
):
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
hidden_states = self.self_attn(hidden_states=hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
hidden_states = self.encoder_attn(
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class WhisperEncoder(nn.Module):
|
|
def __init__(
|
|
self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
|
|
):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
embed_dim = config.d_model
|
|
self.num_mel_bins = config.num_mel_bins
|
|
self.max_source_positions = config.max_source_positions
|
|
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
|
|
|
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.encoder_layers,
|
|
lambda prefix: WhisperEncoderLayer(
|
|
vllm_config=vllm_config, prefix=f"{prefix}.layers"
|
|
),
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
|
|
|
maybe_fp32_init_ctx = (
|
|
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
|
|
)
|
|
|
|
with (
|
|
torch.no_grad(),
|
|
maybe_fp32_init_ctx,
|
|
):
|
|
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
|
self.embed_positions.weight.copy_(
|
|
sinusoids(*self.embed_positions.weight.shape)
|
|
)
|
|
|
|
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
|
hidden_states = []
|
|
for features in input_features:
|
|
embeds = nn.functional.gelu(self.conv1(features))
|
|
embeds = nn.functional.gelu(self.conv2(embeds))
|
|
embeds = embeds.transpose(-1, -2)
|
|
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
|
|
embeds.dtype
|
|
)
|
|
hidden_states.append(embeds)
|
|
hidden_states = torch.cat(hidden_states)
|
|
|
|
for encoder_layer in self.layers:
|
|
hidden_states = encoder_layer(hidden_states)
|
|
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class WhisperDecoder(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
self.layerdrop = config.decoder_layerdrop
|
|
self.padding_idx = config.pad_token_id
|
|
self.max_target_positions = config.max_target_positions
|
|
self.max_source_positions = config.max_source_positions
|
|
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
|
|
|
self.embed_tokens = nn.Embedding(
|
|
config.vocab_size, config.d_model, self.padding_idx
|
|
)
|
|
self.embed_positions = WhisperPositionalEmbedding(
|
|
self.max_target_positions, config.d_model
|
|
)
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.decoder_layers,
|
|
lambda prefix: WhisperDecoderLayer(
|
|
vllm_config=vllm_config, prefix=f"{prefix}.layers"
|
|
),
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
self.layer_norm = nn.LayerNorm(config.d_model)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
positions: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor | None,
|
|
):
|
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
|
positions = self.embed_positions(positions)
|
|
hidden_states = inputs_embeds + positions
|
|
|
|
for decoder_layer in self.layers:
|
|
hidden_states = decoder_layer(
|
|
hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
return hidden_states
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
|
|
class WhisperModel(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
self.encoder = WhisperEncoder(
|
|
vllm_config=vllm_config, prefix=f"{prefix}.encoder"
|
|
)
|
|
self.decoder = WhisperDecoder(
|
|
vllm_config=vllm_config, prefix=f"{prefix}.decoder"
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_features: torch.Tensor | list[torch.Tensor] | None,
|
|
input_ids: torch.Tensor | None,
|
|
positions: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
encoder_outputs = self.get_encoder_outputs(input_features)
|
|
decoder_outputs = self.decoder(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
encoder_hidden_states=encoder_outputs,
|
|
)
|
|
return decoder_outputs
|
|
|
|
def get_encoder_outputs(
|
|
self,
|
|
input_features: torch.Tensor | list[torch.Tensor] | None,
|
|
) -> torch.Tensor | None:
|
|
if input_features is None:
|
|
return None
|
|
return self.encoder(input_features)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
|
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
|
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
|
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"),
|
|
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class WhisperProcessingInfo(BaseProcessingInfo):
|
|
def get_hf_config(self) -> WhisperConfig:
|
|
return self.ctx.get_hf_config(WhisperConfig)
|
|
|
|
def get_hf_processor(self, **kwargs: object) -> WhisperProcessor:
|
|
# HACK: Transformers 4.53.2 has issue with whisper tokenizer to
|
|
# initialize processor. We use a monkeypatch to fix it here.
|
|
# See: https://github.com/vllm-project/vllm/issues/20224
|
|
processor_class = WhisperProcessor
|
|
tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast")
|
|
if processor_class.tokenizer_class != tokenizer_class:
|
|
processor_class.tokenizer_class = tokenizer_class
|
|
return self.ctx.get_hf_processor(processor_class, **kwargs)
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
return {"audio": 1}
|
|
|
|
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
|
|
hf_processor = self.get_hf_processor(**kwargs)
|
|
feature_extractor = hf_processor.feature_extractor # type: ignore
|
|
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
|
return feature_extractor
|
|
|
|
def get_num_audio_tokens(self) -> int:
|
|
return self.get_hf_config().max_source_positions
|
|
|
|
|
|
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_audios = mm_counts.get("audio", 0)
|
|
|
|
return "<|startoftranscript|>" * num_audios
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
) -> MultiModalDataDict:
|
|
feature_extractor = self.info.get_feature_extractor()
|
|
|
|
sampling_rate = feature_extractor.sampling_rate
|
|
audio_len = feature_extractor.chunk_length * sampling_rate
|
|
num_audios = mm_counts.get("audio", 0)
|
|
|
|
audio_overrides = mm_options.get("audio") if mm_options else None
|
|
|
|
return {
|
|
"audio": self._get_dummy_audios(
|
|
length=audio_len, num_audios=num_audios, overrides=audio_overrides
|
|
)
|
|
}
|
|
|
|
|
|
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
|
|
def _get_data_parser(self) -> MultiModalDataParser:
|
|
feature_extractor = self.info.get_feature_extractor()
|
|
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
|
|
|
@property
|
|
def pad_dummy_encoder_prompt(self) -> bool:
|
|
return True
|
|
|
|
def create_encoder_prompt(
|
|
self,
|
|
prompt: str | list[int],
|
|
mm_data: MultiModalDataDict,
|
|
) -> str | list[int]:
|
|
# Strictly speaking, whisper encoder only accept audio features.
|
|
# We create a dummy encoder prompt here which will be padded to
|
|
# num_audio_tokens. So that we can create dummy data from this
|
|
# for encoder profiling.
|
|
return [0]
|
|
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
tok_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
if mm_data:
|
|
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
|
mm_data = dict(audio=mm_data.pop("audios"))
|
|
mm_kwargs = dict(
|
|
**mm_kwargs,
|
|
sampling_rate=feature_extractor.sampling_rate,
|
|
)
|
|
processed_outputs = super()._call_hf_processor(
|
|
prompt=prompt,
|
|
mm_data=mm_data,
|
|
mm_kwargs=mm_kwargs,
|
|
tok_kwargs=tok_kwargs,
|
|
)
|
|
if "labels" in processed_outputs:
|
|
processed_outputs["input_ids"] = processed_outputs.pop("labels")
|
|
return processed_outputs
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(input_features=MultiModalFieldConfig.batched("audio"))
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
num_tokens = self.info.get_num_audio_tokens()
|
|
return [
|
|
PromptReplacement(
|
|
modality="audio",
|
|
target=[0],
|
|
replacement=[0] * num_tokens,
|
|
)
|
|
]
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
WhisperMultiModalProcessor,
|
|
info=WhisperProcessingInfo,
|
|
dummy_inputs=WhisperDummyInputsBuilder,
|
|
)
|
|
class WhisperForConditionalGeneration(
|
|
nn.Module, SupportsTranscription, SupportsMultiModal
|
|
):
|
|
merge_by_field_config = True
|
|
packed_modules_mapping = {
|
|
"self_attn.qkv_proj": [
|
|
"self_attn.q_proj",
|
|
"self_attn.k_proj",
|
|
"self_attn.v_proj",
|
|
],
|
|
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
|
|
}
|
|
|
|
hf_to_vllm_mapper = WeightsMapper(
|
|
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
|
|
)
|
|
|
|
# Whisper only supports audio-conditioned generation.
|
|
supports_transcription_only = True
|
|
supported_languages = ISO639_1_SUPPORTED_LANGS
|
|
|
|
@classmethod
|
|
def validate_language(cls, language: str | None) -> str | None:
|
|
if language is None:
|
|
# TODO language should be optional and can be guessed.
|
|
# For now we default to en. See
|
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
|
logger.warning(
|
|
"Defaulting to language='en'. If you wish to transcribe "
|
|
"audio in a different language, pass the `language` field "
|
|
"in the TranscriptionRequest."
|
|
)
|
|
language = "en"
|
|
return super().validate_language(language)
|
|
|
|
@classmethod
|
|
def get_generation_prompt(
|
|
cls,
|
|
audio: np.ndarray,
|
|
model_config: ModelConfig, # not needed here
|
|
stt_config: SpeechToTextConfig,
|
|
language: str | None,
|
|
task_type: Literal["transcribe", "translate"],
|
|
request_prompt: str,
|
|
to_language: str | None,
|
|
) -> PromptType:
|
|
if language is None:
|
|
raise ValueError(
|
|
"Language must be specified when creating the Whisper prompt"
|
|
)
|
|
prompt = {
|
|
"encoder_prompt": {
|
|
# Whisper does not support encoder prompt.
|
|
"prompt": "",
|
|
"multi_modal_data": {
|
|
"audio": (audio, stt_config.sample_rate),
|
|
},
|
|
},
|
|
"decoder_prompt": (
|
|
(f"<|prev|>{request_prompt}" if request_prompt else "")
|
|
+ f"<|startoftranscript|><|{language}|>"
|
|
+ f"<|{task_type}|><|notimestamps|>"
|
|
),
|
|
}
|
|
return cast(PromptType, prompt)
|
|
|
|
@classmethod
|
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
|
if modality.startswith("audio"):
|
|
return None
|
|
|
|
raise ValueError("Only audio modality is supported")
|
|
|
|
@classmethod
|
|
def get_speech_to_text_config(
|
|
cls, model_config: ModelConfig, task_type: str
|
|
) -> SpeechToTextConfig:
|
|
processor = cached_get_processor(model_config.model)
|
|
|
|
return SpeechToTextConfig(
|
|
max_audio_clip_s=processor.feature_extractor.chunk_length,
|
|
sample_rate=processor.feature_extractor.sampling_rate,
|
|
)
|
|
|
|
@classmethod
|
|
def get_num_audio_tokens(
|
|
cls,
|
|
audio_duration_s: float,
|
|
stt_config: SpeechToTextConfig,
|
|
model_config: ModelConfig,
|
|
) -> int | None:
|
|
processor = cached_get_processor(model_config.model)
|
|
hop_length = processor.feature_extractor.hop_length
|
|
assert hop_length is not None
|
|
# NOTE(NickLucche) user can't pass encoder
|
|
# prompts directly at least not to Whisper.
|
|
# One indicator of the encoder amount of processing
|
|
# is the log-mel spectogram length.
|
|
return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.dtype = vllm_config.model_config.dtype
|
|
|
|
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
|
|
|
|
self.proj_out = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.d_model,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "proj_out"),
|
|
)
|
|
self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
|
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
|
decoder_outputs = self.model(
|
|
input_features=audio_input["input_features"],
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
)
|
|
return decoder_outputs
|
|
|
|
def get_language_model(self) -> torch.nn.Module:
|
|
return self.model.decoder
|
|
|
|
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
|
# Required as part of SupportsMultiModal interface.
|
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
|
return [self.model.get_encoder_outputs(audio_input["input_features"])]
|
|
|
|
def get_input_embeddings(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
|
*,
|
|
is_multimodal: torch.Tensor | None = None,
|
|
handle_oov_mm_token: bool = False,
|
|
) -> torch.Tensor:
|
|
# This method just returns the decoder sequence embeddings since
|
|
# Whisper does not have encoder text tokens.
|
|
return self.model.decoder.get_input_embeddings(input_ids)
|
|
|
|
def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
|
|
input_features = kwargs.pop("input_features", None)
|
|
|
|
if input_features is not None:
|
|
input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features)
|
|
|
|
return WhisperAudioInputs(input_features=input_features)
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
logits = self.logits_processor(self.proj_out, hidden_states)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
|
|
|
|
# add fake zeros bias for k_proj to state_dict
|
|
weights = _create_fake_bias_for_k_proj(weights)
|
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
|
|
|
|
|
def _create_fake_bias_for_k_proj(
|
|
weights: Iterable[tuple[str, torch.Tensor]],
|
|
) -> Iterable[tuple[str, torch.Tensor]]:
|
|
"""
|
|
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
|
|
So that the bias for k_proj in qkv_proj can be initialized with zeros.
|
|
"""
|
|
for name, weight in weights:
|
|
if name.endswith(".k_proj.weight"):
|
|
bias = torch.zeros(weight.size(0))
|
|
bias_name = name.replace("weight", "bias")
|
|
yield from [(name, weight), (bias_name, bias)]
|
|
yield name, weight
|