From ba5106e519724d8f591ba24bbf7e700a4179eb25 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 23 Feb 2025 17:46:03 +0800 Subject: [PATCH] [LMM] Implement merged multimodal processor for whisper (#13278) --- .../multimodal/processing/test_common.py | 11 +- vllm/model_executor/models/whisper.py | 194 ++++++++++++------ vllm/multimodal/processing.py | 5 +- vllm/multimodal/profiling.py | 11 +- 4 files changed, 144 insertions(+), 77 deletions(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 331ffe82ec85d..0115863f56263 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -83,11 +83,11 @@ def _test_processing_correctness( } tokenizer_encode_kwargs = {} - if model_config.hf_config.model_type == "mllama": - # For Mllama, tokenizer will always add bos_token at the beginning of - # prompt by default, causing hf_processor outputs incorrect token ids. - # So we need use `add_special_tokens=False` here to leave bos_token - # to be added by the processor. + if model_config.hf_config.model_type in ("mllama", "whisper"): + # For some encoder-decoder models, tokenizer will always add bos_token + # at the beginning of prompt by default, causing hf_processor outputs + # incorrect token ids. So we need use `add_special_tokens=False` here + # to leave bos_token to be added by the processor. tokenizer_encode_kwargs = {"add_special_tokens": False} for batch_idx in range(num_batches): @@ -173,6 +173,7 @@ def _test_processing_correctness( "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", "fixie-ai/ultravox-v0_5-llama-3_2-1b", + "openai/whisper-large-v3", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 073a30d25e239..2ad1731144ef9 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -4,15 +4,15 @@ import math from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, Union) -import numpy as np import torch from torch import nn +from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, + WhisperProcessor) from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, - NestedTensors) -from vllm.multimodal.audio import resample_audio -from vllm.sequence import SequenceData -from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from .interfaces import SupportsMultiModal, SupportsTranscription from .utils import AutoWeightsLoader, WeightsMapper, make_layers @@ -571,72 +574,126 @@ class WhisperModel(nn.Module): return loaded_params -def get_max_whisper_audio_tokens(ctx: InputContext) -> int: - return ctx.model_config.hf_config.max_source_positions +class WhisperProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> WhisperConfig: + return self.ctx.get_hf_config(WhisperConfig) + + def get_hf_processor(self, + sampling_rate: Optional[int] = None + ) -> WhisperProcessor: + return self.ctx.get_hf_processor(WhisperProcessor) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": 1} + + def get_feature_extractor(self) -> WhisperFeatureExtractor: + hf_processor = self.get_hf_processor() + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_max_audio_tokens(self) -> int: + return self.get_hf_config().max_source_positions + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"audio": self.get_max_audio_tokens()} -def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - assert mm_counts["audio"] == 1 - num_tokens = get_max_whisper_audio_tokens(ctx) - processor = cached_processor_from_config(ctx.model_config) - chunk_length = processor.feature_extractor.chunk_length - sampling_rate = processor.feature_extractor.sampling_rate - num_samples = chunk_length * sampling_rate - return DummyData( - SequenceData.from_prompt_token_counts((0, num_tokens)), - {"audio": [(np.zeros(num_samples), sampling_rate)]}, - ) +class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + mm_data = { + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + return ProcessorInputs( + prompt_text="<|startoftranscript|>" * num_audios, + mm_data=mm_data, + ) -def input_processor_for_whisper(ctx: InputContext, inputs): - multi_modal_data = inputs["encoder"]["multi_modal_data"] - if isinstance(multi_modal_data["audio"], list): - assert len(multi_modal_data["audio"]) == 1 - multi_modal_data["audio"] = multi_modal_data["audio"][0] - # Resample and process audio - audio, orig_sr = multi_modal_data["audio"] - processor = cached_processor_from_config(ctx.model_config) - target_sr = processor.feature_extractor.sampling_rate - audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) - multi_modal_data["audio"] = (audio, target_sr) - # Pre-allocate placeholder tokens in encoder sequence - num_tokens = get_max_whisper_audio_tokens(ctx) - inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens - return inputs +class WhisperMultiModalProcessor( + EncDecMultiModalProcessor[WhisperProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + # Strictly speaking, whisper encoder only accept audio features. + # We create a dummy encoder prompt here which will be padded to + # num_audio_tokens. So that we can create dummy data from this + # for encoder profiling. + return [0] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_data = dict(audio=mm_data.pop("audios")) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + if "labels" in processed_outputs: + processed_outputs["input_ids"] = processed_outputs.pop("labels") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(input_features=MultiModalFieldConfig.batched("audio")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + num_tokens = self.info.get_max_audio_tokens() + return [ + PromptReplacement( + modality="audio", + target=[0], + replacement=[0] * num_tokens, + ) + ] -def input_mapper_for_whisper( - ctx: InputContext, - multi_modal_data: Union[np.ndarray, List[np.ndarray]], -) -> MultiModalKwargs: - if not isinstance(multi_modal_data, list): - multi_modal_data = [multi_modal_data] - - assert len(multi_modal_data) == 1 - - if len(multi_modal_data) == 0: - return MultiModalKwargs() - - processor = cached_processor_from_config(ctx.model_config) - sampling_rate = processor.feature_extractor.sampling_rate - - audios = [audio for audio, _ in multi_modal_data] - - kwargs = processor(audios, - sampling_rate=sampling_rate, - return_tensors="pt") - kwargs["input_features"] = kwargs["input_features"].squeeze(0).to( - ctx.model_config.dtype) - - return MultiModalKwargs(kwargs) - - -@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper) -@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper) -@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_max_whisper_audio_tokens) +@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor, + info=WhisperProcessingInfo, + dummy_inputs=WhisperDummyInputsBuilder) class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, SupportsMultiModal): packed_modules_mapping = { @@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, if not isinstance(input_features, (torch.Tensor, list)): raise ValueError("Incorrect type of audio features. " f"Got type: {type(input_features)}") - input_features = [feat.to(self.dtype) for feat in input_features] + input_features = torch.cat( + [feat.to(self.dtype) for feat in input_features]) return WhisperAudioInputs(input_features=input_features) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index fcd02fbd5203c..93756364dea15 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1297,7 +1297,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): prompt: Union[str, list[int]], mm_data: MultiModalDataDict, ) -> Union[str, list[int]]: - """Create input prompt for the encoder.""" + """ + Create input prompt for the encoder. HF processor will be applied on + this prompt during profiling and generation. + """ raise NotImplementedError def apply( diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 81c92b38f8e95..802e40a0c9523 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -166,8 +166,12 @@ class MultiModalProfiler(Generic[_I]): f"({set(mm_max_tokens_per_item.keys())})") mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - prompt_token_ids = mm_inputs["prompt_token_ids"] placeholders_by_modality = mm_inputs["mm_placeholders"] + # For encoder-decoder models, use encoder prompt token ids instead of + # decoder prompt to construct dummy seq_data for encoder profiling. + prompt_token_ids = ( + mm_inputs["prompt_token_ids"] if not is_encoder_data else + mm_inputs["encoder_prompt_token_ids"]) # type: ignore total_placeholders_by_modality = { modality: sum(item["length"] for item in placeholders) @@ -188,7 +192,7 @@ class MultiModalProfiler(Generic[_I]): # V0 does not support chunked prefill. if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data: - if total_len > seq_len: + if total_len > seq_len and not is_encoder_data: logger.warning( "The context length (%d) of the model is too short " "to hold the multi-modal embeddings in the worst case " @@ -201,7 +205,8 @@ class MultiModalProfiler(Generic[_I]): total_placeholders_by_modality) return DummyData( - seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), + seq_data=SequenceData.from_prompt_token_counts( + (0, max(seq_len, total_len))), multi_modal_data=None, multi_modal_placeholders=None, )