From 75856bc2cb8e559dc5cd62cc0e495662be1aaae3 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sun, 27 Jul 2025 22:37:20 -0700 Subject: [PATCH] Migrate GraniteSpeechAudioInputs to TensorSchema (#21682) Signed-off-by: Benji Beck Signed-off-by: DarkLight1337 Co-authored-by: Cyrus Leung --- vllm/model_executor/models/granite_speech.py | 26 ++++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 6a4dee9ae48d4..5a3e715c3e748 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -25,7 +25,7 @@ """Inference-only IBM Granite speech model.""" import math from collections.abc import Iterable, Mapping -from typing import Optional, TypedDict, Union +from typing import Annotated, Optional, Union import torch import torch.nn.functional as F @@ -48,6 +48,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .blip2 import Blip2QFormerModel from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -57,16 +58,24 @@ from .utils import (AutoWeightsLoader, embed_multimodal, ### Audio Input -class GraniteSpeechAudioInputs(TypedDict): +class GraniteSpeechAudioInputs(TensorSchema): + """ + Audio input features for Granite Speech model. + + Dimensions: + - b: Batch size + - nf: Number of audio features (variable length) + - 160: Fixed feature dimension for Mel spectrogram features + """ - input_features: torch.Tensor - """Shape: `(bsz, num_features, 160)`""" + input_features: Annotated[torch.Tensor, TensorShape("b", "nf", 160)] + """Audio input features.""" - input_features_mask: torch.Tensor - """Shape: `(bsz, num_features)`""" + input_features_mask: Annotated[torch.Tensor, TensorShape("b", "nf")] + """Mask for variable length audio features.""" - audio_embed_sizes: list[int] - """List of length `bsz`""" + audio_embed_sizes: Annotated[list[int], TensorShape("b")] + """List of audio embedding sizes for each item in batch.""" class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): @@ -581,6 +590,7 @@ class GraniteSpeechForConditionalGeneration( input_features = kwargs.pop("input_features", None) input_features_mask = kwargs.pop("input_features_mask", None) audio_embed_sizes = kwargs.pop("audio_embed_sizes", None) + if input_features is None: return None