Migrate whisper inputs to TensorSchema (#23505)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck 2025-09-03 11:04:00 -07:00 committed by GitHub
parent e9b92dcd89
commit 731a6940e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from typing import Literal, Optional, TypedDict, Union, cast
from typing import Annotated, Literal, Optional, Union, cast
import numpy as np
import torch
@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription, SupportsV0Only)
@ -111,9 +112,16 @@ ISO639_1_SUPPORTED_LANGS = {
}
class WhisperAudioInputs(TypedDict):
input_features: NestedTensors
"""Shape: `(batch_size, 128, M)`"""
class WhisperAudioInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- nmb: Number of mel bins
- t: Time frames (M)
"""
input_features: Annotated[Optional[NestedTensors],
TensorShape("b", "nmb", "t")]
class WhisperPositionalEmbedding(nn.Embedding):