mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-12 06:17:13 +08:00
[LMM] Implement merged multimodal processor for whisper (#13278)
This commit is contained in:
parent
d5ca2110f1
commit
ba5106e519
@ -83,11 +83,11 @@ def _test_processing_correctness(
|
|||||||
}
|
}
|
||||||
|
|
||||||
tokenizer_encode_kwargs = {}
|
tokenizer_encode_kwargs = {}
|
||||||
if model_config.hf_config.model_type == "mllama":
|
if model_config.hf_config.model_type in ("mllama", "whisper"):
|
||||||
# For Mllama, tokenizer will always add bos_token at the beginning of
|
# For some encoder-decoder models, tokenizer will always add bos_token
|
||||||
# prompt by default, causing hf_processor outputs incorrect token ids.
|
# at the beginning of prompt by default, causing hf_processor outputs
|
||||||
# So we need use `add_special_tokens=False` here to leave bos_token
|
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||||
# to be added by the processor.
|
# to leave bos_token to be added by the processor.
|
||||||
tokenizer_encode_kwargs = {"add_special_tokens": False}
|
tokenizer_encode_kwargs = {"add_special_tokens": False}
|
||||||
|
|
||||||
for batch_idx in range(num_batches):
|
for batch_idx in range(num_batches):
|
||||||
@ -173,6 +173,7 @@ def _test_processing_correctness(
|
|||||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||||
|
"openai/whisper-large-v3",
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||||
@pytest.mark.parametrize("num_batches", [32])
|
@pytest.mark.parametrize("num_batches", [32])
|
||||||
|
|||||||
@ -4,15 +4,15 @@ import math
|
|||||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||||
|
WhisperProcessor)
|
||||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||||
NestedTensors)
|
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||||
from vllm.multimodal.audio import resample_audio
|
from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
|
||||||
from vllm.sequence import SequenceData
|
MultiModalDataParser)
|
||||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||||
|
EncDecMultiModalProcessor,
|
||||||
|
PromptReplacement)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsTranscription
|
from .interfaces import SupportsMultiModal, SupportsTranscription
|
||||||
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
|
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
|
||||||
@ -571,72 +574,126 @@ class WhisperModel(nn.Module):
|
|||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
def get_max_whisper_audio_tokens(ctx: InputContext) -> int:
|
class WhisperProcessingInfo(BaseProcessingInfo):
|
||||||
return ctx.model_config.hf_config.max_source_positions
|
|
||||||
|
def get_hf_config(self) -> WhisperConfig:
|
||||||
|
return self.ctx.get_hf_config(WhisperConfig)
|
||||||
|
|
||||||
|
def get_hf_processor(self,
|
||||||
|
sampling_rate: Optional[int] = None
|
||||||
|
) -> WhisperProcessor:
|
||||||
|
return self.ctx.get_hf_processor(WhisperProcessor)
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"audio": 1}
|
||||||
|
|
||||||
|
def get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||||
|
hf_processor = self.get_hf_processor()
|
||||||
|
feature_extractor = hf_processor.feature_extractor # type: ignore
|
||||||
|
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||||
|
return feature_extractor
|
||||||
|
|
||||||
|
def get_max_audio_tokens(self) -> int:
|
||||||
|
return self.get_hf_config().max_source_positions
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
return {"audio": self.get_max_audio_tokens()}
|
||||||
|
|
||||||
|
|
||||||
def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int,
|
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
|
||||||
mm_counts: Mapping[str, int]):
|
|
||||||
assert mm_counts["audio"] == 1
|
def get_dummy_processor_inputs(
|
||||||
num_tokens = get_max_whisper_audio_tokens(ctx)
|
self,
|
||||||
processor = cached_processor_from_config(ctx.model_config)
|
seq_len: int,
|
||||||
chunk_length = processor.feature_extractor.chunk_length
|
mm_counts: Mapping[str, int],
|
||||||
sampling_rate = processor.feature_extractor.sampling_rate
|
) -> ProcessorInputs:
|
||||||
num_samples = chunk_length * sampling_rate
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
return DummyData(
|
|
||||||
SequenceData.from_prompt_token_counts((0, num_tokens)),
|
sampling_rate = feature_extractor.sampling_rate
|
||||||
{"audio": [(np.zeros(num_samples), sampling_rate)]},
|
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||||
)
|
num_audios = mm_counts.get("audio", 0)
|
||||||
|
|
||||||
|
mm_data = {
|
||||||
|
"audio":
|
||||||
|
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ProcessorInputs(
|
||||||
|
prompt_text="<|startoftranscript|>" * num_audios,
|
||||||
|
mm_data=mm_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_whisper(ctx: InputContext, inputs):
|
class WhisperMultiModalProcessor(
|
||||||
multi_modal_data = inputs["encoder"]["multi_modal_data"]
|
EncDecMultiModalProcessor[WhisperProcessingInfo]):
|
||||||
if isinstance(multi_modal_data["audio"], list):
|
|
||||||
assert len(multi_modal_data["audio"]) == 1
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
multi_modal_data["audio"] = multi_modal_data["audio"][0]
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
# Resample and process audio
|
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||||
audio, orig_sr = multi_modal_data["audio"]
|
|
||||||
processor = cached_processor_from_config(ctx.model_config)
|
def create_encoder_prompt(
|
||||||
target_sr = processor.feature_extractor.sampling_rate
|
self,
|
||||||
audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr)
|
prompt: Union[str, list[int]],
|
||||||
multi_modal_data["audio"] = (audio, target_sr)
|
mm_data: MultiModalDataDict,
|
||||||
# Pre-allocate placeholder tokens in encoder sequence
|
) -> Union[str, list[int]]:
|
||||||
num_tokens = get_max_whisper_audio_tokens(ctx)
|
# Strictly speaking, whisper encoder only accept audio features.
|
||||||
inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens
|
# We create a dummy encoder prompt here which will be padded to
|
||||||
return inputs
|
# num_audio_tokens. So that we can create dummy data from this
|
||||||
|
# for encoder profiling.
|
||||||
|
return [0]
|
||||||
|
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
if mm_data:
|
||||||
|
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
||||||
|
mm_data = dict(audio=mm_data.pop("audios"))
|
||||||
|
mm_kwargs = dict(
|
||||||
|
**mm_kwargs,
|
||||||
|
sampling_rate=feature_extractor.sampling_rate,
|
||||||
|
)
|
||||||
|
processed_outputs = super()._call_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
)
|
||||||
|
if "labels" in processed_outputs:
|
||||||
|
processed_outputs["input_ids"] = processed_outputs.pop("labels")
|
||||||
|
return processed_outputs
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(input_features=MultiModalFieldConfig.batched("audio"))
|
||||||
|
|
||||||
|
def _get_prompt_replacements(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> list[PromptReplacement]:
|
||||||
|
num_tokens = self.info.get_max_audio_tokens()
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="audio",
|
||||||
|
target=[0],
|
||||||
|
replacement=[0] * num_tokens,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def input_mapper_for_whisper(
|
@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor,
|
||||||
ctx: InputContext,
|
info=WhisperProcessingInfo,
|
||||||
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
|
dummy_inputs=WhisperDummyInputsBuilder)
|
||||||
) -> MultiModalKwargs:
|
|
||||||
if not isinstance(multi_modal_data, list):
|
|
||||||
multi_modal_data = [multi_modal_data]
|
|
||||||
|
|
||||||
assert len(multi_modal_data) == 1
|
|
||||||
|
|
||||||
if len(multi_modal_data) == 0:
|
|
||||||
return MultiModalKwargs()
|
|
||||||
|
|
||||||
processor = cached_processor_from_config(ctx.model_config)
|
|
||||||
sampling_rate = processor.feature_extractor.sampling_rate
|
|
||||||
|
|
||||||
audios = [audio for audio, _ in multi_modal_data]
|
|
||||||
|
|
||||||
kwargs = processor(audios,
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
return_tensors="pt")
|
|
||||||
kwargs["input_features"] = kwargs["input_features"].squeeze(0).to(
|
|
||||||
ctx.model_config.dtype)
|
|
||||||
|
|
||||||
return MultiModalKwargs(kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper)
|
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper)
|
|
||||||
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
|
|
||||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
|
||||||
"audio", get_max_whisper_audio_tokens)
|
|
||||||
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||||
SupportsMultiModal):
|
SupportsMultiModal):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
if not isinstance(input_features, (torch.Tensor, list)):
|
if not isinstance(input_features, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of audio features. "
|
raise ValueError("Incorrect type of audio features. "
|
||||||
f"Got type: {type(input_features)}")
|
f"Got type: {type(input_features)}")
|
||||||
input_features = [feat.to(self.dtype) for feat in input_features]
|
input_features = torch.cat(
|
||||||
|
[feat.to(self.dtype) for feat in input_features])
|
||||||
|
|
||||||
return WhisperAudioInputs(input_features=input_features)
|
return WhisperAudioInputs(input_features=input_features)
|
||||||
|
|
||||||
|
|||||||
@ -1297,7 +1297,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
prompt: Union[str, list[int]],
|
prompt: Union[str, list[int]],
|
||||||
mm_data: MultiModalDataDict,
|
mm_data: MultiModalDataDict,
|
||||||
) -> Union[str, list[int]]:
|
) -> Union[str, list[int]]:
|
||||||
"""Create input prompt for the encoder."""
|
"""
|
||||||
|
Create input prompt for the encoder. HF processor will be applied on
|
||||||
|
this prompt during profiling and generation.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
|
|||||||
@ -166,8 +166,12 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
f"({set(mm_max_tokens_per_item.keys())})")
|
f"({set(mm_max_tokens_per_item.keys())})")
|
||||||
|
|
||||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
|
||||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||||
|
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||||
|
# decoder prompt to construct dummy seq_data for encoder profiling.
|
||||||
|
prompt_token_ids = (
|
||||||
|
mm_inputs["prompt_token_ids"] if not is_encoder_data else
|
||||||
|
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
|
||||||
|
|
||||||
total_placeholders_by_modality = {
|
total_placeholders_by_modality = {
|
||||||
modality: sum(item["length"] for item in placeholders)
|
modality: sum(item["length"] for item in placeholders)
|
||||||
@ -188,7 +192,7 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
|
|
||||||
# V0 does not support chunked prefill.
|
# V0 does not support chunked prefill.
|
||||||
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
|
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
|
||||||
if total_len > seq_len:
|
if total_len > seq_len and not is_encoder_data:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The context length (%d) of the model is too short "
|
"The context length (%d) of the model is too short "
|
||||||
"to hold the multi-modal embeddings in the worst case "
|
"to hold the multi-modal embeddings in the worst case "
|
||||||
@ -201,7 +205,8 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
total_placeholders_by_modality)
|
total_placeholders_by_modality)
|
||||||
|
|
||||||
return DummyData(
|
return DummyData(
|
||||||
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
seq_data=SequenceData.from_prompt_token_counts(
|
||||||
|
(0, max(seq_len, total_len))),
|
||||||
multi_modal_data=None,
|
multi_modal_data=None,
|
||||||
multi_modal_placeholders=None,
|
multi_modal_placeholders=None,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user