[LMM] Implement merged multimodal processor for whisper (#13278)

This commit is contained in:
Isotr0py 2025-02-23 17:46:03 +08:00 committed by GitHub
parent d5ca2110f1
commit ba5106e519
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 144 additions and 77 deletions

View File

@ -83,11 +83,11 @@ def _test_processing_correctness(
} }
tokenizer_encode_kwargs = {} tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type == "mllama": if model_config.hf_config.model_type in ("mllama", "whisper"):
# For Mllama, tokenizer will always add bos_token at the beginning of # For some encoder-decoder models, tokenizer will always add bos_token
# prompt by default, causing hf_processor outputs incorrect token ids. # at the beginning of prompt by default, causing hf_processor outputs
# So we need use `add_special_tokens=False` here to leave bos_token # incorrect token ids. So we need use `add_special_tokens=False` here
# to be added by the processor. # to leave bos_token to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False} tokenizer_encode_kwargs = {"add_special_tokens": False}
for batch_idx in range(num_batches): for batch_idx in range(num_batches):
@ -173,6 +173,7 @@ def _test_processing_correctness(
"Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_5-llama-3_2-1b", "fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
]) ])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("num_batches", [32])

View File

@ -4,15 +4,15 @@ import math
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union) Union)
import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor)
from transformers.models.whisper.modeling_whisper import sinusoids from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
NestedTensors) from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.audio import resample_audio from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
from vllm.sequence import SequenceData MultiModalDataParser)
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription from .interfaces import SupportsMultiModal, SupportsTranscription
from .utils import AutoWeightsLoader, WeightsMapper, make_layers from .utils import AutoWeightsLoader, WeightsMapper, make_layers
@ -571,72 +574,126 @@ class WhisperModel(nn.Module):
return loaded_params return loaded_params
def get_max_whisper_audio_tokens(ctx: InputContext) -> int: class WhisperProcessingInfo(BaseProcessingInfo):
return ctx.model_config.hf_config.max_source_positions
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
def get_hf_processor(self,
sampling_rate: Optional[int] = None
) -> WhisperProcessor:
return self.ctx.get_hf_processor(WhisperProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": 1}
def get_feature_extractor(self) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor()
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_max_audio_tokens(self) -> int:
return self.get_hf_config().max_source_positions
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"audio": self.get_max_audio_tokens()}
def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int, class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
mm_counts: Mapping[str, int]):
assert mm_counts["audio"] == 1 def get_dummy_processor_inputs(
num_tokens = get_max_whisper_audio_tokens(ctx) self,
processor = cached_processor_from_config(ctx.model_config) seq_len: int,
chunk_length = processor.feature_extractor.chunk_length mm_counts: Mapping[str, int],
sampling_rate = processor.feature_extractor.sampling_rate ) -> ProcessorInputs:
num_samples = chunk_length * sampling_rate feature_extractor = self.info.get_feature_extractor()
return DummyData(
SequenceData.from_prompt_token_counts((0, num_tokens)), sampling_rate = feature_extractor.sampling_rate
{"audio": [(np.zeros(num_samples), sampling_rate)]}, audio_len = feature_extractor.chunk_length * sampling_rate
) num_audios = mm_counts.get("audio", 0)
mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text="<|startoftranscript|>" * num_audios,
mm_data=mm_data,
)
def input_processor_for_whisper(ctx: InputContext, inputs): class WhisperMultiModalProcessor(
multi_modal_data = inputs["encoder"]["multi_modal_data"] EncDecMultiModalProcessor[WhisperProcessingInfo]):
if isinstance(multi_modal_data["audio"], list):
assert len(multi_modal_data["audio"]) == 1 def _get_data_parser(self) -> MultiModalDataParser:
multi_modal_data["audio"] = multi_modal_data["audio"][0] feature_extractor = self.info.get_feature_extractor()
# Resample and process audio return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
audio, orig_sr = multi_modal_data["audio"]
processor = cached_processor_from_config(ctx.model_config) def create_encoder_prompt(
target_sr = processor.feature_extractor.sampling_rate self,
audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr) prompt: Union[str, list[int]],
multi_modal_data["audio"] = (audio, target_sr) mm_data: MultiModalDataDict,
# Pre-allocate placeholder tokens in encoder sequence ) -> Union[str, list[int]]:
num_tokens = get_max_whisper_audio_tokens(ctx) # Strictly speaking, whisper encoder only accept audio features.
inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens # We create a dummy encoder prompt here which will be padded to
return inputs # num_audio_tokens. So that we can create dummy data from this
# for encoder profiling.
return [0]
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
if "labels" in processed_outputs:
processed_outputs["input_ids"] = processed_outputs.pop("labels")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(input_features=MultiModalFieldConfig.batched("audio"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
num_tokens = self.info.get_max_audio_tokens()
return [
PromptReplacement(
modality="audio",
target=[0],
replacement=[0] * num_tokens,
)
]
def input_mapper_for_whisper( @MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor,
ctx: InputContext, info=WhisperProcessingInfo,
multi_modal_data: Union[np.ndarray, List[np.ndarray]], dummy_inputs=WhisperDummyInputsBuilder)
) -> MultiModalKwargs:
if not isinstance(multi_modal_data, list):
multi_modal_data = [multi_modal_data]
assert len(multi_modal_data) == 1
if len(multi_modal_data) == 0:
return MultiModalKwargs()
processor = cached_processor_from_config(ctx.model_config)
sampling_rate = processor.feature_extractor.sampling_rate
audios = [audio for audio, _ in multi_modal_data]
kwargs = processor(audios,
sampling_rate=sampling_rate,
return_tensors="pt")
kwargs["input_features"] = kwargs["input_features"].squeeze(0).to(
ctx.model_config.dtype)
return MultiModalKwargs(kwargs)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper)
@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal): SupportsMultiModal):
packed_modules_mapping = { packed_modules_mapping = {
@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
if not isinstance(input_features, (torch.Tensor, list)): if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. " raise ValueError("Incorrect type of audio features. "
f"Got type: {type(input_features)}") f"Got type: {type(input_features)}")
input_features = [feat.to(self.dtype) for feat in input_features] input_features = torch.cat(
[feat.to(self.dtype) for feat in input_features])
return WhisperAudioInputs(input_features=input_features) return WhisperAudioInputs(input_features=input_features)

View File

@ -1297,7 +1297,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
) -> Union[str, list[int]]: ) -> Union[str, list[int]]:
"""Create input prompt for the encoder.""" """
Create input prompt for the encoder. HF processor will be applied on
this prompt during profiling and generation.
"""
raise NotImplementedError raise NotImplementedError
def apply( def apply(

View File

@ -166,8 +166,12 @@ class MultiModalProfiler(Generic[_I]):
f"({set(mm_max_tokens_per_item.keys())})") f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"] placeholders_by_modality = mm_inputs["mm_placeholders"]
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
prompt_token_ids = (
mm_inputs["prompt_token_ids"] if not is_encoder_data else
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
total_placeholders_by_modality = { total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders) modality: sum(item["length"] for item in placeholders)
@ -188,7 +192,7 @@ class MultiModalProfiler(Generic[_I]):
# V0 does not support chunked prefill. # V0 does not support chunked prefill.
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data: if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
if total_len > seq_len: if total_len > seq_len and not is_encoder_data:
logger.warning( logger.warning(
"The context length (%d) of the model is too short " "The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case " "to hold the multi-modal embeddings in the worst case "
@ -201,7 +205,8 @@ class MultiModalProfiler(Generic[_I]):
total_placeholders_by_modality) total_placeholders_by_modality)
return DummyData( return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), seq_data=SequenceData.from_prompt_token_counts(
(0, max(seq_len, total_len))),
multi_modal_data=None, multi_modal_data=None,
multi_modal_placeholders=None, multi_modal_placeholders=None,
) )