[Model] Whisper model implementation (#11280)

Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
This commit is contained in:
Aurick Qiao 2025-01-03 03:39:19 -05:00 committed by GitHub
parent fd3a62a122
commit e1a5c2f0a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1045 additions and 55 deletions

View File

@ -363,12 +363,14 @@ steps:
- tests/models/decoder_only/audio_language
- tests/models/decoder_only/vision_language
- tests/models/embedding/vision_language
- tests/models/encoder_decoder/audio_language
- tests/models/encoder_decoder/vision_language
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model
- pytest -v -s models/encoder_decoder/audio_language -m core_model
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model

View File

@ -0,0 +1,59 @@
import time
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
# Create a Whisper encoder/decoder model instance
llm = LLM(
model="openai/whisper-large-v3",
max_model_len=448,
max_num_seqs=400,
limit_mm_per_prompt={"audio": 1},
kv_cache_dtype="fp8",
)
prompts = [
{
"prompt": "<|startoftranscript|>",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
{ # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
"decoder_prompt": "<|startoftranscript|>",
}
] * 1024
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
max_tokens=200,
)
start = time.time()
# Generate output tokens from the prompts. The output is a list of
# RequestOutput objects that contain the prompt, generated
# text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
print(f"Encoder prompt: {encoder_prompt!r}, "
f"Decoder prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
duration = time.time() - start
print("Duration:", duration)
print("RPS:", len(prompts) / duration)

View File

@ -0,0 +1,136 @@
"""Compare the outputs of HF and vLLM for Whisper models using greedy sampling.
Run `pytest tests/models/encoder_decoder/audio/test_whisper.py`.
"""
from typing import Optional
import pytest
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from ....utils import fork_new_process_for_each_test, multi_gpu_test
PROMPTS = [
{
"prompt":
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
},
},
{ # Test explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": AudioAsset("winning_call").audio_and_sample_rate,
},
},
"decoder_prompt":
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>",
}
]
EXPECTED = {
"openai/whisper-tiny": [
" He has birth words I spoke in the original corner of that. And a"
" little piece of black coat poetry. Mary had a little sandwich,"
" sweet, with white and snow. And everyone had it very went the last"
" would sure to go.",
" >> And the old one, fit John the way to Edgar Martinez. >> One more"
" to line down the field line for our base camp. Here comes joy. Here"
" is June and the third base. They're going to wave him in. The throw"
" to the plate will be late. The Mariners are going to play for the"
" American League Championship. I don't believe it. It just continues"
" by all five."
],
"openai/whisper-small": [
" The first words I spoke in the original pornograph. A little piece"
" of practical poetry. Mary had a little lamb, its fleece was quite a"
" slow, and everywhere that Mary went the lamb was sure to go.",
" And the old one pitch on the way to Edgar Martinez one month. Here"
" comes joy. Here is Junior to third base. They're gonna wave him"
" in. The throw to the plate will be late. The Mariners are going to"
" play for the American League Championship. I don't believe it. It"
" just continues. My, oh my."
],
"openai/whisper-medium": [
" The first words I spoke in the original phonograph, a little piece"
" of practical poetry. Mary had a little lamb, its fleece was quite as"
" slow, and everywhere that Mary went the lamb was sure to go.",
" And the 0-1 pitch on the way to Edgar Martinez swung on the line"
" down the left field line for Obeyshev. Here comes Joy. Here is"
" Jorgen at third base. They're going to wave him in. The throw to the"
" plate will be late. The Mariners are going to play for the American"
" League Championship. I don't believe it. It just continues. My, oh"
" my."
],
"openai/whisper-large-v3": [
" The first words I spoke in the original phonograph, a little piece"
" of practical poetry. Mary had a little lamb, its feet were quite as"
" slow, and everywhere that Mary went, the lamb was sure to go.",
" And the 0-1 pitch on the way to Edgar Martinez. Swung on the line."
" Now the left field line for a base hit. Here comes Joy. Here is"
" Junior to third base. They're going to wave him in. The throw to the"
" plate will be late. The Mariners are going to play for the American"
" League Championship. I don't believe it. It just continues. My, oh,"
" my."
],
"openai/whisper-large-v3-turbo": [
" The first words I spoke in the original phonograph, a little piece"
" of practical poetry. Mary had a little lamb, its streets were quite"
" as slow, and everywhere that Mary went the lamb was sure to go.",
" And the 0-1 pitch on the way to Edgar Martinez. Swung on the line"
" down the left field line for a base hit. Here comes Joy. Here is"
" Junior to third base. They're going to wave him in. The throw to the"
" plate will be late. The Mariners are going to play for the American"
" League Championship. I don't believe it. It just continues. My, oh,"
" my."
]
}
def run_test(
model: str,
*,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
) -> None:
prompt_list = PROMPTS * 10
expected_list = EXPECTED[model] * 10
llm = LLM(
model=model,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
)
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
max_tokens=200,
)
outputs = llm.generate(prompt_list, sampling_params)
for output, expected in zip(outputs, expected_list):
print(output.outputs[0].text)
assert output.outputs[0].text == expected
@fork_new_process_for_each_test
@pytest.mark.core_model
@pytest.mark.parametrize(
"model", ["openai/whisper-small", "openai/whisper-large-v3-turbo"])
def test_models(model) -> None:
run_test(model, tensor_parallel_size=1)
@multi_gpu_test(num_gpus=2)
@pytest.mark.core_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
def test_models_distributed(model, distributed_executor_backend) -> None:
run_test(model,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend)

View File

@ -204,6 +204,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"),
# [Encoder-decoder]
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
}
_SPECULATIVE_DECODING_EXAMPLE_MODELS = {

View File

@ -2312,6 +2312,8 @@ def _get_and_verify_max_len(
"seq_length",
# Command-R
"model_max_length",
# Whisper
"max_target_positions",
# Others
"max_sequence_length",
"max_seq_length",

View File

@ -184,10 +184,16 @@ class InputPreprocessor:
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()
add_special_tokens = None
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
lora_request=lora_request,
add_special_tokens=add_special_tokens)
async def _tokenize_prompt_async(
self,
@ -197,10 +203,17 @@ class InputPreprocessor:
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group()
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
add_special_tokens = None
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False
return await tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
def _can_process_multimodal(self) -> bool:
model_config = self.model_config
@ -439,8 +452,15 @@ class InputPreprocessor:
assert_never(encoder_inputs) # type: ignore[arg-type]
if decoder_inputs is None:
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
None)
if self.model_config.hf_config.model_type == "whisper":
# For Whisper models, the text prompt should go to the decoder.
# If no explicit encoder/decoder inputs, then copy the prompt
# from the encoder to the decoder. The encoder tokens are later
# overridden by the audio features.
dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
else:
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
None)
decoder_inputs = token_inputs(dec_token_ids)
elif (decoder_inputs["type"] == "token"
or decoder_inputs["type"] == "multimodal"):

View File

@ -170,6 +170,7 @@ _MULTIMODAL_MODELS = {
"UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder]
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}
_SPECULATIVE_DECODING_MODELS = {

View File

@ -0,0 +1,737 @@
import math
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import numpy as np
import torch
from torch import nn
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.audio import resample_audio
from vllm.sequence import SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from .interfaces import SupportsMultiModal
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
logger = init_logger(__name__)
class WhisperAudioInputs(TypedDict):
input_features: NestedTensors
"""Shape: `(batch_size, 128, M)`"""
class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self,
num_positions: int,
embedding_dim: int,
padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim)
def forward(self, position_ids):
return self.weight[position_ids]
class WhisperAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
attn_type: AttentionType = AttentionType.DECODER,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = embed_dim
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
if self.total_num_heads >= tp_size:
# Number of heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_heads % tp_size == 0
else:
# Number of heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_heads == 0
self.num_kv_heads = max(1, self.total_num_heads // tp_size)
self.head_dim = self.embed_dim // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn_type = attn_type
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: "
f"{self.embed_dim} and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
self.out_proj = RowParallelLinear(
input_size=embed_dim,
output_size=embed_dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def _init_qkv(
self,
embed_dim: int,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
self.qkv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=self.attn_type)
output, _ = self.out_proj(attn_output)
return output
class WhisperCrossAttention(WhisperAttention):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__(
embed_dim=embed_dim,
num_heads=num_heads,
bias=bias,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)
def _init_qkv(
self,
embed_dim: int,
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
self.q_proj = ColumnParallelLinear(
input_size=embed_dim,
output_size=embed_dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_dim,
total_num_heads=0,
total_num_kv_heads=self.total_num_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.kv_proj",
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
q, _ = self.q_proj(hidden_states)
# Encoder hidden states are only computed once during prefill phase.
# Afterwards, the keys and values should be available in the kv-cache.
if encoder_hidden_states is not None:
kv, _ = self.kv_proj(encoder_hidden_states)
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
else:
k = v = None
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
output, _ = self.out_proj(attn_output)
return output
class WhisperMLP(nn.Module):
def __init__(
self,
embed_dim: int,
ffn_dim: int,
act_fn: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.activation_fn = get_act_fn(act_fn)
self.fc1 = ColumnParallelLinear(
input_size=embed_dim,
output_size=ffn_dim,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
input_size=ffn_dim,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor):
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class WhisperEncoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embed_dim = config.d_model
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
attn_type=AttentionType.ENCODER,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.mlp = WhisperMLP(
embed_dim=config.d_model,
ffn_dim=config.encoder_ffn_dim,
act_fn=config.activation_function,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.isinf().any() or hidden_states.isnan().any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states,
min=-clamp_value,
max=clamp_value)
return hidden_states
class WhisperDecoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.self_attn = WhisperAttention(
embed_dim=config.d_model,
num_heads=config.decoder_attention_heads,
attn_type=AttentionType.DECODER,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
self.encoder_attn = WhisperCrossAttention(
embed_dim=config.d_model,
num_heads=config.decoder_attention_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.encoder_attn",
)
self.encoder_attn_layer_norm = nn.LayerNorm(config.d_model)
self.mlp = WhisperMLP(
embed_dim=config.d_model,
ffn_dim=config.decoder_ffn_dim,
act_fn=config.activation_function,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.final_layer_norm = nn.LayerNorm(config.d_model)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
hidden_states = self.encoder_attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class WhisperEncoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(embed_dim)
if config.scale_embedding else 1.0)
self.conv1 = nn.Conv1d(self.num_mel_bins,
embed_dim,
kernel_size=3,
padding=1)
self.conv2 = nn.Conv1d(embed_dim,
embed_dim,
kernel_size=3,
stride=2,
padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions,
embed_dim)
self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers,
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers"),
prefix=f"{prefix}.layers",
)
self.layer_norm = nn.LayerNorm(config.d_model)
with torch.no_grad():
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape))
def forward(
self,
input_features: Union[torch.Tensor, List[torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
):
hidden_states = []
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
embeds = embeds.permute(1, 0)
embeds = embeds + self.embed_positions.weight[:embeds.size(0), :]
hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class WhisperDecoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_target_positions
self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(config.d_model)
if config.scale_embedding else 1.0)
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model,
self.padding_idx)
self.embed_positions = WhisperPositionalEmbedding(
self.max_target_positions, config.d_model)
self.start_layer, self.end_layer, self.layers = make_layers(
config.decoder_layers,
lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers"),
prefix=f"{prefix}.layers",
)
self.layer_norm = nn.LayerNorm(config.d_model)
def forward(
self,
input_ids,
positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
):
inputs_embeds = self.get_input_embeddings(input_ids)
positions = self.embed_positions(positions)
hidden_states = inputs_embeds + positions
for idx, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)
class WhisperModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.encoder = WhisperEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
self.decoder = WhisperDecoder(vllm_config=vllm_config,
prefix=f"{prefix}.decoder")
def forward(
self,
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
encoder_outputs = self.get_encoder_outputs(
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
encoder_hidden_states=encoder_outputs,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return decoder_outputs
def get_encoder_outputs(
self,
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> Optional[torch.Tensor]:
if input_features is None:
return None
return self.encoder(
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".encoder_attn.kv_proj", ".encoder_attn.k_proj", "k"),
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def get_max_whisper_audio_tokens(ctx: InputContext) -> int:
return ctx.model_config.hf_config.max_source_positions
def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
assert mm_counts["audio"] == 1
num_tokens = get_max_whisper_audio_tokens(ctx)
processor = cached_get_processor(ctx.model_config.model)
chunk_length = processor.feature_extractor.chunk_length
sampling_rate = processor.feature_extractor.sampling_rate
num_samples = chunk_length * sampling_rate
return DummyData(
SequenceData.from_prompt_token_counts((0, num_tokens)),
{"audio": [(np.zeros(num_samples), sampling_rate)]},
)
def input_processor_for_whisper(ctx: InputContext, inputs):
multi_modal_data = inputs["encoder"]["multi_modal_data"]
if isinstance(multi_modal_data["audio"], list):
assert len(multi_modal_data["audio"]) == 1
multi_modal_data["audio"] = multi_modal_data["audio"][0]
# Resample and process audio
audio, orig_sr = multi_modal_data["audio"]
processor = cached_get_processor(ctx.model_config.model)
target_sr = processor.feature_extractor.sampling_rate
audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr)
multi_modal_data["audio"] = (audio, target_sr)
# Pre-allocate placeholder tokens in encoder sequence
num_tokens = get_max_whisper_audio_tokens(ctx)
inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens
return inputs
def input_mapper_for_whisper(
ctx: InputContext,
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
) -> MultiModalKwargs:
if not isinstance(multi_modal_data, list):
multi_modal_data = [multi_modal_data]
assert len(multi_modal_data) == 1
if len(multi_modal_data) == 0:
return MultiModalKwargs()
processor = cached_get_processor(ctx.model_config.model)
sampling_rate = processor.feature_extractor.sampling_rate
audios = [audio for audio, _ in multi_modal_data]
kwargs = processor(audios,
sampling_rate=sampling_rate,
return_tensors="pt")
kwargs["input_features"] = kwargs["input_features"].squeeze(0).to(
ctx.model_config.dtype)
return MultiModalKwargs(kwargs)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper)
@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.dtype = vllm_config.model_config.dtype
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
self.unpadded_vocab_size = config.vocab_size
self.proj_out = ParallelLMHead(config.vocab_size,
config.d_model,
quant_config=quant_config)
self.proj_out = self.proj_out.tie_weights(
self.model.decoder.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
audio_input = self._parse_and_validate_audio_input(**kwargs)
decoder_outputs = self.model(
input_features=audio_input["input_features"],
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return decoder_outputs
def get_multimodal_embeddings(
self,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> Optional[NestedTensors]:
# TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return self.model.get_encoder_outputs(
audio_input["input_features"],
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
# TODO: This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. Refactor this once
# encoder/decoder support is implemented in V1.
return self.model.decoder.get_input_embeddings(input_ids)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> WhisperAudioInputs:
input_features = kwargs.pop("input_features", None)
if input_features is not None:
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(input_features)}")
input_features = [feat.to(self.dtype) for feat in input_features]
return WhisperAudioInputs(input_features=input_features)
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.proj_out, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
loaded_weights = [(name, loaded_weight)
for name, loaded_weight in weights]
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
return loader.load_weights(loaded_weights, mapper=mapper)

View File

@ -16,7 +16,7 @@ from transformers import BatchFeature, ProcessorMixin
from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, encode_tokens
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
@ -57,24 +57,6 @@ class PromptReplacement:
)
def _encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: bool = False,
) -> list[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
@lru_cache(maxsize=2048)
def _cached_encode(
tokenizer: AnyTokenizer,
@ -82,7 +64,9 @@ def _cached_encode(
*,
add_special_tokens: bool = False,
) -> list[int]:
return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
return encode_tokens(tokenizer,
text,
add_special_tokens=add_special_tokens)
def _decode(
@ -983,7 +967,9 @@ class BaseMultiModalProcessor(ABC):
mm_item_counts,
)
token_ids = _encode(tokenizer, text)
token_ids = encode_tokens(tokenizer,
text,
add_special_tokens=False)
matched_repls = [match.prompt_repl for match in text_matches]
placeholders = self._find_placeholders(matched_repls, token_ids,

View File

@ -710,15 +710,27 @@ class SequenceGroup:
@property
def multi_modal_data(self) -> MultiModalDataDict:
return self.first_seq.multi_modal_data
if self.first_seq.multi_modal_data:
return self.first_seq.multi_modal_data
elif self.encoder_seq is not None:
return self.encoder_seq.multi_modal_data
return {}
@property
def multi_modal_placeholders(self) -> MultiModalPlaceholderDict:
return self.first_seq.multi_modal_placeholders
if self.first_seq.multi_modal_data:
return self.first_seq.multi_modal_placeholders
elif self.encoder_seq is not None:
return self.encoder_seq.multi_modal_placeholders
return {}
@property
def mm_processor_kwargs(self) -> Dict[str, Any]:
return self.first_seq.mm_processor_kwargs
if self.first_seq.multi_modal_data:
return self.first_seq.mm_processor_kwargs
elif self.encoder_seq is not None:
return self.encoder_seq.mm_processor_kwargs
return {}
@property
def lora_int_id(self) -> int:

View File

@ -21,6 +21,25 @@ AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer]
def encode_tokens(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
elif add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text)
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties.

View File

@ -32,7 +32,8 @@ class BaseTokenizerGroup(ABC):
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass
@ -41,7 +42,8 @@ class BaseTokenizerGroup(ABC):
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass

View File

@ -112,7 +112,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
@ -132,7 +133,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
lora_request=lora_request,
add_special_tokens=add_special_tokens))
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
@ -143,7 +145,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
lora_request=lora_request,
add_special_tokens=add_special_tokens))
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
@ -160,7 +163,8 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
@ -177,9 +181,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor_is_alive = True
original_actor = actor
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
ret = await actor.encode.remote(
request_id=request_id,
prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
@ -187,9 +193,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
exc_info=e)
actor = self._init_actor()
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
ret = await actor.encode.remote(
request_id=request_id,
prompt=prompt,
lora_request=lora_request,
add_special_tokens=add_special_tokens)
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "

View File

@ -2,7 +2,7 @@ from typing import List, Optional
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
get_lora_tokenizer,
get_lora_tokenizer_async,
get_tokenizer)
@ -55,9 +55,12 @@ class TokenizerGroup(BaseTokenizerGroup):
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
ret = tokenizer.encode(prompt)
ret = encode_tokens(tokenizer,
prompt,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret
@ -65,9 +68,12 @@ class TokenizerGroup(BaseTokenizerGroup):
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
ret = tokenizer.encode(prompt)
ret = encode_tokens(tokenizer,
prompt,
add_special_tokens=add_special_tokens)
self._raise_if_input_too_long(ret, lora_request)
return ret

View File

@ -287,12 +287,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
seq_len,
self.mm_registry,
is_encoder_data=False)
encoder_dummy_data \
= self.input_registry.dummy_data_for_profiling(
self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=True)
encoder_dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry,
is_encoder_data=True)
# Having more tokens is over-conservative but otherwise fine
assert len(