diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 848b6e0f8093..97e8cd6e7695 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -4,7 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext -from typing import Literal, Optional, TypedDict, Union, cast +from typing import Annotated, Literal, Optional, Union, cast import numpy as np import torch @@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription, SupportsV0Only) @@ -111,9 +112,16 @@ ISO639_1_SUPPORTED_LANGS = { } -class WhisperAudioInputs(TypedDict): - input_features: NestedTensors - """Shape: `(batch_size, 128, M)`""" +class WhisperAudioInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - nmb: Number of mel bins + - t: Time frames (M) + """ + + input_features: Annotated[Optional[NestedTensors], + TensorShape("b", "nmb", "t")] class WhisperPositionalEmbedding(nn.Embedding):